diff --git a/README.md b/README.md index 86dc362..2191865 100644 --- a/README.md +++ b/README.md @@ -50,10 +50,17 @@ ana auth login --endpoint https://app.textql.com ana org show ana connector list ana chat send "show me last month's revenue" +ana update # replace the running binary with the latest release ``` Run `ana --help` or `ana --help` for command-specific flags. +`ana` checks GitHub for a newer release after each verb and prints a one-line +stderr nudge when one exists. The result is cached for 4 h by default; set +`updateCheckInterval` in `config.json` (any `time.ParseDuration`-compatible +value) to change the cadence, or `"0"` / `"disable"` to turn the check off. +`--json` suppresses the nudge so automation pipelines aren't broken. + ## Configuration `ana` stores tokens and per-profile endpoints at diff --git a/cmd/ana/CLAUDE.md b/cmd/ana/CLAUDE.md index d8f25f8..276fab8 100644 --- a/cmd/ana/CLAUDE.md +++ b/cmd/ana/CLAUDE.md @@ -4,6 +4,7 @@ The `ana` binary's main package. Pure wiring: reads global flags + config, const ## Files -- `main.go` — `main` + `run` (the testable entrypoint with injectable args/stdio/env) and the `buildVerbs`/`authDeps`/`profileDeps`/`chatDeps` adapters. Also owns `newUUID` (used for chat `cellId`s) and the projection `profileToAuthConfig` that keeps `internal/auth` from importing `internal/config`. +- `main.go` — `main` + `run` (the testable entrypoint with injectable args/stdio/env) and the `buildVerbs`/`authDeps`/`profileDeps`/`chatDeps` adapters. Also owns `newUUID` (used for chat `cellId`s), the projection `profileToAuthConfig` that keeps `internal/auth` from importing `internal/config`, and the `startNudge`/`drainNudge` helpers that run the passive update-check goroutine in parallel with `cli.Dispatch`. - `version.go` — the `version` leaf command plus the `version`/`commit`/`date` package vars that goreleaser stamps via `-ldflags "-X main.version=..."`. `--version` / `-V` is rewritten to the `version` verb up front so flag and subcommand share one rendering path. -- `main_test.go` — exercises `run` end-to-end with fakes (no live server) and asserts the verb-map shape, version banner, and all adapter closures. +- `update.go` — the `update` leaf command (`ana update`) that delegates to `internal/update.SelfUpdate` to download, verify, and replace the running binary. +- `main_test.go` — exercises `run` end-to-end with fakes (no live server) and asserts the verb-map shape, version banner, adapter closures, `startNudge` skip predicates, `drainNudge` branches, and the `update` help short-circuit. diff --git a/cmd/ana/main.go b/cmd/ana/main.go index e15e4a0..a8c70be 100644 --- a/cmd/ana/main.go +++ b/cmd/ana/main.go @@ -10,6 +10,8 @@ import ( "encoding/hex" "errors" "fmt" + "io" + "net/http" "os" "os/signal" "time" @@ -27,6 +29,7 @@ import ( "github.com/highperformance-tech/ana-cli/internal/playbook" "github.com/highperformance-tech/ana-cli/internal/profile" "github.com/highperformance-tech/ana-cli/internal/transport" + "github.com/highperformance-tech/ana-cli/internal/update" ) func main() { @@ -120,7 +123,79 @@ func run(args []string, stdio cli.IO, env func(string) string) error { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) defer stop() - return cli.Dispatch(ctx, verbs, args, stdio) + // Kick the passive update-check goroutine BEFORE Dispatch so the HTTP + // round-trip overlaps the verb's work. drainNudge picks it up after + // Dispatch returns. nudgeCh is nil whenever we decide not to check at + // all (dev build, --json, disabled interval, or no cache path); that + // nil flows through drainNudge as a no-op. + nudgeCh := startNudge(env, loaded, global) + err = cli.Dispatch(ctx, verbs, args, stdio) + drainNudge(nudgeCh, 500*time.Millisecond, err, stdio.Stderr) + return err +} + +// startNudge launches the passive update-check goroutine when enabled and +// returns a buffered channel that drainNudge reads. Returns nil whenever the +// check is skipped so drainNudge can short-circuit without touching the +// channel. +// +// Skip predicates (any true disables the check): +// - version == "dev" — source checkout, no corresponding GitHub release. +// - global.JSON — automation pipeline, extra stderr line would break parsers. +// - ParseInterval reports disabled (config value "0" / "disable"). +// - CachePath fails — no XDG_CACHE_HOME and no HOME means we have nowhere +// to stash freshness state, and we refuse to re-hit the network on every +// run. +func startNudge(env func(string) string, loaded config.Config, global cli.Global) chan string { + if version == "dev" || global.JSON { + return nil + } + ttl, enabled := update.ParseInterval(loaded.UpdateCheckInterval) + if !enabled { + return nil + } + if _, err := update.CachePath(env); err != nil { + return nil + } + ch := make(chan string, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + tag, notify, _ := update.CachedCheck(ctx, update.CacheDeps{ + Env: env, + Now: time.Now, + HTTP: http.DefaultClient, + }, ttl, version) + if notify { + ch <- fmt.Sprintf("A new version of ana-cli is available: v%s → %s Run: ana update", version, tag) + } else { + ch <- "" + } + }() + return ch +} + +// drainNudge waits up to timeout for startNudge's goroutine to report. When +// it produces a non-empty message and the verb did not return a help/usage +// error, the message is written to stderr — we intentionally suppress the +// nudge on help/usage paths so help text doesn't get crowded by an upgrade +// prompt. A nil ch (check was skipped) or a timeout is a clean no-op. +func drainNudge(ch chan string, timeout time.Duration, verbErr error, stderr io.Writer) { + if ch == nil { + return + } + if errors.Is(verbErr, cli.ErrHelp) || errors.Is(verbErr, cli.ErrUsage) { + return + } + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case msg := <-ch: + if msg != "" { + fmt.Fprintln(stderr, msg) + } + case <-timer.C: + } } // buildVerbs wires every verb package's Deps against the shared transport @@ -144,6 +219,7 @@ func buildVerbs(client *transport.Client, env func(string) string, cfgPath, prof "feed": feed.New(feed.Deps{Unary: client.Unary}), "audit": audit.New(audit.Deps{Unary: client.Unary, Now: time.Now}), "version": versionCmd{}, + "update": updateCmd{deps: update.DefaultDeps()}, } } diff --git a/cmd/ana/main_test.go b/cmd/ana/main_test.go index 2619e0e..5fa507f 100644 --- a/cmd/ana/main_test.go +++ b/cmd/ana/main_test.go @@ -287,7 +287,7 @@ func TestBuildVerbs_Shape(t *testing.T) { t.Parallel() client := transport.New("https://example", func(context.Context) (string, error) { return "", nil }) verbs := buildVerbs(client, func(string) string { return "" }, "", "default", "https://example") - want := []string{"auth", "profile", "org", "connector", "chat", "dashboard", "playbook", "ontology", "feed", "audit", "version"} + want := []string{"auth", "profile", "org", "connector", "chat", "dashboard", "playbook", "ontology", "feed", "audit", "version", "update"} for _, v := range want { if _, ok := verbs[v]; !ok { t.Errorf("missing verb: %q", v) @@ -436,7 +436,7 @@ func TestVersionCmd_Help(t *testing.T) { var out bytes.Buffer stdio := cli.IO{Stdout: &out, Stderr: &bytes.Buffer{}} err := (versionCmd{}).Run(context.Background(), []string{"--help"}, stdio) - if err != cli.ErrHelp { + if !errors.Is(err, cli.ErrHelp) { t.Fatalf("err = %v, want ErrHelp", err) } if !strings.Contains(out.String(), "Print ana version") { @@ -536,6 +536,108 @@ func TestRun_LeafUsageErrorReturned(t *testing.T) { } } +// TestStartNudge_SkipConditions covers every reason startNudge returns nil: +// dev version, --json, interval disabled, no HOME/XDG. Each skip must short- +// circuit before the goroutine spawns, which we assert by the returned ch +// being nil. +func TestStartNudge_SkipConditions(t *testing.T) { + // Mutates the package-level version var — must not run in parallel with + // TestVersionCmd_PrintsBanner or TestRun_VersionFlag, both of which read + // it concurrently under -race. + prev := version + t.Cleanup(func() { version = prev }) + envNone := func(string) string { return "" } + envHome := func(k string) string { + if k == "HOME" { + return t.TempDir() + } + return "" + } + disable := "disable" + cases := []struct { + name string + version string + env func(string) string + cfg config.Config + global cli.Global + }{ + {"dev build", "dev", envHome, config.Config{}, cli.Global{}}, + {"json output", "1.0.0", envHome, config.Config{}, cli.Global{JSON: true}}, + {"disabled", "1.0.0", envHome, config.Config{UpdateCheckInterval: &disable}, cli.Global{}}, + {"no home", "1.0.0", envNone, config.Config{}, cli.Global{}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + version = tc.version + if got := startNudge(tc.env, tc.cfg, tc.global); got != nil { + t.Fatalf("expected nil channel, got %v", got) + } + }) + } +} + +// TestDrainNudge covers the four branches: nil channel, help-err suppression, +// non-empty message printed, and empty message (no print). +func TestDrainNudge(t *testing.T) { + t.Parallel() + t.Run("nil channel is a no-op", func(t *testing.T) { + var buf bytes.Buffer + drainNudge(nil, time.Millisecond, nil, &buf) + if buf.Len() != 0 { + t.Fatalf("stderr: %q", buf.String()) + } + }) + t.Run("help err suppresses", func(t *testing.T) { + ch := make(chan string, 1) + ch <- "should not print" + var buf bytes.Buffer + drainNudge(ch, time.Millisecond, cli.ErrHelp, &buf) + if buf.Len() != 0 { + t.Fatalf("stderr: %q", buf.String()) + } + }) + t.Run("message printed", func(t *testing.T) { + ch := make(chan string, 1) + ch <- "hello" + var buf bytes.Buffer + drainNudge(ch, time.Millisecond, nil, &buf) + if !strings.Contains(buf.String(), "hello") { + t.Fatalf("stderr: %q", buf.String()) + } + }) + t.Run("empty message swallowed", func(t *testing.T) { + ch := make(chan string, 1) + ch <- "" + var buf bytes.Buffer + drainNudge(ch, time.Millisecond, nil, &buf) + if buf.Len() != 0 { + t.Fatalf("stderr: %q", buf.String()) + } + }) + t.Run("timeout", func(t *testing.T) { + ch := make(chan string) // no sender + var buf bytes.Buffer + drainNudge(ch, 10*time.Millisecond, nil, &buf) + if buf.Len() != 0 { + t.Fatalf("stderr: %q", buf.String()) + } + }) +} + +// TestUpdateCmd_Help short-circuits on --help like every other leaf verb. +func TestUpdateCmd_Help(t *testing.T) { + t.Parallel() + var out bytes.Buffer + stdio := cli.IO{Stdout: &out, Stderr: &bytes.Buffer{}} + err := (updateCmd{}).Run(context.Background(), []string{"--help"}, stdio) + if !errors.Is(err, cli.ErrHelp) { + t.Fatalf("err = %v, want ErrHelp", err) + } + if !strings.Contains(out.String(), "latest ana release") { + t.Fatalf("help body missing: %q", out.String()) + } +} + // TestRun_UnknownProfile drives the ErrUnknownProfile branch in run: a // --profile pointing at a slot that doesn't exist (and no env fallback) // must print the canonical error to stderr and exit 1 via ErrUsage. diff --git a/cmd/ana/update.go b/cmd/ana/update.go new file mode 100644 index 0000000..113fad2 --- /dev/null +++ b/cmd/ana/update.go @@ -0,0 +1,35 @@ +package main + +import ( + "context" + "errors" + "fmt" + + "github.com/highperformance-tech/ana-cli/internal/cli" + "github.com/highperformance-tech/ana-cli/internal/update" +) + +// updateCmd implements `ana update`: fetches the matching release archive, +// verifies its sha256, and atomically replaces the running binary. Mirrors +// versionCmd's leaf shape — deps are pulled in via the package-level default +// so cmd/ana keeps its "pure wiring" posture. +type updateCmd struct { + deps update.Deps +} + +func (updateCmd) Help() string { + return "Download and install the latest ana release." +} + +func (c updateCmd) Run(ctx context.Context, args []string, stdio cli.IO) error { + if len(args) > 0 && cli.IsHelpArg(args[0]) { + fmt.Fprintln(stdio.Stdout, updateCmd{}.Help()) + return cli.ErrHelp + } + jsonOut := cli.GlobalFrom(ctx).JSON + if err := update.SelfUpdate(ctx, c.deps, version, stdio.Stdout, jsonOut); err != nil { + fmt.Fprintln(stdio.Stderr, err) + return errors.Join(err, cli.ErrReported) + } + return nil +} diff --git a/cmd/ana/version.go b/cmd/ana/version.go index 4982d7d..9255f1f 100644 --- a/cmd/ana/version.go +++ b/cmd/ana/version.go @@ -25,7 +25,7 @@ func (versionCmd) Help() string { } func (versionCmd) Run(_ context.Context, args []string, stdio cli.IO) error { - if len(args) > 0 && (args[0] == "-h" || args[0] == "--help" || args[0] == "help") { + if len(args) > 0 && cli.IsHelpArg(args[0]) { fmt.Fprintln(stdio.Stdout, versionCmd{}.Help()) return cli.ErrHelp } diff --git a/docs/cli-readiness.md b/docs/cli-readiness.md index 455edd7..044692d 100644 --- a/docs/cli-readiness.md +++ b/docs/cli-readiness.md @@ -32,6 +32,7 @@ Confidence key: ✅ full CRUD verified · 🟡 partial / readonly verified · | Packages | 🟡 | List only; what "install/uninstall" looks like is unknown. | | Notifications | 🟡 | Streaming envelope not captured (`StreamNotifications`). | | Feed | 🟡 | Same — `StreamFeed` not captured. | +| Self-update | ✅ | Passive check after every verb (4h cache, `--json` suppresses); `ana update` downloads + sha256-verifies + atomically replaces the running binary from the matching GoReleaser archive. | ## Enum catalog (incomplete but useful) @@ -135,6 +136,9 @@ ana dashboard list / get / spawn / health ana ontology list / get ana audit tail # poll ListAuditLogs + +ana update # replace the running binary with latest release +ana version # banner + build metadata ``` Anything beyond this (`ana dashboard create`, `ana playbook schedule`, etc.) needs a fresh probe — the RPCs exist but their request shapes are not in the catalog yet. diff --git a/e2e/harness/CLAUDE.md b/e2e/harness/CLAUDE.md index bbedcee..798ef0f 100644 --- a/e2e/harness/CLAUDE.md +++ b/e2e/harness/CLAUDE.md @@ -4,7 +4,7 @@ Per-test scaffolding for live smoke tests against a real TextQL endpoint. Duplic ## Files -- `harness.go` — `H`, `Begin`, `End`. Per-test lifecycle with temp config, auth env, verb map, and cleanup stack. +- `harness.go` — `H`, `Begin`, `End`. Per-test lifecycle with temp config, auth env, verb map, and cleanup stack. Exposes `ExpectOrgID()` and `Endpoint()` so tests that assert endpoint/org-referencing stdout (e.g. OAuth callback URLs) read validated values instead of re-querying the env. - `client.go` — mirrors `cmd/ana/main.go`'s verb builder so harness and binary share the same wiring shape. - `guard.go` — wraps mutating RPCs: records them on the ledger before invoking, aborts if the pre-flight guard fails (wrong org, missing env, etc.). - `ledger.go` — `ManualRevertLog` + `Record`/`Close`. Writes any unreverted mutation using `e2e/testdata/manual-revert.template.md`. diff --git a/internal/CLAUDE.md b/internal/CLAUDE.md index f819291..fac24bb 100644 --- a/internal/CLAUDE.md +++ b/internal/CLAUDE.md @@ -24,3 +24,4 @@ Multi-file verb packages use one `_test.go` per source file (e.g. `list. | `ontology/` | `ana ontology` — readonly list/get. | | `feed/` | `ana feed` — show + stats. | | `audit/` | `ana audit tail` — audit-log listing with `--since`. Injectable clock. | +| `update/` | Passive update-check nudge + `ana update` self-update verb. Stdlib-only; 100% covered. | diff --git a/internal/cli/CLAUDE.md b/internal/cli/CLAUDE.md index 1990e80..36d5be5 100644 --- a/internal/cli/CLAUDE.md +++ b/internal/cli/CLAUDE.md @@ -4,7 +4,7 @@ Argument-dispatch core shared by every verb. Defines the `Command` interface, th ## Files -- `cli.go` — `Command`, `IO`, `DefaultIO`, `Group` (nested-verb dispatcher with auto-generated help listing; optional `Flags` closure declares group-level flags that descend to every leaf via `WithAncestorFlags`), `dispatchChild` (the Group/Dispatch handoff that scans a resolved leaf's args for `--help`/`-h` and renders its `Help()` before `Run` is called — and if the leaf implements `Flagger`, appends a `Flags:` block enumerating own + ancestor flags via `renderFlagsAsText`; Groups are skipped so the flag reaches the deepest leaf, and bare positional `help` is left alone so leaves can receive it as an argument), `Flagger` (opt-in interface: leaves that implement `Flags(fs)` get the ancestor-aware `Flags:` block in `--help`), and `renderFlagsAsText` (sorted `--name usage (default: X)` enumeration). Precedence when ancestor and leaf declare the same name: **leaf wins**, because leaves call `ApplyAncestorFlags` AFTER declaring their own flags, and ancestor registrars use `DeclareString` / `DeclareBool` Lookup-guards that skip already-declared names (stdlib `flag.FlagSet.StringVar` panics on duplicate names, verified empirically). +- `cli.go` — `Command`, `IO`, `DefaultIO`, `Group` (nested-verb dispatcher with auto-generated help listing; optional `Flags` closure declares group-level flags that descend to every leaf via `WithAncestorFlags`), `dispatchChild` (the Group/Dispatch handoff that scans a resolved leaf's args for `--help`/`-h` and renders its `Help()` before `Run` is called — and if the leaf implements `Flagger`, appends a `Flags:` block enumerating own + ancestor flags via `renderFlagsAsText`; Groups are skipped so the flag reaches the deepest leaf, and bare positional `help` is left alone so leaves can receive it as an argument), `Flagger` (opt-in interface: leaves that implement `Flags(fs)` get the ancestor-aware `Flags:` block in `--help`), `IsHelpArg` (exported helper `cmd/ana` leaves reuse so the `-h`/`--help`/`help` check lives in one place), and `renderFlagsAsText` (sorted `--name usage (default: X)` enumeration). Precedence when ancestor and leaf declare the same name: **leaf wins**, because leaves call `ApplyAncestorFlags` AFTER declaring their own flags, and ancestor registrars use `DeclareString` / `DeclareBool` Lookup-guards that skip already-declared names (stdlib `flag.FlagSet.StringVar` panics on duplicate names, verified empirically). - `dispatch.go` — `Dispatch` (root entry: short-circuits help, parses globals, routes to the matching verb via `dispatchChild`) and `RootHelp`. - `root.go` — `Global` shape, `WithGlobal`/`GlobalFrom` context helpers (both require a non-nil ctx per stdlib `context.WithValue` convention — nil panics), `ParseGlobal` (stdlib-style front-anchored parse: stops at the first positional), and `StripGlobals` (position-tolerant: walks argv once, consumes known global flags wherever they appear, passes everything else through in order so the leaf's FlagSet reports unknown-flag errors). The two share the authoritative `globalFlagRegistry` list — `TestGlobalFlagsRegistrySync` enforces that the registry matches `ParseGlobal`'s FlagSet shape. `Dispatch` uses `StripGlobals`; `cmd/ana/main.go`'s early config-resolution pre-pass uses it too so `ana org show --profile prod` honours `--profile` even when it's placed after the verb. `globalFlagsHelp` renders the canonical `Global Flags:` block that both `RootHelp` and the leaf `--help` path append so `--json`/`--endpoint`/`--token-file`/`--profile` are discoverable from every help surface. Phase 2 flag-registrar stack: `WithAncestorFlags(ctx, reg)` / `ApplyAncestorFlags(ctx, fs)` (context-carried slice of `func(*flag.FlagSet)` closures that `Group.Run` appends to and leaves replay on their own `FlagSet`), plus `DeclareString` / `DeclareBool` / `DeclareInt` (Lookup-guarded wrappers ancestor closures use instead of raw `StringVar` / `BoolVar` / `IntVar` to avoid the stdlib redeclaration panic when a leaf already declared the same name). - `flags.go` — `ParseFlags`, which tolerates positional args interleaved with flags (stdlib `FlagSet.Parse` stops at the first non-flag, silently dropping later flags); `FlagWasSet`, the `fs.Visit` wrapper partial-update verbs use to tell "user left this alone" from "user explicitly passed the zero value"; `RequireFlags`, which emits a single sorted `missing required flags: --a, --b` usage error for any name not explicitly set on fs; and three typed `flag.Value` constructors — `EnumFlag` (allow-list validation at parse time), `IntListFlag` (CSV → `[]int` with whitespace tolerance), `SinceFlag` (accepts non-negative `time.ParseDuration` or RFC3339, stored UTC via an injected clock). The stdlib `flag.Parse` re-wraps `Set` errors with `%v`, so the `ErrUsage` chain survives only through the outer `ParseFlags` wrap — tests that exercise these helpers must go through `ParseFlags`, not bare `fs.Parse`. diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 949d170..5601c98 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -69,7 +69,7 @@ type Group struct { // before delegating so every descendant leaf can ApplyAncestorFlags and pick // up the group's declared flags. func (g *Group) Run(ctx context.Context, args []string, stdio IO) error { - if len(args) == 0 || isHelpArg(args[0]) { + if len(args) == 0 || IsHelpArg(args[0]) { fmt.Fprintln(stdio.Stdout, g.Help()) return ErrHelp } @@ -207,8 +207,8 @@ func (g *Group) Help() string { return strings.TrimRight(b.String(), "\n") } -// isHelpArg reports whether s is one of the recognized help tokens. -func isHelpArg(s string) bool { +// IsHelpArg reports whether s is one of the recognized help tokens. +func IsHelpArg(s string) bool { return s == "-h" || s == "--help" || s == "help" } diff --git a/internal/cli/dispatch.go b/internal/cli/dispatch.go index 3dcb2af..8ed4e74 100644 --- a/internal/cli/dispatch.go +++ b/internal/cli/dispatch.go @@ -14,7 +14,7 @@ import ( func Dispatch(ctx context.Context, verbs map[string]Command, args []string, stdio IO) error { // A bare help token anywhere up front short-circuits flag parsing so users // can discover commands without first fixing any flag validation errors. - if len(args) > 0 && isHelpArg(args[0]) { + if len(args) > 0 && IsHelpArg(args[0]) { RootHelp(stdio.Stdout, verbs) return ErrHelp } @@ -33,7 +33,7 @@ func Dispatch(ctx context.Context, verbs map[string]Command, args []string, stdi RootHelp(stdio.Stdout, verbs) return ErrHelp } - if isHelpArg(rest[0]) { + if IsHelpArg(rest[0]) { RootHelp(stdio.Stdout, verbs) return ErrHelp } diff --git a/internal/config/CLAUDE.md b/internal/config/CLAUDE.md index 2951bc2..6c4fd10 100644 --- a/internal/config/CLAUDE.md +++ b/internal/config/CLAUDE.md @@ -4,5 +4,5 @@ Reads and writes the ana CLI config file at `$XDG_CONFIG_HOME/ana/config.json` ( ## Files -- `config.go` — `Profile`, `Config`, `DefaultPath`, `Load`, `Save`, `ActiveProfile`, `Upsert`, `Remove`, and `Resolve` (the endpoint + token + profile-name resolver that merges env fallbacks with the loaded config). Owns `DefaultEndpoint` and `ErrUnknownProfile`. +- `config.go` — `Profile`, `Config`, `DefaultPath`, `Load`, `Save`, `ActiveProfile`, `Upsert`, `Remove`, and `Resolve` (the endpoint + token + profile-name resolver that merges env fallbacks with the loaded config). Owns `DefaultEndpoint` and `ErrUnknownProfile`. Also carries the optional `UpdateCheckInterval *string` pointer (interpreted by `internal/update.ParseInterval`) — `omitempty` so existing config files stay unchanged. - `config_test.go` — covers the full round-trip (Load/Save/Upsert) in a `t.TempDir`, profile resolution precedence (flag > env > config > default), and the `ErrUnknownProfile` path. diff --git a/internal/config/config.go b/internal/config/config.go index ab368f5..5fcb70a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -38,9 +38,15 @@ type Profile struct { // Config is the persisted CLI configuration. Profiles maps profile name to // its Profile. Active names the profile selected by default. +// +// UpdateCheckInterval controls the passive self-update nudge cadence (see +// internal/update.ParseInterval). A pointer + omitempty keeps the on-disk +// shape backward-compatible: existing files never acquire the field and a +// nil value means "use the built-in default". type Config struct { - Profiles map[string]Profile `json:"profiles"` - Active string `json:"active"` + Profiles map[string]Profile `json:"profiles"` + Active string `json:"active"` + UpdateCheckInterval *string `json:"updateCheckInterval,omitempty"` } // DefaultPath returns the default path for the config file. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 28e095d..d8f52d3 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -414,6 +414,49 @@ func TestRoundTrip(t *testing.T) { } } +// TestRoundTrip_UpdateCheckInterval verifies the optional nudge-cadence +// field survives Save/Load intact AND that omitting it keeps the on-disk +// shape backward-compatible (no key written when nil). +func TestRoundTrip_UpdateCheckInterval(t *testing.T) { + t.Parallel() + dir := t.TempDir() + // With value: round-trips. + path := filepath.Join(dir, "with.json") + interval := "2h" + want := Config{ + Profiles: map[string]Profile{"default": {Endpoint: "https://e", Token: "t"}}, + Active: "default", + UpdateCheckInterval: &interval, + } + if err := Save(path, want); err != nil { + t.Fatalf("save: %v", err) + } + got, err := Load(path) + if err != nil { + t.Fatalf("load: %v", err) + } + if got.UpdateCheckInterval == nil || *got.UpdateCheckInterval != "2h" { + t.Fatalf("interval not preserved: %+v", got.UpdateCheckInterval) + } + + // Without value: key must be omitted, not serialised as "null". + nilPath := filepath.Join(dir, "without.json") + nilCfg := Config{ + Profiles: map[string]Profile{"default": {Endpoint: "https://e", Token: "t"}}, + Active: "default", + } + if err := Save(nilPath, nilCfg); err != nil { + t.Fatalf("save nil: %v", err) + } + raw, err := os.ReadFile(nilPath) + if err != nil { + t.Fatalf("read: %v", err) + } + if strings.Contains(string(raw), "updateCheckInterval") { + t.Errorf("nil field should be omitted, got: %s", raw) + } +} + func TestActiveProfile(t *testing.T) { t.Parallel() // Active set and present. diff --git a/internal/update/CLAUDE.md b/internal/update/CLAUDE.md new file mode 100644 index 0000000..9e9401b --- /dev/null +++ b/internal/update/CLAUDE.md @@ -0,0 +1,13 @@ +# internal/update + +Powers the passive "new release available" nudge and the `ana update` self-update verb. Stdlib-only by design (the module's zero-deps posture is enforced in review); mirrors the archive layout from `.goreleaser.yml` and the URL template from `install.sh`. Split across small files so each concern (semver compare, nudge check, self-update orchestration, HTTP download, archive extraction) can be read on its own. + +## Files + +- `update.go` — package doc, `HTTPDoer` interface, and the two package-level URL vars (`latestReleaseURL`, `releasesBaseURL`) that tests repoint at `httptest` servers. +- `semver.go` — `CmpSemver` + `parseSemver`. Three-int semver with optional `v` prefix and `-prerelease` suffix; prerelease sorts below release at the same X.Y.Z so `prerelease: auto` beta→stable flows notify correctly. Malformed input returns 0 (never trigger a nudge on junk). +- `check.go` — passive nudge surface: `LatestRelease`, `CacheDeps`, `CachePath` (XDG → HOME), `ParseInterval` (`nil`→4h, `"0"`/`"disable"`→off), `CachedCheck` (reads/writes `update-check.json` atomically), plus the unexported `cacheFile` / `readCache` / `writeCache` / `shouldNotify`. +- `selfupdate.go` — `Deps` + `DefaultDeps` + `resolveDeps`, the `SelfUpdate` orchestration (resolve → compare → download → verify → extract → atomic replace), the `updateStatus` JSON shape + `emitStatus` (routes JSON through `cli.WriteJSON` for byte-compat), and `atomicReplace` (Unix rename-over; Windows rename-aside + rollback). +- `download.go` — HTTP helpers (`httpGet`, `downloadFile` — streams body to disk and returns the sha256 via `TeeReader`, avoiding a re-read, `downloadBody` — 1 MiB-capped for `checksums.txt`) and `verifyChecksum` (parses goreleaser's ` ` or sha256sum's ` *` format and compares against the caller-supplied hex). +- `extract.go` — `extractBinary` dispatches on `archiveExt`; `extractFromTarGz` / `extractFromZip` walk the archive for the matching member and hand the reader to `writeBinary` (0755). +- `_test.go` — one test file per source; shared helpers (`fakeDoer`, `releaseServer`, `stageUpdate`, `wantErr`, `withURLs`, `fakeArchive`) live in `update_test.go` per the repo convention. 100% coverage gate. diff --git a/internal/update/check.go b/internal/update/check.go new file mode 100644 index 0000000..93db5b6 --- /dev/null +++ b/internal/update/check.go @@ -0,0 +1,164 @@ +package update + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// defaultCheckInterval is the cadence of the passive update nudge. 4h +// matches the issue's ask — short enough to notice a fresh release within a +// workday, long enough that CI loops don't spam api.github.com. +const defaultCheckInterval = 4 * time.Hour + +// LatestRelease returns the `tag_name` from GitHub's /releases/latest. +func LatestRelease(ctx context.Context, client HTTPDoer) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, latestReleaseURL, nil) + if err != nil { + return "", fmt.Errorf("update: build request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github+json") + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("update: fetch latest release: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("update: fetch latest release: status %d", resp.StatusCode) + } + var body struct { + TagName string `json:"tag_name"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return "", fmt.Errorf("update: decode latest release: %w", err) + } + if body.TagName == "" { + return "", errors.New("update: empty tag_name in response") + } + return body.TagName, nil +} + +// CacheDeps is the injection boundary for CachedCheck. +type CacheDeps struct { + Env func(string) string + Now func() time.Time + HTTP HTTPDoer +} + +// cacheFile carries the last-observed release so fresh-cache hits can +// still decide whether to nudge without another HTTP call. +type cacheFile struct { + CheckedAt time.Time `json:"checkedAt"` + LatestTag string `json:"latestTag"` +} + +// CachePath resolves $XDG_CACHE_HOME/ana/update-check.json with a +// $HOME/.cache/... fallback; both unset → error so callers skip the check. +func CachePath(env func(string) string) (string, error) { + if xdg := env("XDG_CACHE_HOME"); xdg != "" { + return filepath.Join(xdg, "ana", "update-check.json"), nil + } + if home := env("HOME"); home != "" { + return filepath.Join(home, ".cache", "ana", "update-check.json"), nil + } + return "", errors.New("update: neither XDG_CACHE_HOME nor HOME is set") +} + +// ParseInterval interprets the Config.UpdateCheckInterval pointer: nil → +// (4h, true). "0" or "disable" → (0, false). Any time.ParseDuration-friendly +// string → (d, true). A malformed duration silently falls back to the +// default — we never want a bad config value to block the user's verb. +func ParseInterval(s *string) (time.Duration, bool) { + if s == nil { + return defaultCheckInterval, true + } + v := strings.TrimSpace(*s) + if v == "0" || strings.EqualFold(v, "disable") { + return 0, false + } + d, err := time.ParseDuration(v) + if err != nil || d <= 0 { + return defaultCheckInterval, true + } + return d, true +} + +// CachedCheck returns the newest observed release tag plus a notify flag that +// is true only when currentVersion is older than that tag. When the cache is +// fresh (age < ttl) no HTTP call is made. A stale or missing cache triggers a +// LatestRelease call and an atomic cache rewrite. Any error short-circuits +// to (_, false, err); callers ignore the error since the nudge is best-effort. +func CachedCheck(ctx context.Context, deps CacheDeps, ttl time.Duration, currentVersion string) (string, bool, error) { + now := deps.Now() + path, err := CachePath(deps.Env) + if err != nil { + return "", false, err + } + if cached, ok := readCache(path); ok && now.Sub(cached.CheckedAt) < ttl { + return cached.LatestTag, shouldNotify(currentVersion, cached.LatestTag), nil + } + tag, err := LatestRelease(ctx, deps.HTTP) + if err != nil { + return "", false, err + } + if werr := writeCache(path, cacheFile{CheckedAt: now, LatestTag: tag}); werr != nil { + // Cache-write failure doesn't invalidate the tag we just fetched; + // worst case the next invocation refetches. Still surface it so + // callers can log. + return tag, shouldNotify(currentVersion, tag), werr + } + return tag, shouldNotify(currentVersion, tag), nil +} + +// shouldNotify reports true only when tag strictly exceeds currentVersion. +func shouldNotify(currentVersion, tag string) bool { + if tag == "" { + return false + } + return CmpSemver(currentVersion, strings.TrimPrefix(tag, "v")) < 0 +} + +// readCache returns ok=false on missing/corrupt — the first-run path. +func readCache(path string) (cacheFile, bool) { + data, err := os.ReadFile(path) + if err != nil { + return cacheFile{}, false + } + var c cacheFile + if err := json.Unmarshal(data, &c); err != nil { + return cacheFile{}, false + } + return c, true +} + +// writeCache atomically writes the cache file with 0700 dir + 0600 file +// perms — same pattern as config.Save so permissions stay consistent across +// ana's on-disk state. +func writeCache(path string, c cacheFile) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("update: mkdir %s: %w", dir, err) + } + // time.Time.MarshalJSON can error on years outside [0, 9999]; time.Now() + // won't trip that in practice, but handling it satisfies errchkjson and + // keeps future callers (e.g. a time-injected test) safe. + data, err := json.Marshal(c) + if err != nil { + return fmt.Errorf("update: marshal cache: %w", err) + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o600); err != nil { + return fmt.Errorf("update: write %s: %w", tmp, err) + } + if err := os.Rename(tmp, path); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("update: rename %s -> %s: %w", tmp, path, err) + } + return nil +} diff --git a/internal/update/check_test.go b/internal/update/check_test.go new file mode 100644 index 0000000..ee52df1 --- /dev/null +++ b/internal/update/check_test.go @@ -0,0 +1,351 @@ +package update + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestParseInterval(t *testing.T) { + t.Parallel() + cases := []struct { + in *string + wantD time.Duration + wantEn bool + }{ + {nil, defaultCheckInterval, true}, + {ptr("4h"), 4 * time.Hour, true}, + {ptr("30m"), 30 * time.Minute, true}, + {ptr("0"), 0, false}, + {ptr("disable"), 0, false}, + {ptr(" DISABLE "), 0, false}, + {ptr("garbage"), defaultCheckInterval, true}, + {ptr("-5m"), defaultCheckInterval, true}, + } + for _, tc := range cases { + d, en := ParseInterval(tc.in) + if d != tc.wantD || en != tc.wantEn { + t.Errorf("ParseInterval(%v) = (%v,%v), want (%v,%v)", tc.in, d, en, tc.wantD, tc.wantEn) + } + } +} + +func TestCachePath(t *testing.T) { + t.Parallel() + cases := []struct { + env map[string]string + want string + err bool + }{ + {map[string]string{"XDG_CACHE_HOME": "/xdg", "HOME": "/h"}, filepath.Join("/xdg", "ana", "update-check.json"), false}, + {map[string]string{"HOME": "/h"}, filepath.Join("/h", ".cache", "ana", "update-check.json"), false}, + {nil, "", true}, + } + for _, tc := range cases { + got, err := CachePath(mapEnv(tc.env)) + if (err != nil) != tc.err || got != tc.want { + t.Errorf("CachePath(%v) = (%q,%v); want (%q,err=%v)", tc.env, got, err, tc.want, tc.err) + } + } +} + +func TestLatestRelease(t *testing.T) { + cases := []struct { + name string + handler http.HandlerFunc + doer HTTPDoer + badBaseURL bool + wantTag string + wantErr string + }{ + { + name: "happy", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept"), "vnd.github") { + t.Errorf("missing Accept header") + } + _ = json.NewEncoder(w).Encode(map[string]string{"tag_name": "v1.2.3"}) + }, + wantTag: "v1.2.3", + }, + { + name: "non-200", + handler: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(503) }, + wantErr: "503", + }, + { + name: "bad json", + handler: func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("not json")) }, + wantErr: "decode", + }, + { + name: "empty tag", + handler: func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{"tag_name": ""}) + }, + wantErr: "empty tag_name", + }, + { + name: "do error", + doer: &fakeDoer{handler: func(*http.Request) (*http.Response, error) { return nil, errors.New("net down") }}, + wantErr: "net down", + }, + { + name: "build request error", + badBaseURL: true, + wantErr: "build request", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var client HTTPDoer = http.DefaultClient + if tc.handler != nil { + srv := httptest.NewServer(tc.handler) + defer srv.Close() + prev := latestReleaseURL + latestReleaseURL = srv.URL + defer func() { latestReleaseURL = prev }() + } + if tc.doer != nil { + client = tc.doer + } + if tc.badBaseURL { + prev := latestReleaseURL + latestReleaseURL = "http://\x7f" + defer func() { latestReleaseURL = prev }() + } + tag, err := LatestRelease(context.Background(), client) + if tc.wantErr != "" { + wantErr(t, err, tc.wantErr) + return + } + if err != nil { + t.Fatalf("err: %v", err) + } + if tag != tc.wantTag { + t.Fatalf("tag = %q, want %q", tag, tc.wantTag) + } + }) + } +} + +func TestCachedCheck(t *testing.T) { + freshServerTag := "v1.5.0" + serveFresh := func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{"tag_name": freshServerTag}) + })) + } + + t.Run("fresh hit: no HTTP, notify when behind", func(t *testing.T) { + dir := t.TempDir() + cachePath := filepath.Join(dir, "ana", "update-check.json") + if err := writeCache(cachePath, cacheFile{CheckedAt: time.Now(), LatestTag: "v2.0.0"}); err != nil { + t.Fatalf("seed: %v", err) + } + doer := &fakeDoer{handler: func(*http.Request) (*http.Response, error) { + t.Fatal("HTTP should not be called") + return nil, nil + }} + tag, notify, err := CachedCheck(context.Background(), CacheDeps{ + Env: mapEnv(map[string]string{"XDG_CACHE_HOME": dir}), + Now: time.Now, + HTTP: doer, + }, time.Hour, "1.0.0") + if err != nil || tag != "v2.0.0" || !notify || doer.calls != 0 { + t.Fatalf("tag=%q notify=%v calls=%d err=%v", tag, notify, doer.calls, err) + } + }) + + t.Run("stale refresh rewrites cache", func(t *testing.T) { + srv := serveFresh(t) + defer srv.Close() + prev := latestReleaseURL + latestReleaseURL = srv.URL + defer func() { latestReleaseURL = prev }() + + dir := t.TempDir() + cachePath := filepath.Join(dir, "ana", "update-check.json") + if err := writeCache(cachePath, cacheFile{CheckedAt: time.Now().Add(-time.Hour), LatestTag: "v1.0.0"}); err != nil { + t.Fatalf("seed: %v", err) + } + now := time.Now() + tag, notify, err := CachedCheck(context.Background(), CacheDeps{ + Env: mapEnv(map[string]string{"XDG_CACHE_HOME": dir}), + Now: func() time.Time { return now }, + HTTP: http.DefaultClient, + }, time.Second, "1.0.0") + if err != nil || tag != freshServerTag || !notify { + t.Fatalf("tag=%q notify=%v err=%v", tag, notify, err) + } + c, ok := readCache(cachePath) + if !ok || c.LatestTag != freshServerTag || !c.CheckedAt.Equal(now) { + t.Fatalf("cache not rewritten: %+v ok=%v", c, ok) + } + }) + + t.Run("first run writes cache + notifies", func(t *testing.T) { + srv := serveFresh(t) + defer srv.Close() + prev := latestReleaseURL + latestReleaseURL = srv.URL + defer func() { latestReleaseURL = prev }() + + dir := t.TempDir() + _, notify, err := CachedCheck(context.Background(), CacheDeps{ + Env: mapEnv(map[string]string{"XDG_CACHE_HOME": dir}), + Now: time.Now, + HTTP: http.DefaultClient, + }, time.Hour, "0.0.9") + if err != nil || !notify { + t.Fatalf("notify=%v err=%v", notify, err) + } + if _, ok := readCache(filepath.Join(dir, "ana", "update-check.json")); !ok { + t.Fatal("cache not written") + } + }) + + t.Run("current equal or ahead → no notify", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{"tag_name": "v1.0.0"}) + })) + defer srv.Close() + prev := latestReleaseURL + latestReleaseURL = srv.URL + defer func() { latestReleaseURL = prev }() + + for _, cur := range []string{"1.0.0", "2.0.0"} { + dir := t.TempDir() + _, notify, err := CachedCheck(context.Background(), CacheDeps{ + Env: mapEnv(map[string]string{"XDG_CACHE_HOME": dir}), + Now: time.Now, + HTTP: http.DefaultClient, + }, time.Hour, cur) + if err != nil { + t.Fatalf("err: %v", err) + } + if notify { + t.Errorf("cur=%q notify=true; want false", cur) + } + } + }) + + t.Run("empty cached tag → no notify", func(t *testing.T) { + dir := t.TempDir() + cachePath := filepath.Join(dir, "ana", "update-check.json") + if err := writeCache(cachePath, cacheFile{CheckedAt: time.Now(), LatestTag: ""}); err != nil { + t.Fatalf("seed: %v", err) + } + _, notify, err := CachedCheck(context.Background(), CacheDeps{ + Env: mapEnv(map[string]string{"XDG_CACHE_HOME": dir}), + Now: time.Now, + HTTP: &fakeDoer{handler: func(*http.Request) (*http.Response, error) { t.Fatal("no HTTP"); return nil, nil }}, + }, time.Hour, "1.0.0") + if err != nil || notify { + t.Fatalf("notify=%v err=%v", notify, err) + } + }) + + t.Run("path error", func(t *testing.T) { + _, _, err := CachedCheck(context.Background(), CacheDeps{ + Env: mapEnv(nil), + Now: time.Now, + HTTP: http.DefaultClient, + }, time.Hour, "1.0.0") + wantErr(t, err, "neither XDG_CACHE_HOME") + }) + + t.Run("latest release error", func(t *testing.T) { + doer := &fakeDoer{handler: func(*http.Request) (*http.Response, error) { return nil, errors.New("offline") }} + dir := t.TempDir() + _, notify, err := CachedCheck(context.Background(), CacheDeps{ + Env: mapEnv(map[string]string{"XDG_CACHE_HOME": dir}), + Now: time.Now, + HTTP: doer, + }, time.Hour, "1.0.0") + wantErr(t, err, "offline") + if notify { + t.Error("notify should be false on error") + } + }) + + t.Run("write cache error still returns fetched tag", func(t *testing.T) { + dir := t.TempDir() + // Pre-create a directory at the .tmp path so WriteFile can't create the file. + cachePath := filepath.Join(dir, "ana", "update-check.json") + if err := os.MkdirAll(cachePath+".tmp", 0o700); err != nil { + t.Fatalf("seed: %v", err) + } + srv := serveFresh(t) + defer srv.Close() + prev := latestReleaseURL + latestReleaseURL = srv.URL + defer func() { latestReleaseURL = prev }() + + tag, notify, err := CachedCheck(context.Background(), CacheDeps{ + Env: mapEnv(map[string]string{"XDG_CACHE_HOME": dir}), + Now: time.Now, + HTTP: http.DefaultClient, + }, time.Hour, "1.0.0") + if err == nil { + t.Fatal("expected write error") + } + if tag != freshServerTag || !notify { + t.Fatalf("tag/notify should reflect fetched value: tag=%q notify=%v", tag, notify) + } + }) +} + +func TestReadCache_Corrupt(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "c.json") + if err := os.WriteFile(path, []byte("{not json"), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + if _, ok := readCache(path); ok { + t.Fatal("corrupt cache should yield ok=false") + } +} + +func TestWriteCache_Errors(t *testing.T) { + t.Parallel() + t.Run("marshal (unsafe time)", func(t *testing.T) { + // time.Time.MarshalJSON rejects years outside [0, 9999] — the only + // realistic way to trip json.Marshal on our string+time shape. + path := filepath.Join(t.TempDir(), "c.json") + bad := cacheFile{CheckedAt: time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)} + wantErr(t, writeCache(path, bad), "marshal cache") + }) + t.Run("mkdir", func(t *testing.T) { + dir := t.TempDir() + blocker := filepath.Join(dir, "blocker") + if err := os.WriteFile(blocker, []byte("x"), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + wantErr(t, writeCache(filepath.Join(blocker, "sub", "c.json"), cacheFile{}), "mkdir") + }) + t.Run("write", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "c.json") + if err := os.MkdirAll(path+".tmp", 0o700); err != nil { + t.Fatalf("seed: %v", err) + } + wantErr(t, writeCache(path, cacheFile{}), "write") + }) + t.Run("rename", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "c.json") + if err := os.MkdirAll(path, 0o700); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(path, "x"), []byte("x"), 0o600); err != nil { + t.Fatalf("populate: %v", err) + } + wantErr(t, writeCache(path, cacheFile{}), "rename") + }) +} diff --git a/internal/update/download.go b/internal/update/download.go new file mode 100644 index 0000000..c97773a --- /dev/null +++ b/internal/update/download.go @@ -0,0 +1,96 @@ +package update + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "strings" +) + +// httpGet performs a GET and returns the response body on 200. +func httpGet(ctx context.Context, client HTTPDoer, url string) (io.ReadCloser, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("update: build request: %w", err) + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("update: download %s: %w", url, err) + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("update: download %s: status %d", url, resp.StatusCode) + } + return resp.Body, nil +} + +// downloadFile streams url into dst and returns the hex sha256 of the body +// (computed via TeeReader so the archive doesn't need to be re-read for +// verification). Close-on-error is wired via named return so a failed flush +// surfaces the real cause instead of producing a silently-truncated archive. +func downloadFile(ctx context.Context, client HTTPDoer, url, dst string) (sha string, err error) { + body, err := httpGet(ctx, client, url) + if err != nil { + return "", err + } + defer body.Close() + f, err := os.Create(dst) + if err != nil { + return "", fmt.Errorf("update: create %s: %w", dst, err) + } + defer func() { + if cerr := f.Close(); err == nil { + err = cerr + } + }() + h := sha256.New() + if _, err = io.Copy(f, io.TeeReader(body, h)); err != nil { + return "", fmt.Errorf("update: write %s: %w", dst, err) + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +// maxChecksumsSize caps the checksums.txt buffer so a hostile / broken CDN +// can't force us to allocate arbitrarily large memory. A goreleaser +// checksums.txt for this project is well under 4 KB in practice; 1 MiB is +// generous. +const maxChecksumsSize = 1 << 20 + +func downloadBody(ctx context.Context, client HTTPDoer, url string) ([]byte, error) { + body, err := httpGet(ctx, client, url) + if err != nil { + return nil, err + } + defer body.Close() + return io.ReadAll(io.LimitReader(body, maxChecksumsSize)) +} + +// verifyChecksum compares got against the entry for archiveName in a +// goreleaser-format checksums.txt (" " or sha256sum's +// " *" binary-mode variant). Caller computes got — typically +// via downloadFile's TeeReader output. +func verifyChecksum(got, archiveName string, checksums []byte) error { + want := "" + for _, line := range strings.Split(string(checksums), "\n") { + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + name := strings.TrimPrefix(fields[len(fields)-1], "*") + if name == archiveName { + want = fields[0] + break + } + } + if want == "" { + return fmt.Errorf("update: no checksum entry for %s", archiveName) + } + if got != want { + return fmt.Errorf("update: checksum mismatch for %s: expected %s, got %s", archiveName, want, got) + } + return nil +} diff --git a/internal/update/download_test.go b/internal/update/download_test.go new file mode 100644 index 0000000..6ed3dd1 --- /dev/null +++ b/internal/update/download_test.go @@ -0,0 +1,61 @@ +package update + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" +) + +func TestDownloadFile(t *testing.T) { + t.Parallel() + call := func(t *testing.T, doer HTTPDoer, url, dst string) error { + t.Helper() + _, err := downloadFile(context.Background(), doer, url, dst) + return err + } + t.Run("build request error", func(t *testing.T) { + wantErr(t, call(t, http.DefaultClient, "http://\x7f", filepath.Join(t.TempDir(), "x")), "build request") + }) + t.Run("do error", func(t *testing.T) { + doer := &fakeDoer{handler: func(*http.Request) (*http.Response, error) { return nil, errors.New("dial") }} + wantErr(t, call(t, doer, "http://x", filepath.Join(t.TempDir(), "x")), "dial") + }) + t.Run("create error", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) })) + defer srv.Close() + wantErr(t, call(t, http.DefaultClient, srv.URL, filepath.Join(t.TempDir(), "nope", "x")), "create") + }) + t.Run("copy error", func(t *testing.T) { + doer := &fakeDoer{handler: func(*http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200, Body: errBody{err: errors.New("body boom")}, Header: make(http.Header)}, nil + }} + wantErr(t, call(t, doer, "http://x", filepath.Join(t.TempDir(), "x")), "body boom") + }) +} + +func TestDownloadBody_DoError(t *testing.T) { + t.Parallel() + doer := &fakeDoer{handler: func(*http.Request) (*http.Response, error) { return nil, errors.New("dial") }} + _, err := downloadBody(context.Background(), doer, "http://x") + wantErr(t, err, "dial") +} + +func TestVerifyChecksum(t *testing.T) { + t.Parallel() + t.Run("no entry", func(t *testing.T) { + wantErr(t, verifyChecksum("deadbeef", "a.tar.gz", []byte("dead other.tar.gz\n")), "no checksum entry") + }) + t.Run("binary-mode filename prefix accepted", func(t *testing.T) { + // Reaching "checksum mismatch" (rather than "no checksum entry") + // proves the `*` prefix was stripped and the name match succeeded. + wantErr(t, verifyChecksum("cafef00d", "a.tar.gz", []byte("deadbeef *a.tar.gz\n")), "checksum mismatch") + }) + t.Run("happy path: match", func(t *testing.T) { + if err := verifyChecksum("deadbeef", "a.tar.gz", []byte("deadbeef a.tar.gz\n")); err != nil { + t.Fatalf("err: %v", err) + } + }) +} diff --git a/internal/update/extract.go b/internal/update/extract.go new file mode 100644 index 0000000..c63717a --- /dev/null +++ b/internal/update/extract.go @@ -0,0 +1,89 @@ +package update + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "errors" + "fmt" + "io" + "os" + "path/filepath" +) + +// extractBinary extracts exeName from archivePath to dst (0755). +// archiveExt is "tar.gz" (Unix) or "zip" (Windows). +func extractBinary(archivePath, archiveExt, exeName, dst string) error { + if archiveExt == "zip" { + return extractFromZip(archivePath, exeName, dst) + } + return extractFromTarGz(archivePath, exeName, dst) +} + +func extractFromTarGz(archivePath, exeName, dst string) error { + f, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("update: open %s: %w", archivePath, err) + } + defer f.Close() + gz, err := gzip.NewReader(f) + if err != nil { + return fmt.Errorf("update: gzip %s: %w", archivePath, err) + } + defer gz.Close() + tr := tar.NewReader(gz) + for { + h, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("update: tar %s: %w", archivePath, err) + } + if filepath.Base(h.Name) != exeName { + continue + } + return writeBinary(dst, tr) + } + return fmt.Errorf("update: archive %s missing %s", archivePath, exeName) +} + +func extractFromZip(archivePath, exeName, dst string) error { + zr, err := zip.OpenReader(archivePath) + if err != nil { + return fmt.Errorf("update: zip %s: %w", archivePath, err) + } + defer zr.Close() + for _, zf := range zr.File { + if filepath.Base(zf.Name) != exeName { + continue + } + rc, err := zf.Open() + if err != nil { + return fmt.Errorf("update: open %s in %s: %w", zf.Name, archivePath, err) + } + defer rc.Close() + return writeBinary(dst, rc) + } + return fmt.Errorf("update: archive %s missing %s", archivePath, exeName) +} + +// writeBinary copies src to dst with 0755 so the extracted binary is +// immediately executable. +func writeBinary(dst string, src io.Reader) (err error) { + out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) + if err != nil { + return fmt.Errorf("update: create %s: %w", dst, err) + } + // Capture Close error via named return so a failed flush doesn't leave a + // silently-truncated binary staged for atomicReplace. + defer func() { + if cerr := out.Close(); err == nil { + err = cerr + } + }() + if _, err := io.Copy(out, src); err != nil { + return fmt.Errorf("update: write %s: %w", dst, err) + } + return nil +} diff --git a/internal/update/extract_test.go b/internal/update/extract_test.go new file mode 100644 index 0000000..318d047 --- /dev/null +++ b/internal/update/extract_test.go @@ -0,0 +1,117 @@ +package update + +import ( + "archive/zip" + "bytes" + "compress/gzip" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestExtractFromTarGz(t *testing.T) { + t.Parallel() + t.Run("open error", func(t *testing.T) { + wantErr(t, extractFromTarGz(filepath.Join(t.TempDir(), "nope"), "ana", "/tmp/out"), "open") + }) + t.Run("broken tar inside valid gzip", func(t *testing.T) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte("garbage")) + _ = gz.Close() + path := filepath.Join(t.TempDir(), "broken.tar.gz") + if err := os.WriteFile(path, buf.Bytes(), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + wantErr(t, extractFromTarGz(path, "ana", filepath.Join(t.TempDir(), "out")), "") + }) +} + +func TestExtractFromZip(t *testing.T) { + t.Parallel() + t.Run("open error", func(t *testing.T) { + wantErr(t, extractFromZip(filepath.Join(t.TempDir(), "nope"), "ana.exe", "/tmp/out"), "zip") + }) + t.Run("unsupported method", func(t *testing.T) { + // Writer claims method 99 (registered as a nop); reader has no + // matching decompressor, so zf.Open fails on that entry. + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + zw.RegisterCompressor(99, func(w io.Writer) (io.WriteCloser, error) { + return nopWriteCloser{Writer: w}, nil + }) + fh := &zip.FileHeader{Name: "ana.exe", Method: 99} + w, err := zw.CreateHeader(fh) + if err != nil { + t.Fatalf("create header: %v", err) + } + if _, err := w.Write([]byte("payload")); err != nil { + t.Fatalf("write: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("close: %v", err) + } + archive := filepath.Join(t.TempDir(), "bad.zip") + if err := os.WriteFile(archive, buf.Bytes(), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + wantErr(t, extractFromZip(archive, "ana.exe", filepath.Join(t.TempDir(), "out")), "open") + }) +} + +// TestExtract_PathTraversalSafe proves that a malicious archive entry +// named `../../evil/ana` still lands at the caller-provided `dst` and +// does NOT write outside dst's parent. Regression guard in case a future +// refactor swaps `filepath.Base` for raw `h.Name`. +func TestExtract_PathTraversalSafe(t *testing.T) { + t.Parallel() + cases := []struct { + name, ext, member string + }{ + {"tar.gz", "tar.gz", "../../../evil/ana"}, + {"zip", "zip", "../../../evil/ana.exe"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + archive := filepath.Join(dir, "a") + if err := os.WriteFile(archive, fakeArchive(t, tc.ext, tc.member, []byte("PAYLOAD")), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + dst := filepath.Join(dir, "dst") + target := filepath.Base(tc.member) // "ana" or "ana.exe" + var err error + if tc.ext == "zip" { + err = extractFromZip(archive, target, dst) + } else { + err = extractFromTarGz(archive, target, dst) + } + if err != nil { + t.Fatalf("extract: %v", err) + } + if got, _ := os.ReadFile(dst); string(got) != "PAYLOAD" { + t.Fatalf("dst content = %q", got) + } + // Nothing outside dir (the archive parent) should have been written. + entries, _ := os.ReadDir(filepath.Dir(dir)) + for _, e := range entries { + if e.Name() == "evil" { + t.Fatalf("path traversal wrote outside: %s", e.Name()) + } + } + }) + } +} + +func TestWriteBinary(t *testing.T) { + t.Parallel() + t.Run("create error", func(t *testing.T) { + wantErr(t, writeBinary(filepath.Join(t.TempDir(), "nope", "out"), strings.NewReader("x")), "create") + }) + t.Run("copy error", func(t *testing.T) { + wantErr(t, writeBinary(filepath.Join(t.TempDir(), "out"), errReader{err: errors.New("rd")}), "rd") + }) +} diff --git a/internal/update/selfupdate.go b/internal/update/selfupdate.go new file mode 100644 index 0000000..4185352 --- /dev/null +++ b/internal/update/selfupdate.go @@ -0,0 +1,176 @@ +package update + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/highperformance-tech/ana-cli/internal/cli" +) + +// Deps is the injection boundary for SelfUpdate; zero fields fall +// through to stdlib defaults via resolveDeps. +type Deps struct { + HTTP HTTPDoer + GOOS string + GOARCH string + ExePath func() (string, error) + Rename func(old, newp string) error + TempDir func() (string, error) +} + +// DefaultDeps wires Deps against the process environment. +func DefaultDeps() Deps { + return Deps{ + HTTP: http.DefaultClient, + GOOS: runtime.GOOS, + GOARCH: runtime.GOARCH, + ExePath: os.Executable, + Rename: os.Rename, + TempDir: func() (string, error) { return os.MkdirTemp("", "ana-update-*") }, + } +} + +// resolveDeps fills zero fields with stdlib defaults. +func resolveDeps(d Deps) Deps { + if d.HTTP == nil { + d.HTTP = http.DefaultClient + } + if d.GOOS == "" { + d.GOOS = runtime.GOOS + } + if d.GOARCH == "" { + d.GOARCH = runtime.GOARCH + } + if d.ExePath == nil { + d.ExePath = os.Executable + } + if d.Rename == nil { + d.Rename = os.Rename + } + if d.TempDir == nil { + d.TempDir = func() (string, error) { return os.MkdirTemp("", "ana-update-*") } + } + return d +} + +// updateStatus is the --json output shape. Named type keeps field tags +// stable for scripts that parse `ana update --json`. +type updateStatus struct { + Status string `json:"status"` + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + Archive string `json:"archive,omitempty"` +} + +// SelfUpdate runs the full self-update flow: resolve latest, compare, skip if +// already current, else download + verify + extract + atomic replace. jsonOut +// selects a single-object JSON summary on w; otherwise plain-text progress is +// written line by line. "Already up to date" is success (nil error). +func SelfUpdate(ctx context.Context, deps Deps, currentVersion string, w io.Writer, jsonOut bool) error { + deps = resolveDeps(deps) + + exe, err := deps.ExePath() + if err != nil { + return fmt.Errorf("update: locate executable: %w", err) + } + // Best-effort cleanup of a .old left by a previous Windows update. + // Missing file is expected; other errors just leave it around. + if deps.GOOS == "windows" { + _ = os.Remove(exe + ".old") + } + + tag, err := LatestRelease(ctx, deps.HTTP) + if err != nil { + return err + } + latest := strings.TrimPrefix(tag, "v") + if CmpSemver(currentVersion, latest) >= 0 { + return emitStatus(w, jsonOut, updateStatus{Status: "up-to-date", From: currentVersion}, fmt.Sprintf("ana is already at version %s\n", currentVersion)) + } + + exeName := "ana" + archiveExt := "tar.gz" + if deps.GOOS == "windows" { + exeName = "ana.exe" + archiveExt = "zip" + } + archiveName := fmt.Sprintf("ana_%s_%s_%s.%s", latest, deps.GOOS, deps.GOARCH, archiveExt) + base := releasesBaseURL + "/" + tag + + tmp, err := deps.TempDir() + if err != nil { + return fmt.Errorf("update: create tempdir: %w", err) + } + defer os.RemoveAll(tmp) + + archivePath := filepath.Join(tmp, archiveName) + sum, err := downloadFile(ctx, deps.HTTP, base+"/"+archiveName, archivePath) + if err != nil { + return err + } + checksums, err := downloadBody(ctx, deps.HTTP, base+"/checksums.txt") + if err != nil { + return err + } + if err := verifyChecksum(sum, archiveName, checksums); err != nil { + return err + } + + newBinary := filepath.Join(tmp, exeName+".new") + if err := extractBinary(archivePath, archiveExt, exeName, newBinary); err != nil { + return err + } + + if err := atomicReplace(deps, exe, newBinary); err != nil { + return err + } + return emitStatus(w, jsonOut, + updateStatus{Status: "updated", From: currentVersion, To: latest, Archive: archiveName}, + fmt.Sprintf("Updated ana %s → %s\n", currentVersion, latest), + ) +} + +// emitStatus routes through cli.WriteJSON on --json so output stays +// byte-compatible with the rest of the CLI's --json verbs. +func emitStatus(w io.Writer, jsonOut bool, st updateStatus, plain string) error { + if jsonOut { + return cli.WriteJSON(w, st) + } + if _, err := io.WriteString(w, plain); err != nil { + return fmt.Errorf("update: emit status: %w", err) + } + return nil +} + +// atomicReplace installs newPath over exePath. Unix: a single rename works +// because the old inode stays resident while the process runs. Windows: an +// open .exe cannot be renamed over, so we rename it aside first (to .old) +// and rename the replacement in. If the second rename fails, we roll the +// .old back in place and surface an error that names the recovery path. +func atomicReplace(deps Deps, exePath, newPath string) error { + if deps.GOOS != "windows" { + if err := deps.Rename(newPath, exePath); err != nil { + return fmt.Errorf("update: replace %s: %w", exePath, err) + } + return nil + } + oldPath := exePath + ".old" + if err := deps.Rename(exePath, oldPath); err != nil { + return fmt.Errorf("update: rename %s -> %s: %w", exePath, oldPath, err) + } + if err := deps.Rename(newPath, exePath); err != nil { + // Roll the old binary back; if that also fails we can't fix it for + // the user so tell them where the last-known-good copy lives. + if rbErr := deps.Rename(oldPath, exePath); rbErr != nil { + return fmt.Errorf("update: replace %s failed (%w); rollback also failed (%w); recover from %s", exePath, err, rbErr, oldPath) + } + return fmt.Errorf("update: replace %s: %w", exePath, err) + } + return nil +} diff --git a/internal/update/selfupdate_test.go b/internal/update/selfupdate_test.go new file mode 100644 index 0000000..999eec1 --- /dev/null +++ b/internal/update/selfupdate_test.go @@ -0,0 +1,321 @@ +package update + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "testing" +) + +func TestSelfUpdate_UnixHappyPath(t *testing.T) { + exePath, deps := stageUpdate(t, "linux", "amd64", "ana") + r := &releaseServer{tag: "v1.2.3", archiveName: "ana_1.2.3_linux_amd64.tar.gz"} + r.archiveBody = fakeArchive(t, "tar.gz", "ana", []byte("NEW_BINARY_BYTES")) + srv := r.serve(t) + defer srv.Close() + withURLs(t, srv) + + var rnCalls int + deps.Rename = func(o, n string) error { rnCalls++; return os.Rename(o, n) } + + var out bytes.Buffer + if err := SelfUpdate(context.Background(), deps, "1.0.0", &out, false); err != nil { + t.Fatalf("SelfUpdate: %v", err) + } + if got, _ := os.ReadFile(exePath); string(got) != "NEW_BINARY_BYTES" { + t.Errorf("exe not replaced: %q", got) + } + if rnCalls != 1 { + t.Errorf("expected 1 rename on unix, got %d", rnCalls) + } + if !strings.Contains(out.String(), "Updated ana 1.0.0 → 1.2.3") { + t.Errorf("unexpected stdout: %q", out.String()) + } +} + +func TestSelfUpdate_JSONAndUpToDate(t *testing.T) { + t.Run("updated", func(t *testing.T) { + _, deps := stageUpdate(t, "linux", "amd64", "ana") + r := &releaseServer{tag: "v1.2.3", archiveName: "ana_1.2.3_linux_amd64.tar.gz"} + r.archiveBody = fakeArchive(t, "tar.gz", "ana", []byte("NEW")) + srv := r.serve(t) + defer srv.Close() + withURLs(t, srv) + + var out bytes.Buffer + if err := SelfUpdate(context.Background(), deps, "1.0.0", &out, true); err != nil { + t.Fatalf("SelfUpdate: %v", err) + } + var got updateStatus + if err := json.Unmarshal(out.Bytes(), &got); err != nil { + t.Fatalf("json: %v (%q)", err, out.String()) + } + if got.Status != "updated" || got.From != "1.0.0" || got.To != "1.2.3" { + t.Errorf("status: %+v", got) + } + }) + + for _, jsonOut := range []bool{false, true} { + t.Run(fmt.Sprintf("up-to-date json=%v", jsonOut), func(t *testing.T) { + exePath, deps := stageUpdate(t, "linux", "amd64", "ana") + r := &releaseServer{tag: "v1.0.0", archiveName: "ana_1.0.0_linux_amd64.tar.gz"} + r.archiveBody = fakeArchive(t, "tar.gz", "ana", []byte("NEW")) + srv := r.serve(t) + defer srv.Close() + withURLs(t, srv) + + var out bytes.Buffer + if err := SelfUpdate(context.Background(), deps, "1.0.0", &out, jsonOut); err != nil { + t.Fatalf("SelfUpdate: %v", err) + } + if jsonOut { + var got updateStatus + if err := json.Unmarshal(out.Bytes(), &got); err != nil { + t.Fatalf("json: %v", err) + } + if got.Status != "up-to-date" { + t.Errorf("status: %+v", got) + } + } else if !strings.Contains(out.String(), "already at version 1.0.0") { + t.Errorf("unexpected output: %q", out.String()) + } + if got, _ := os.ReadFile(exePath); string(got) != "old" { + t.Errorf("exe should not change: %q", got) + } + }) + } +} + +// TestSelfUpdate_SadPaths table-drives the one-axis failure modes against an +// otherwise-valid release server. Each case overrides a single field to +// inject the specific failure. +func TestSelfUpdate_SadPaths(t *testing.T) { + validTarGz := fakeArchive(t, "tar.gz", "ana", []byte("NEW")) + wrongMember := fakeArchive(t, "tar.gz", "README.md", []byte("x")) + validZip := fakeArchive(t, "zip", "ana.exe", []byte("NEW")) + zipWrongMember := fakeArchive(t, "zip", "README.md", []byte("x")) + cases := []struct { + name string + goos string + ext string + overrides func(r *releaseServer) + wantErr string + }{ + {"archive 404", "", "tar.gz", func(r *releaseServer) { r.archive404 = true }, "status 404"}, + {"checksums 404", "", "tar.gz", func(r *releaseServer) { r.checksums404 = true }, "status 404"}, + {"checksum mismatch", "", "tar.gz", func(r *releaseServer) { r.checksums = "deadbeef " + r.archiveName + "\n" }, "checksum mismatch"}, + {"no checksum entry", "", "tar.gz", func(r *releaseServer) { r.checksums = "deadbeef unrelated.tar.gz\n" }, "no checksum entry"}, + {"archive missing member (tar.gz)", "", "tar.gz", func(r *releaseServer) { r.archiveBody = wrongMember }, "missing ana"}, + {"archive missing member (zip)", "windows", "zip", func(r *releaseServer) { r.archiveBody = zipWrongMember }, "missing ana.exe"}, + {"corrupt gzip", "", "tar.gz", func(r *releaseServer) { r.archiveBody = []byte("not a gzip stream") }, "gzip"}, + {"corrupt zip", "windows", "zip", func(r *releaseServer) { r.archiveBody = []byte("not a zip either") }, "zip"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + goos := tc.goos + if goos == "" { + goos = "linux" + } + exeName := "ana" + body := validTarGz + if tc.ext == "zip" { + exeName = "ana.exe" + body = validZip + } + _, deps := stageUpdate(t, goos, "amd64", exeName) + r := &releaseServer{ + tag: "v2.0.0", + archiveName: fmt.Sprintf("ana_2.0.0_%s_amd64.%s", goos, tc.ext), + archiveBody: body, + } + if tc.overrides != nil { + tc.overrides(r) + } + srv := r.serve(t) + defer srv.Close() + withURLs(t, srv) + wantErr(t, SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false), tc.wantErr) + }) + } +} + +func TestSelfUpdate_WindowsAtomicReplace(t *testing.T) { + setup := func(t *testing.T) (string, Deps, func()) { + exePath, deps := stageUpdate(t, "windows", "amd64", "ana.exe") + r := &releaseServer{tag: "v2.0.0", archiveName: "ana_2.0.0_windows_amd64.zip"} + r.archiveBody = fakeArchive(t, "zip", "ana.exe", []byte("NEW_BINARY_BYTES")) + srv := r.serve(t) + withURLs(t, srv) + return exePath, deps, srv.Close + } + + t.Run("happy path: rename aside + in, .old cleanup", func(t *testing.T) { + exePath, deps, done := setup(t) + defer done() + if err := os.WriteFile(exePath+".old", []byte("stale"), 0o600); err != nil { + t.Fatalf("seed .old: %v", err) + } + var renames [][2]string + deps.Rename = func(o, n string) error { renames = append(renames, [2]string{o, n}); return os.Rename(o, n) } + + if err := SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false); err != nil { + t.Fatalf("SelfUpdate: %v", err) + } + if len(renames) != 2 || renames[0][1] != exePath+".old" || renames[1][1] != exePath { + t.Fatalf("rename sequence wrong: %v", renames) + } + if got, _ := os.ReadFile(exePath); string(got) != "NEW_BINARY_BYTES" { + t.Errorf("exe not replaced: %q", got) + } + // Pre-seeded .old ("stale") was removed; aside-rename then wrote the + // prior exe ("old") into .old for rollback. + if got, _ := os.ReadFile(exePath + ".old"); string(got) != "old" { + t.Errorf(".old content = %q; want \"old\"", got) + } + }) + + t.Run("rename-in fails: rollback succeeds", func(t *testing.T) { + exePath, deps, done := setup(t) + defer done() + var n int + deps.Rename = func(o, np string) error { + n++ + if n == 2 { + return errors.New("simulated replace failure") + } + return os.Rename(o, np) + } + wantErr(t, SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false), "simulated replace failure") + if got, _ := os.ReadFile(exePath); string(got) != "old" { + t.Errorf("rollback did not restore: %q", got) + } + }) + + t.Run("rename-in fails: rollback also fails", func(t *testing.T) { + _, deps, done := setup(t) + defer done() + var n int + deps.Rename = func(o, np string) error { + n++ + if n >= 2 { + return errors.New("all renames fail") + } + return os.Rename(o, np) + } + wantErr(t, SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false), "recover from") + }) + + t.Run("rename aside fails", func(t *testing.T) { + _, deps, done := setup(t) + defer done() + deps.Rename = func(string, string) error { return errors.New("locked") } + wantErr(t, SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false), "locked") + }) +} + +// TestSelfUpdate_CtxCancel asserts context cancellation mid-flow surfaces +// as a ctx error and doesn't replace the executable. Guards against a future +// regression that swaps ctx-aware helpers for context.Background(). +func TestSelfUpdate_CtxCancel(t *testing.T) { + exePath, deps := stageUpdate(t, "linux", "amd64", "ana") + r := &releaseServer{tag: "v1.2.3", archiveName: "ana_1.2.3_linux_amd64.tar.gz"} + r.archiveBody = fakeArchive(t, "tar.gz", "ana", []byte("NEW")) + srv := r.serve(t) + defer srv.Close() + withURLs(t, srv) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := SelfUpdate(ctx, deps, "1.0.0", io.Discard, false) + if err == nil { + t.Fatal("expected cancellation error") + } + if got, _ := os.ReadFile(exePath); string(got) != "old" { + t.Errorf("exe changed despite cancel: %q", got) + } +} + +func TestSelfUpdate_EarlyErrors(t *testing.T) { + t.Run("exe path error", func(t *testing.T) { + deps := Deps{ExePath: func() (string, error) { return "", errors.New("no exe") }} + wantErr(t, SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false), "no exe") + }) + t.Run("latest release error", func(t *testing.T) { + _, deps := stageUpdate(t, "linux", "amd64", "ana") + deps.HTTP = &fakeDoer{handler: func(*http.Request) (*http.Response, error) { return nil, errors.New("net down") }} + wantErr(t, SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false), "net down") + }) + t.Run("tempdir error", func(t *testing.T) { + _, deps := stageUpdate(t, "linux", "amd64", "ana") + deps.TempDir = func() (string, error) { return "", errors.New("tmp down") } + r := &releaseServer{tag: "v2.0.0", archiveName: "ana_2.0.0_linux_amd64.tar.gz"} + r.archiveBody = fakeArchive(t, "tar.gz", "ana", []byte("NEW")) + srv := r.serve(t) + defer srv.Close() + withURLs(t, srv) + wantErr(t, SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false), "tmp down") + }) + t.Run("unix rename fails", func(t *testing.T) { + _, deps := stageUpdate(t, "linux", "amd64", "ana") + deps.Rename = func(string, string) error { return errors.New("denied") } + r := &releaseServer{tag: "v2.0.0", archiveName: "ana_2.0.0_linux_amd64.tar.gz"} + r.archiveBody = fakeArchive(t, "tar.gz", "ana", []byte("NEW")) + srv := r.serve(t) + defer srv.Close() + withURLs(t, srv) + wantErr(t, SelfUpdate(context.Background(), deps, "1.0.0", io.Discard, false), "denied") + }) +} + +// failingWriter errors on any Write call. Used to exercise emitStatus's +// WriteString-error branch without wiring a whole SelfUpdate invocation. +type failingWriter struct{} + +func (failingWriter) Write([]byte) (int, error) { return 0, errors.New("write denied") } + +func TestEmitStatus_WriteError(t *testing.T) { + t.Parallel() + err := emitStatus(failingWriter{}, false, updateStatus{Status: "x"}, "line\n") + wantErr(t, err, "emit status") +} + +func TestDeps(t *testing.T) { + t.Parallel() + t.Run("DefaultDeps populates everything", func(t *testing.T) { + d := DefaultDeps() + if d.HTTP == nil || d.ExePath == nil || d.Rename == nil || d.TempDir == nil || d.GOOS == "" || d.GOARCH == "" { + t.Fatalf("missing field: %+v", d) + } + p, err := d.TempDir() + if err != nil { + t.Fatalf("tempdir: %v", err) + } + defer os.RemoveAll(p) + }) + t.Run("resolveDeps zero fills and runs", func(t *testing.T) { + d := resolveDeps(Deps{}) + p, err := d.TempDir() + if err != nil { + t.Fatalf("tempdir: %v", err) + } + defer os.RemoveAll(p) + }) + t.Run("resolveDeps preserves explicit", func(t *testing.T) { + d := resolveDeps(Deps{ + HTTP: http.DefaultClient, GOOS: "plan9", GOARCH: "arm", + ExePath: func() (string, error) { return "/x", nil }, + Rename: func(string, string) error { return nil }, + TempDir: func() (string, error) { return "/t", nil }, + }) + if d.GOOS != "plan9" || d.GOARCH != "arm" { + t.Fatalf("explicit overwritten: %+v", d) + } + }) +} diff --git a/internal/update/semver.go b/internal/update/semver.go new file mode 100644 index 0000000..948a834 --- /dev/null +++ b/internal/update/semver.go @@ -0,0 +1,70 @@ +package update + +import ( + "strconv" + "strings" +) + +// CmpSemver returns -1/0/+1 comparing two three-int semver strings. Accepts +// optional leading `v` and an optional `-prerelease` suffix. Prerelease +// versions sort below their non-prerelease counterparts at the same X.Y.Z +// (so goreleaser's `prerelease: auto` beta→stable path notifies correctly). +// Malformed input on either side returns 0 — a bad tag must never trigger a +// nudge. +func CmpSemver(a, b string) int { + av, aPre, aOK := parseSemver(a) + bv, bPre, bOK := parseSemver(b) + if !aOK || !bOK { + return 0 + } + for i := range 3 { + if av[i] < bv[i] { + return -1 + } + if av[i] > bv[i] { + return 1 + } + } + // X.Y.Z match: prerelease < release at the same core version. + if aPre == "" && bPre != "" { + return 1 + } + if aPre != "" && bPre == "" { + return -1 + } + if aPre < bPre { + return -1 + } + if aPre > bPre { + return 1 + } + return 0 +} + +// parseSemver splits s into (major, minor, patch), a prerelease string (may +// be ""), and an ok flag. Rejects any component that isn't a non-negative +// int. Build metadata after `+` is stripped per semver. +func parseSemver(s string) ([3]int, string, bool) { + s = strings.TrimPrefix(s, "v") + if i := strings.IndexByte(s, '+'); i >= 0 { + s = s[:i] + } + var pre string + if i := strings.IndexByte(s, '-'); i >= 0 { + pre = s[i+1:] + s = s[:i] + } + parts := strings.SplitN(s, ".", 3) + if len(parts) != 3 { + return [3]int{}, "", false + } + var out [3]int + for i, p := range parts { + n, err := strconv.Atoi(p) + if err != nil || n < 0 { + return [3]int{}, "", false + } + out[i] = n + } + return out, pre, true +} diff --git a/internal/update/semver_test.go b/internal/update/semver_test.go new file mode 100644 index 0000000..9ed6903 --- /dev/null +++ b/internal/update/semver_test.go @@ -0,0 +1,29 @@ +package update + +import "testing" + +func TestCmpSemver(t *testing.T) { + t.Parallel() + cases := []struct { + a, b string + want int + }{ + {"1.2.3", "1.2.3", 0}, {"v1.2.3", "1.2.3", 0}, + {"1.2.3", "1.2.4", -1}, {"1.2.4", "1.2.3", 1}, + {"1.2.3", "1.3.0", -1}, {"2.0.0", "1.9.9", 1}, {"0.0.1", "0.0.2", -1}, + {"1.2.3", "1.2.3+build.5", 0}, + // prerelease < release at same core + {"1.2.3-beta", "1.2.3", -1}, {"1.2.3", "1.2.3-beta", 1}, + {"1.2.3-alpha", "1.2.3-beta", -1}, {"1.2.3-beta", "1.2.3-alpha", 1}, + {"1.2.3-beta", "1.2.3-beta", 0}, + // malformed → 0 + {"dev", "1.2.3", 0}, {"1.2.3", "dev", 0}, + {"1.2", "1.2.0", 0}, {"-1.0.0", "0.0.1", 0}, + {"1.2.x", "1.2.3", 0}, {"1.-2.3", "1.2.3", 0}, + } + for _, tc := range cases { + if got := CmpSemver(tc.a, tc.b); got != tc.want { + t.Errorf("CmpSemver(%q,%q) = %d, want %d", tc.a, tc.b, got, tc.want) + } + } +} diff --git a/internal/update/update.go b/internal/update/update.go new file mode 100644 index 0000000..73979be --- /dev/null +++ b/internal/update/update.go @@ -0,0 +1,22 @@ +// Package update powers the passive update nudge (CachedCheck) and the +// `ana update` self-update verb (SelfUpdate). All logic sits behind +// injected deps so cmd/ana stays thin wiring. +// +// Stdlib-only; mirrors goreleaser's archive layout and install.sh's URL +// template. +package update + +import "net/http" + +// HTTPDoer is the narrow HTTP interface this package consumes. http.Client +// satisfies it; tests pass a fake that returns canned responses. +type HTTPDoer interface { + Do(*http.Request) (*http.Response, error) +} + +// Repository coordinates. Package vars (not constants) so tests can repoint +// them at a httptest.Server. Callers must not mutate these at runtime. +var ( + latestReleaseURL = "https://api.github.com/repos/highperformance-tech/ana-cli/releases/latest" + releasesBaseURL = "https://github.com/highperformance-tech/ana-cli/releases/download" +) diff --git a/internal/update/update_test.go b/internal/update/update_test.go new file mode 100644 index 0000000..7c475a8 --- /dev/null +++ b/internal/update/update_test.go @@ -0,0 +1,183 @@ +package update + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +// Shared test helpers used across semver_test / check_test / selfupdate_test +// / download_test / extract_test. Per the project convention, helpers used by +// more than one test file live in _test.go. + +type fakeDoer struct { + handler func(*http.Request) (*http.Response, error) + calls int +} + +func (f *fakeDoer) Do(r *http.Request) (*http.Response, error) { + f.calls++ + return f.handler(r) +} + +type errReader struct{ err error } + +func (r errReader) Read([]byte) (int, error) { return 0, r.err } + +type errBody struct{ err error } + +func (b errBody) Read([]byte) (int, error) { return 0, b.err } +func (b errBody) Close() error { return nil } + +type nopWriteCloser struct{ io.Writer } + +func (nopWriteCloser) Close() error { return nil } + +func ptr(s string) *string { return &s } + +func mapEnv(m map[string]string) func(string) string { + return func(k string) string { return m[k] } +} + +func sha256Hex(b []byte) string { + h := sha256.Sum256(b) + return hex.EncodeToString(h[:]) +} + +// withURLs points both package-level URL vars at a test server. Tests that +// use it MUST NOT call t.Parallel() — they share global state. +func withURLs(t *testing.T, srv *httptest.Server) { + t.Helper() + prevLatest, prevBase := latestReleaseURL, releasesBaseURL + latestReleaseURL = srv.URL + "/api/releases/latest" + releasesBaseURL = srv.URL + "/releases" + t.Cleanup(func() { + latestReleaseURL = prevLatest + releasesBaseURL = prevBase + }) +} + +// fakeArchive builds a single-member tar.gz or zip in memory. +func fakeArchive(t *testing.T, ext, name string, payload []byte) []byte { + t.Helper() + var buf bytes.Buffer + switch ext { + case "tar.gz": + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + if err := tw.WriteHeader(&tar.Header{Name: name, Mode: 0o755, Size: int64(len(payload))}); err != nil { + t.Fatalf("tar header: %v", err) + } + if _, err := tw.Write(payload); err != nil { + t.Fatalf("tar write: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("gz close: %v", err) + } + case "zip": + zw := zip.NewWriter(&buf) + f, err := zw.Create(name) + if err != nil { + t.Fatalf("zip create: %v", err) + } + if _, err := f.Write(payload); err != nil { + t.Fatalf("zip write: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("zip close: %v", err) + } + } + return buf.Bytes() +} + +// releaseServer holds the three routes the update flow expects. Zero-value +// fields use sensible defaults; tests override individual fields to simulate +// specific failure modes. +type releaseServer struct { + tag string + archiveName string + latestStatus int + latestBody string + archiveBody []byte + archive404 bool + checksums string + checksums404 bool +} + +func (r *releaseServer) serve(t *testing.T) *httptest.Server { + t.Helper() + sum := sha256Hex(r.archiveBody) + if r.checksums == "" { + r.checksums = fmt.Sprintf("%s %s\n", sum, r.archiveName) + } + mux := http.NewServeMux() + mux.HandleFunc("/api/releases/latest", func(w http.ResponseWriter, _ *http.Request) { + if r.latestStatus != 0 { + w.WriteHeader(r.latestStatus) + } + if r.latestBody != "" { + fmt.Fprint(w, r.latestBody) + return + } + _ = json.NewEncoder(w).Encode(map[string]string{"tag_name": r.tag}) + }) + mux.HandleFunc("/releases/"+r.tag+"/"+r.archiveName, func(w http.ResponseWriter, _ *http.Request) { + if r.archive404 { + w.WriteHeader(404) + return + } + _, _ = w.Write(r.archiveBody) + }) + mux.HandleFunc("/releases/"+r.tag+"/checksums.txt", func(w http.ResponseWriter, _ *http.Request) { + if r.checksums404 { + w.WriteHeader(404) + return + } + fmt.Fprint(w, r.checksums) + }) + return httptest.NewServer(mux) +} + +// stageUpdate seeds an exe + staging root and returns Deps wired at +// those paths. Callers mutate the returned struct for per-test tweaks. +func stageUpdate(t *testing.T, goos, goarch, exeName string) (string, Deps) { + t.Helper() + tmp := t.TempDir() + exePath := filepath.Join(tmp, exeName) + if err := os.WriteFile(exePath, []byte("old"), 0o755); err != nil { + t.Fatalf("seed exe: %v", err) + } + return exePath, Deps{ + GOOS: goos, + GOARCH: goarch, + ExePath: func() (string, error) { return exePath, nil }, + TempDir: func() (string, error) { return os.MkdirTemp(tmp, "stage-*") }, + } +} + +// wantErr fatally fails the test when err is nil or (when substr is +// non-empty) when it doesn't contain substr. +func wantErr(t *testing.T, err error, substr string) { + t.Helper() + if err == nil { + t.Fatalf("expected error containing %q, got nil", substr) + } + if substr != "" && !strings.Contains(err.Error(), substr) { + t.Fatalf("err = %v; want substring %q", err, substr) + } +}