diff --git a/cmd/ans-verify/main.go b/cmd/ans-verify/main.go index 6169487..a43cd17 100644 --- a/cmd/ans-verify/main.go +++ b/cmd/ans-verify/main.go @@ -50,6 +50,16 @@ import ( ) func main() { + // Subcommand dispatch: `ans-verify list ...` enumerates agents + // under a provider via tile-walk. Any other first-arg form falls + // through to the original single-agent verify path so existing + // invocations (`ans-verify ` or `ans-verify -agent `) + // keep working unchanged. + if len(os.Args) > 1 && os.Args[1] == "list" { + listMain(os.Args[2:]) + return + } + var ( baseURL string agentID string diff --git a/cmd/ans-verify/walk.go b/cmd/ans-verify/walk.go new file mode 100644 index 0000000..3a42da8 --- /dev/null +++ b/cmd/ans-verify/walk.go @@ -0,0 +1,736 @@ +package main + +import ( + "bytes" + "context" + "crypto/ecdsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "net/http" + "os" + "regexp" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/transparency-dev/tessera/api" + "github.com/transparency-dev/tessera/api/layout" + + anscrypto "github.com/godaddy/ans/internal/crypto" + "github.com/godaddy/ans/internal/tl/logstore" + "github.com/godaddy/ans/internal/tl/receipt" +) + +// maxResponseBytes caps any single HTTP body the walker is willing +// to read. Sized for a full 256-leaf tile of ~64 KiB envelopes +// (16 MiB) with headroom; protects against a hostile or buggy TL +// streaming an unbounded body and OOMing the verifier. +const maxResponseBytes = 32 * 1024 * 1024 + +// agentIDPattern is the allowed shape for an agentId interpolated +// into a /v1/agents/{id}/... URL. Restricting to UUID syntax means a +// malicious TL leaf can't smuggle path traversal or query-string +// fragments through verifyMatches. The RA only ever issues UUIDs, so +// rejecting anything else is a no-cost defense. +var agentIDPattern = regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`) + +// AgentMatch is one agent leaf that matched the provider filter +// during a tile-walk enumeration. LeafBytes is the JCS-canonical +// envelope bytes that were actually appended to the log — kept so +// the verify step can cross-check that the receipt's payload equals +// what we walked off the tile (defense against leaf substitution). +type AgentMatch struct { + LeafIndex uint64 + AgentID string + AnsName string + Host string + EventType string + LeafBytes []byte +} + +// Terminal lifecycle states — events with these eventTypes are not +// considered "live" by reduceToLive. Kept as a small set rather than +// a "live whitelist" so that any new active-state event type added +// to the schema later (e.g. AGENT_SUSPENDED → reactivated) isn't +// silently filtered out by this client. +var terminalEventTypes = map[string]bool{ + "AGENT_REVOKED": true, + "AGENT_DEPRECATED": true, +} + +// providerMatches reports whether host belongs to providerSuffix — +// either an exact match (case-insensitive, trailing dot tolerated) +// or a strict subdomain (`host = "x.suffix"`). +// +// Empty inputs never match. The "x.suffix" rule rejects accidental +// substring matches like "evilsuffix.com" vs "suffix.com". +func providerMatches(host, providerSuffix string) bool { + if host == "" || providerSuffix == "" { + return false + } + h := strings.ToLower(strings.TrimSuffix(host, ".")) + s := strings.ToLower(strings.TrimSuffix(providerSuffix, ".")) + if h == s { + return true + } + return strings.HasSuffix(h, "."+s) +} + +// agentIdentity is the subset of an envelope the walker needs to +// decide whether a leaf matches the provider filter. +type agentIdentity struct { + AgentID string + AnsName string + Host string + EventType string +} + +// extractAgentIdentity pulls the identity fields needed for provider +// filtering out of a leaf's JCS-canonical envelope JSON. V1 and V2 +// envelopes share the `payload.producer.event.{ansName, eventType, +// agent.host}` path, so a single decoder serves both lanes. +// +// Returns ok=false when the bytes don't parse as JSON or neither +// ansName nor agent.host is populated (which lets the caller skip +// non-event leaves without surfacing an error). +func extractAgentIdentity(envelope []byte) (agentIdentity, bool) { + var env struct { + Payload struct { + Producer struct { + Event struct { + AnsID string `json:"ansId"` + AnsName string `json:"ansName"` + EventType string `json:"eventType"` + Agent struct { + Host string `json:"host"` + } `json:"agent"` + } `json:"event"` + } `json:"producer"` + } `json:"payload"` + } + if err := json.Unmarshal(envelope, &env); err != nil { + return agentIdentity{}, false + } + e := env.Payload.Producer.Event + if e.AnsName == "" && e.Agent.Host == "" { + return agentIdentity{}, false + } + return agentIdentity{ + AgentID: e.AnsID, + AnsName: e.AnsName, + Host: e.Agent.Host, + EventType: e.EventType, + }, true +} + +// decodeEntryBundle parses a c2sp tlog-tiles entry bundle into its +// leaf byte-slices, hiding the tessera api dependency behind a thin +// helper so the walker logic is easy to unit-test against synthetic +// bundles. +func decodeEntryBundle(raw []byte) ([][]byte, error) { + var b api.EntryBundle + if err := b.UnmarshalText(raw); err != nil { + return nil, err + } + return b.Entries, nil +} + +// walkProviderAgents enumerates every leaf in [0, treeSize) by +// fetching entry tiles from baseURL, decoding each leaf envelope, +// and returning the subset whose agent.host falls under +// providerSuffix. +// +// Tile fetches run concurrently bounded by concurrency (clamped to +// [1, 64]). Matches are emitted in log (leaf-index) order regardless +// of fetch completion order so downstream consumers can treat the +// result as a stable timeline. concurrency=0 selects a sensible +// default for the caller. +// +// Note: this is a per-leaf scan, not per-agent deduplication. An +// agent that has multiple events in the log will appear multiple +// times. Lifecycle reduction is reduceToLive's job — keeps the +// walker's contract narrow ("every matching leaf, in order"). +func walkProviderAgents( + ctx context.Context, + client *http.Client, + baseURL, providerSuffix string, + treeSize uint64, + concurrency int, +) ([]AgentMatch, error) { + if treeSize == 0 { + return nil, nil + } + concurrency = clampConcurrency(concurrency) + + const w = uint64(layout.EntryBundleWidth) // 256 + nTiles := (treeSize + w - 1) / w + + // Per-tile result slot: either parsed entries or the first error. + // Indexed by tile position so we can emit matches in log order + // without sorting later. Slots a worker never reaches stay + // zero-value; the post-loop scan treats that as "fetch was + // cancelled" and prefers the captured firstErr. + type tileResult struct { + entries [][]byte + err error + } + results := make([]tileResult, nTiles) + + jobs := make(chan uint64, concurrency) + var wg sync.WaitGroup + wctx, cancel := context.WithCancel(ctx) + defer cancel() + + // firstErr captures the actual triggering error so the caller + // sees the root cause, not whichever tile happens to be lowest- + // indexed when other slots are still zero from an early cancel. + var firstErr atomic.Pointer[error] + recordErr := func(err error) { + // CompareAndSwap-style: only the first failing worker wins. + firstErr.CompareAndSwap(nil, &err) + cancel() + } + + for range concurrency { + wg.Add(1) + go func() { + defer wg.Done() + for tileIdx := range jobs { + if wctx.Err() != nil { + // Drain remaining jobs without doing work so the + // producer's `jobs <- tileIdx` send completes for + // every tile, otherwise close(jobs) deadlocks. + continue + } + partial := layout.PartialTileSize(0, tileIdx, treeSize) + path := layout.EntriesPath(tileIdx, partial) + raw, err := httpGetBytes(wctx, client, baseURL+"/"+path) + if err != nil { + wrapped := fmt.Errorf("fetch %s: %w", path, err) + results[tileIdx] = tileResult{err: wrapped} + recordErr(wrapped) + continue + } + entries, err := decodeEntryBundle(raw) + if err != nil { + wrapped := fmt.Errorf("decode %s: %w", path, err) + results[tileIdx] = tileResult{err: wrapped} + recordErr(wrapped) + continue + } + // Tile-size guard: a full tile MUST be EntryBundleWidth + // entries; a partial tile MUST be exactly `partial`. A + // hostile or buggy TL serving a truncated or oversized + // bundle would otherwise slip through silently — the + // checkpoint signature only binds the tree shape, not + // the bytes of any individual tile. Fail closed. + wantLen := uint64(layout.EntryBundleWidth) + if partial != 0 { + wantLen = uint64(partial) + } + if uint64(len(entries)) != wantLen { + wrapped := fmt.Errorf("%s: bundle has %d entries, want %d", path, len(entries), wantLen) + results[tileIdx] = tileResult{err: wrapped} + recordErr(wrapped) + continue + } + results[tileIdx] = tileResult{entries: entries} + } + }() + } + // Producer: stop enqueueing on cancel so a failure in one worker + // doesn't force the producer to push N more indices into the + // channel before close(jobs) unblocks. Workers still drain any + // already-queued indices via their wctx.Err() check. +producer: + for tileIdx := range nTiles { + select { + case jobs <- tileIdx: + case <-wctx.Done(): + break producer + } + } + close(jobs) + wg.Wait() + + // Prefer the captured first error (most accurate root cause). + // Fall back to the caller's ctx error so an external timeout or + // cancel returns a non-nil error even when no fetch reported one. + if perr := firstErr.Load(); perr != nil { + return nil, *perr + } + if err := ctx.Err(); err != nil { + return nil, err + } + + matches := make([]AgentMatch, 0) + for tileIdx, r := range results { + if r.err != nil { + return nil, r.err + } + base := uint64(tileIdx) * w + for i, leaf := range r.entries { + id, ok := extractAgentIdentity(leaf) + if !ok { + continue + } + if !providerMatches(id.Host, providerSuffix) { + continue + } + matches = append(matches, AgentMatch{ + LeafIndex: base + uint64(i), + AgentID: id.AgentID, + AnsName: id.AnsName, + Host: id.Host, + EventType: id.EventType, + LeafBytes: leaf, + }) + } + } + return matches, nil +} + +// clampConcurrency normalizes user input. 0 → 8 (the default), then +// floor at 1 and ceiling at 64 so a misconfigured CLI flag can't +// either deadlock the walker or DOS the TL. +func clampConcurrency(c int) int { + const def, lo, hi = 8, 1, 64 + if c == 0 { + return def + } + if c < lo { + return lo + } + if c > hi { + return hi + } + return c +} + +// reduceToLive collapses a per-leaf match list into one row per +// AnsName, keeping the most recent leaf, then drops agents whose +// latest event puts them in a terminal lifecycle state (revoked / +// deprecated). This is the answer to "what agents currently live +// under provider X" — distinct from the raw walker output which is +// "every event ever logged under provider X". +// +// Matches with an empty AnsName are passed through individually so +// the caller doesn't lose data from a malformed legacy leaf. +func reduceToLive(matches []AgentMatch) []AgentMatch { + latest := make(map[string]AgentMatch, len(matches)) + out := make([]AgentMatch, 0, len(matches)) + for _, m := range matches { + if m.AnsName == "" { + out = append(out, m) + continue + } + if prev, ok := latest[m.AnsName]; !ok || m.LeafIndex > prev.LeafIndex { + latest[m.AnsName] = m + } + } + for _, m := range latest { + if terminalEventTypes[m.EventType] { + continue + } + out = append(out, m) + } + sort.Slice(out, func(i, j int) bool { return out[i].LeafIndex < out[j].LeafIndex }) + return out +} + +// VerifyResult is the outcome of one per-match receipt verification. +type VerifyResult struct { + Match AgentMatch + OK bool + Err error +} + +// verifyMatches fetches and verifies the SCITT COSE receipt for each +// match, running concurrency workers in parallel. Results are +// returned in match-input order. A missing agentId is a hard error +// for that match — the receipt URL is keyed by agentId, so without +// one we can't even attempt a fetch. +// +// verifyOne is injected so callers can swap the verification logic +// in tests; production callers pass makeReceiptVerifier(keys). +// +// The function returns nil error overall — per-match failures are +// surfaced via VerifyResult.Err so the caller can decide whether a +// partial failure is fatal or informational. +func verifyMatches( + ctx context.Context, + client *http.Client, + baseURL string, + matches []AgentMatch, + verifyOne func(receiptBytes []byte) error, + concurrency int, +) []VerifyResult { + if len(matches) == 0 { + return nil + } + concurrency = clampConcurrency(concurrency) + + out := make([]VerifyResult, len(matches)) + jobs := make(chan int, concurrency) + var wg sync.WaitGroup + for range concurrency { + wg.Add(1) + go func() { + defer wg.Done() + for idx := range jobs { + m := matches[idx] + if m.AgentID == "" { + out[idx] = VerifyResult{Match: m, Err: errors.New("match has no agentId")} + continue + } + if !agentIDPattern.MatchString(m.AgentID) { + // Defense in depth: the agentId came from a TL + // leaf we don't fully trust at this point in the + // pipeline. Anything that isn't a UUID can't be a + // real RA-issued id, so refuse to interpolate it + // into a URL. + out[idx] = VerifyResult{Match: m, Err: fmt.Errorf("agentId %q is not a valid UUID", m.AgentID)} + continue + } + rec, err := httpGetBytes(ctx, client, baseURL+"/v1/agents/"+m.AgentID+"/receipt") + if err != nil { + out[idx] = VerifyResult{Match: m, Err: fmt.Errorf("fetch receipt: %w", err)} + continue + } + if err := verifyOne(rec); err != nil { + out[idx] = VerifyResult{Match: m, Err: fmt.Errorf("verify: %w", err)} + continue + } + // Leaf-substitution guard: the receipt's payload IS + // the canonical envelope bytes that were appended to + // the log. If they don't match the bytes we walked + // off the tile, the TL served a forged tile for a + // real agentId — receipt-only verification would + // silently pass. Skip the check when the walker + // didn't capture LeafBytes (legacy callers). + if m.LeafBytes != nil { + payload, perr := receipt.ExtractPayload(rec) + if perr != nil { + out[idx] = VerifyResult{Match: m, Err: fmt.Errorf("extract receipt payload: %w", perr)} + continue + } + if !bytes.Equal(payload, m.LeafBytes) { + out[idx] = VerifyResult{Match: m, Err: errors.New("receipt payload does not match tile leaf (possible leaf substitution)")} + continue + } + } + out[idx] = VerifyResult{Match: m, OK: true} + } + }() + } + for i := range matches { + jobs <- i + } + close(jobs) + wg.Wait() + return out +} + +// makeReceiptVerifier returns a closure that tries each key against +// a receipt and returns nil on the first success. Used by callers +// that have already loaded /root-keys; tests substitute their own +// stub directly into verifyMatches. +func makeReceiptVerifier(keys []*ecdsa.PublicKey) func([]byte) error { + return func(b []byte) error { + if len(keys) == 0 { + return errors.New("no verification keys available") + } + var lastErr error + for _, k := range keys { + err := receipt.Verify(b, k) + if err == nil { + return nil + } + lastErr = err + } + return lastErr + } +} + +// checkpointTreeSize fetches /v1/log/checkpoint and returns the +// declared logSize WITHOUT verifying the checkpoint signature. +// Retained for tests and for callers that don't have the verifier +// keys handy; production list-mode uses verifiedCheckpoint instead. +func checkpointTreeSize(ctx context.Context, client *http.Client, baseURL string) (uint64, error) { + body, err := httpGetBytes(ctx, client, baseURL+"/v1/log/checkpoint") + if err != nil { + return 0, err + } + var cp struct { + LogSize uint64 `json:"logSize"` + } + if err := json.Unmarshal(body, &cp); err != nil { + return 0, fmt.Errorf("decode checkpoint json: %w", err) + } + return cp.LogSize, nil +} + +// VerifiedCheckpoint is a parsed + signature-verified checkpoint. +type VerifiedCheckpoint struct { + Origin string + Size uint64 + RootHash []byte +} + +// verifiedCheckpoint fetches /checkpoint (raw C2SP signed note), +// verifies the signature against one of keysByHash, and returns the +// parsed origin/size/rootHash. +// +// Without this step, a hostile TL could return a smaller logSize on +// /v1/log/checkpoint than the real tree contains and the walker +// would never fetch the tiles holding agents the attacker wants +// hidden — a textbook omission attack against a transparency log. +func verifiedCheckpoint( + ctx context.Context, + client *http.Client, + baseURL string, + keysByHash map[string]*ecdsa.PublicKey, +) (*VerifiedCheckpoint, error) { + if len(keysByHash) == 0 { + return nil, errors.New("no verification keys available") + } + body, err := httpGetBytes(ctx, client, baseURL+"/checkpoint") + if err != nil { + return nil, fmt.Errorf("fetch /checkpoint: %w", err) + } + return verifyCheckpointNote(body, keysByHash) +} + +// verifyCheckpointNote is the pure (network-free) half of +// verifiedCheckpoint, split out so the parsing + signature +// verification can be unit-tested against synthetic notes. +// +// A C2SP-shaped signed note is: +// +// \n +// \n +// \n +// \n +// — \n +// [— ...] (optional additional signature lines) +// +// Verification succeeds when at least one signature line's keyhash +// matches a known verifier key and the ECDSA P-256 signature +// validates against the body bytes (everything up to and including +// the blank separator line). +func verifyCheckpointNote(raw []byte, keysByHash map[string]*ecdsa.PublicKey) (*VerifiedCheckpoint, error) { + // Body / signature split is the first "\n\n" separator. + sep := bytes.Index(raw, []byte("\n\n")) + if sep < 0 { + return nil, errors.New("checkpoint note: missing body/signature separator") + } + body := raw[:sep+1] // body INCLUDES the trailing newline per signed-note spec + sigLines := bytes.Split(bytes.TrimRight(raw[sep+2:], "\n"), []byte("\n")) + + bodyLines := bytes.Split(bytes.TrimRight(body, "\n"), []byte("\n")) + if len(bodyLines) < 3 { + return nil, fmt.Errorf("checkpoint note: body must have ≥3 lines, got %d", len(bodyLines)) + } + origin := string(bodyLines[0]) + size, err := strconv.ParseUint(string(bodyLines[1]), 10, 64) + if err != nil { + return nil, fmt.Errorf("checkpoint note: parse size %q: %w", bodyLines[1], err) + } + rootHash, err := base64.StdEncoding.DecodeString(string(bodyLines[2])) + if err != nil { + return nil, fmt.Errorf("checkpoint note: decode rootHash: %w", err) + } + + var lastSigErr error + for _, line := range sigLines { + if !bytes.HasPrefix(line, []byte("— ")) { + // Not a signature line (or non-UTF8 prefix on Windows + // CRLF input); skip. + continue + } + // Format: "— ". Last space-separated token is + // the base64-encoded keyhash+signature. + fields := bytes.Fields(line) + if len(fields) < 3 { + continue + } + blob, err := base64.StdEncoding.DecodeString(string(fields[len(fields)-1])) + if err != nil { + lastSigErr = fmt.Errorf("decode sig line: %w", err) + continue + } + if len(blob) < 4 { + lastSigErr = errors.New("sig line: blob shorter than keyhash") + continue + } + keyhashHex := fmt.Sprintf("%08x", binary.BigEndian.Uint32(blob[:4])) + pub, ok := keysByHash[keyhashHex] + if !ok { + continue // signature is for an unknown key — try the next line + } + if !logstore.VerifyC2SPECDSA(pub, body, blob[4:]) { + lastSigErr = fmt.Errorf("sig for kid %s did not verify", keyhashHex) + continue + } + return &VerifiedCheckpoint{Origin: origin, Size: size, RootHash: rootHash}, nil + } + if lastSigErr == nil { + lastSigErr = errors.New("no signature line matched a known verifier key") + } + return nil, fmt.Errorf("checkpoint note: %w", lastSigErr) +} + +// keyHashHex returns the 8-char hex string the /root-keys line +// publishes for pub. Used by tests to wire keysByHash without +// reaching into anscrypto. +func keyHashHex(pub *ecdsa.PublicKey) (string, error) { + h, err := anscrypto.SPKIKeyHash4(pub) + if err != nil { + return "", err + } + return fmt.Sprintf("%08x", binary.BigEndian.Uint32(h)), nil +} + +// httpGetBytes is a minimal GET helper for the walker. Distinct from +// main.go's fetchBinary because that one returns the content-type the +// status-token path needs; the walker only ever wants the body. +// +// Bodies are capped at maxResponseBytes — a hostile or buggy TL +// streaming an unbounded response cannot OOM the verifier. Hitting +// the cap is surfaced as an explicit error rather than a silent +// truncation so callers don't decode partial JSON / partial tiles. +func httpGetBytes(ctx context.Context, client *http.Client, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d", resp.StatusCode) + } + // LimitReader+1 trick: read one byte past the cap so we can tell + // "exactly at the cap" from "tried to overflow it". + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) + if err != nil { + return nil, err + } + if int64(len(body)) > maxResponseBytes { + return nil, fmt.Errorf("response body exceeds %d-byte cap", maxResponseBytes) + } + return body, nil +} + +// listMain implements the `ans-verify list` subcommand: walk the log +// from index 0..treeSize, decoding each leaf envelope and printing +// the ones whose agent.host falls under -provider. +func listMain(args []string) { + // ContinueOnError lets us print a consistent custom usage and + // exit with code 1 on parse failure. flag.ExitOnError calls + // os.Exit(2) internally before Parse returns, so the err-check + // below would never fire. + fs := flag.NewFlagSet("list", flag.ContinueOnError) + fs.Usage = func() { + fmt.Fprintln(os.Stderr, + "usage: ans-verify list -provider [-url ] [-live=false] [-verify] [-concurrency N]") + fs.PrintDefaults() + } + var ( + baseURL string + provider string + live bool + doVerify bool + concurrency int + ) + fs.StringVar(&baseURL, "url", "http://localhost:18081", + "Base URL of the transparency log") + fs.StringVar(&provider, "provider", "", + "Provider host suffix to filter on (e.g. darknetian.com)") + fs.BoolVar(&live, "live", true, + "Collapse to one row per agent and drop revoked/deprecated agents") + fs.BoolVar(&doVerify, "verify", false, + "After listing, fetch and verify each matched agent's SCITT receipt") + fs.IntVar(&concurrency, "concurrency", 8, + "Number of parallel HTTP workers (1-64)") + if err := fs.Parse(args); err != nil { + // fs.Usage already printed (Parse calls it on error). + os.Exit(1) + } + if provider == "" { + fs.Usage() + os.Exit(1) + } + baseURL = strings.TrimRight(baseURL, "/") + client := &http.Client{Timeout: 30 * time.Second} + ctx := context.Background() + + // Always fetch /root-keys first — both the checkpoint signature + // verification AND the per-match receipt verification depend on + // them, so a missing-keys failure should fail fast before we do + // any walking. + keys, keysByHash, err := fetchRootKeys(baseURL) + if err != nil { + fatalf("fetch root-keys: %v", err) + } + cp, err := verifiedCheckpoint(ctx, client, baseURL, keysByHash) + if err != nil { + fatalf("verified checkpoint: %v", err) + } + matches, err := walkProviderAgents(ctx, client, baseURL, provider, cp.Size, concurrency) + if err != nil { + fatalf("walk: %v", err) + } + rawCount := len(matches) + if live { + matches = reduceToLive(matches) + } + + fmt.Printf("=== ANS provider walk ===\n") + fmt.Printf("TL Base URL: %s\n", baseURL) + fmt.Printf("Provider: %s\n", provider) + fmt.Printf("Origin: %s\n", cp.Origin) + fmt.Printf("Tree size: %d leaves (checkpoint signature ✓)\n", cp.Size) + if live { + fmt.Printf("Matched: %d live agents (from %d raw leaves)\n\n", + len(matches), rawCount) + } else { + fmt.Printf("Matched: %d leaves\n\n", len(matches)) + } + for _, m := range matches { + // %q on every TL-supplied string to neutralize newlines and + // terminal-control characters that could spoof output. + fmt.Printf(" [%d] ansName=%q host=%q eventType=%q agentId=%q\n", + m.LeafIndex, m.AnsName, m.Host, m.EventType, m.AgentID) + } + + if !doVerify { + return + } + results := verifyMatches(ctx, client, baseURL, matches, makeReceiptVerifier(keys), concurrency) + var passed, failed int + fmt.Println("\n── Receipt verification ──") + for _, r := range results { + if r.OK { + passed++ + fmt.Printf(" ✓ ansName=%q agentId=%q\n", r.Match.AnsName, r.Match.AgentID) + continue + } + failed++ + fmt.Printf(" ✗ ansName=%q agentId=%q: %v\n", r.Match.AnsName, r.Match.AgentID, r.Err) + } + fmt.Printf("\nVerified %d/%d receipts (%d failed)\n", passed, passed+failed, failed) + if failed > 0 { + os.Exit(1) + } +} diff --git a/cmd/ans-verify/walk_test.go b/cmd/ans-verify/walk_test.go new file mode 100644 index 0000000..b73d7bc --- /dev/null +++ b/cmd/ans-verify/walk_test.go @@ -0,0 +1,877 @@ +package main + +import ( + "bytes" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/asn1" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "sort" + "strings" + "sync" + "testing" +) + +func TestProviderMatches(t *testing.T) { + t.Parallel() + cases := []struct { + host, suffix string + want bool + }{ + {"darknetian.com", "darknetian.com", true}, + {"agent.darknetian.com", "darknetian.com", true}, + {"a.b.darknetian.com", "darknetian.com", true}, + {"DARKNETIAN.COM", "darknetian.com", true}, + {"agent.darknetian.com.", "darknetian.com", true}, // trailing dot + {"agent.darknetian.com", "DARKNETIAN.COM.", true}, // suffix case + dot + {"evildarknetian.com", "darknetian.com", false}, // substring guard + {"darknetian.com.evil", "darknetian.com", false}, + {"", "darknetian.com", false}, + {"darknetian.com", "", false}, + {"darknetian.com", "agent.darknetian.com", false}, // suffix longer than host + } + for _, c := range cases { + name := fmt.Sprintf("%s_under_%s", c.host, c.suffix) + t.Run(name, func(t *testing.T) { + t.Parallel() + if got := providerMatches(c.host, c.suffix); got != c.want { + t.Fatalf("providerMatches(%q, %q) = %v, want %v", + c.host, c.suffix, got, c.want) + } + }) + } +} + +func TestExtractAgentIdentity(t *testing.T) { + t.Parallel() + cases := map[string]struct { + envelope string + wantAns, wantHost, wantE string + wantOK bool + }{ + "v1 register": { + envelope: `{"payload":{"producer":{"event":{"ansName":"ans://v1.0.0.agent.darknetian.com","eventType":"REGISTERED","agent":{"host":"agent.darknetian.com"}}}}}`, + wantAns: "ans://v1.0.0.agent.darknetian.com", + wantHost: "agent.darknetian.com", + wantE: "REGISTERED", + wantOK: true, + }, + "v2 revoke": { + envelope: `{"schemaVersion":"V2","payload":{"producer":{"event":{"ansName":"ans://v1.0.0.x.example","eventType":"REVOKED","agent":{"host":"x.example"}}}}}`, + wantAns: "ans://v1.0.0.x.example", + wantHost: "x.example", + wantE: "REVOKED", + wantOK: true, + }, + "missing both ans and host": { + envelope: `{"payload":{"producer":{"event":{"eventType":"REGISTERED"}}}}`, + wantOK: false, + }, + "ans only (no agent block)": { + envelope: `{"payload":{"producer":{"event":{"ansName":"ans://v1.0.0.x.example"}}}}`, + wantAns: "ans://v1.0.0.x.example", + wantOK: true, + }, + "garbage json": { + envelope: `{not json`, + wantOK: false, + }, + "empty object": { + envelope: `{}`, + wantOK: false, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + t.Parallel() + id, ok := extractAgentIdentity([]byte(c.envelope)) + if ok != c.wantOK { + t.Fatalf("ok = %v, want %v", ok, c.wantOK) + } + if id.AnsName != c.wantAns { + t.Errorf("ansName = %q, want %q", id.AnsName, c.wantAns) + } + if id.Host != c.wantHost { + t.Errorf("host = %q, want %q", id.Host, c.wantHost) + } + if id.EventType != c.wantE { + t.Errorf("eventType = %q, want %q", id.EventType, c.wantE) + } + }) + } +} + +// encodeEntryBundle is the test-side inverse of api.EntryBundle's +// UnmarshalText — `[2-byte BE size][size bytes of leaf]` repeated. +// Kept local to the test file so the production walker only depends +// on the reader half. +func encodeEntryBundle(leaves [][]byte) []byte { + out := make([]byte, 0, 1024) + var sizeBuf [2]byte + for _, leaf := range leaves { + if len(leaf) > 0xFFFF { + panic("test bundle leaf > 64 KiB") + } + binary.BigEndian.PutUint16(sizeBuf[:], uint16(len(leaf))) + out = append(out, sizeBuf[:]...) + out = append(out, leaf...) + } + return out +} + +func TestDecodeEntryBundle_RoundTrip(t *testing.T) { + t.Parallel() + leaves := [][]byte{ + []byte("first"), + []byte(`{"some":"json"}`), + {}, // zero-length leaf + []byte("last"), + } + raw := encodeEntryBundle(leaves) + got, err := decodeEntryBundle(raw) + if err != nil { + t.Fatalf("decode: %v", err) + } + if len(got) != len(leaves) { + t.Fatalf("got %d leaves, want %d", len(got), len(leaves)) + } + for i := range got { + if string(got[i]) != string(leaves[i]) { + t.Errorf("leaf %d = %q, want %q", i, got[i], leaves[i]) + } + } +} + +func TestDecodeEntryBundle_DanglingBytes(t *testing.T) { + t.Parallel() + // Length prefix says 4 bytes but only 2 follow. + raw := []byte{0x00, 0x04, 'a', 'b'} + if _, err := decodeEntryBundle(raw); err == nil { + t.Fatal("want error for truncated bundle, got nil") + } +} + +// makeEnvelope produces a minimal V1-shaped envelope JSON for tests. +// The walker only inspects the fields extractAgentIdentity reads, so +// we don't need to populate the full schema. +func makeEnvelope(ansName, host, eventType string) []byte { + return makeEnvelopeWithID("", ansName, host, eventType) +} + +// makeEnvelopeWithID is the four-field variant used by tests that +// care about agentId (reduceToLive, verifyMatches). +func makeEnvelopeWithID(ansID, ansName, host, eventType string) []byte { + return []byte(fmt.Sprintf( + `{"payload":{"producer":{"event":{"ansId":%q,"ansName":%q,"eventType":%q,"agent":{"host":%q}}}}}`, + ansID, ansName, eventType, host, + )) +} + +// tileServer stands up a httptest.Server that serves tlog-tiles +// entry-tile paths backed by an in-memory `tile index -> bundle` +// map. Lets us exercise the walker without standing up tessera. +func tileServer(t *testing.T, tiles map[string][]byte) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Strip the leading slash so the keys can match what + // layout.EntriesPath returns. + key := strings.TrimPrefix(r.URL.Path, "/") + body, ok := tiles[key] + if !ok { + http.NotFound(w, r) + return + } + _, _ = w.Write(body) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestWalkProviderAgents_FiltersAndIndexes(t *testing.T) { + t.Parallel() + // Build a 3-leaf log: two under darknetian.com, one under example.org. + leaves := [][]byte{ + makeEnvelope("ans://v1.0.0.alpha.darknetian.com", "alpha.darknetian.com", "REGISTERED"), + makeEnvelope("ans://v1.0.0.other.example.org", "other.example.org", "REGISTERED"), + makeEnvelope("ans://v1.0.0.beta.darknetian.com", "beta.darknetian.com", "REGISTERED"), + } + bundle := encodeEntryBundle(leaves) + srv := tileServer(t, map[string][]byte{ + "tile/entries/000.p/3": bundle, // partial tile: 3 of 256 + }) + got, err := walkProviderAgents(context.Background(), srv.Client(), srv.URL, "darknetian.com", 3, 0) + if err != nil { + t.Fatalf("walk: %v", err) + } + if len(got) != 2 { + t.Fatalf("got %d matches, want 2 (matches=%+v)", len(got), got) + } + // Matches must come back in log order. + sort.Slice(got, func(i, j int) bool { return got[i].LeafIndex < got[j].LeafIndex }) + if got[0].LeafIndex != 0 || got[0].Host != "alpha.darknetian.com" { + t.Errorf("got[0] = %+v, want leaf 0 alpha.darknetian.com", got[0]) + } + if got[1].LeafIndex != 2 || got[1].Host != "beta.darknetian.com" { + t.Errorf("got[1] = %+v, want leaf 2 beta.darknetian.com", got[1]) + } +} + +func TestWalkProviderAgents_MultipleTiles(t *testing.T) { + t.Parallel() + // 257 leaves: one full tile (256) plus a 1-leaf partial. + // Put a darknetian match at the start of each tile to exercise + // the base-index arithmetic across the tile boundary. + full := make([][]byte, 256) + for i := range full { + full[i] = makeEnvelope( + fmt.Sprintf("ans://v1.0.0.a%d.example.org", i), + fmt.Sprintf("a%d.example.org", i), + "REGISTERED", + ) + } + full[0] = makeEnvelope("ans://v1.0.0.first.darknetian.com", "first.darknetian.com", "REGISTERED") + partial := [][]byte{ + makeEnvelope("ans://v1.0.0.second.darknetian.com", "second.darknetian.com", "REGISTERED"), + } + srv := tileServer(t, map[string][]byte{ + "tile/entries/000": encodeEntryBundle(full), + "tile/entries/001.p/1": encodeEntryBundle(partial), + }) + got, err := walkProviderAgents(context.Background(), srv.Client(), srv.URL, "darknetian.com", 257, 4) + if err != nil { + t.Fatalf("walk: %v", err) + } + if len(got) != 2 { + t.Fatalf("got %d matches, want 2 (matches=%+v)", len(got), got) + } + sort.Slice(got, func(i, j int) bool { return got[i].LeafIndex < got[j].LeafIndex }) + if got[0].LeafIndex != 0 || got[1].LeafIndex != 256 { + t.Errorf("leaf indices = %d,%d; want 0,256", + got[0].LeafIndex, got[1].LeafIndex) + } +} + +func TestWalkProviderAgents_EmptyTree(t *testing.T) { + t.Parallel() + // No tiles registered — walker should not even attempt a fetch. + srv := tileServer(t, map[string][]byte{}) + got, err := walkProviderAgents(context.Background(), srv.Client(), srv.URL, "darknetian.com", 0, 0) + if err != nil { + t.Fatalf("walk: %v", err) + } + if len(got) != 0 { + t.Fatalf("got %d matches, want 0", len(got)) + } +} + +func TestWalkProviderAgents_FetchError(t *testing.T) { + t.Parallel() + // Tile path not present → 404 → walker surfaces the error. + srv := tileServer(t, map[string][]byte{}) + _, err := walkProviderAgents(context.Background(), srv.Client(), srv.URL, "darknetian.com", 1, 0) + if err == nil { + t.Fatal("want error for missing tile, got nil") + } +} + +func TestWalkProviderAgents_SkipsUnparsableLeaves(t *testing.T) { + t.Parallel() + leaves := [][]byte{ + []byte("not json at all"), + makeEnvelope("ans://v1.0.0.real.darknetian.com", "real.darknetian.com", "REGISTERED"), + []byte(`{"payload":{}}`), // valid json, no event — skipped + } + srv := tileServer(t, map[string][]byte{ + "tile/entries/000.p/3": encodeEntryBundle(leaves), + }) + got, err := walkProviderAgents(context.Background(), srv.Client(), srv.URL, "darknetian.com", 3, 0) + if err != nil { + t.Fatalf("walk: %v", err) + } + if len(got) != 1 || got[0].LeafIndex != 1 { + t.Fatalf("got %+v, want one match at leaf 1", got) + } +} + +func TestCheckpointTreeSize(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/log/checkpoint" { + http.NotFound(w, r) + return + } + _, _ = w.Write([]byte(`{"logSize":42,"originName":"test","rootHash":"AAAA"}`)) + })) + t.Cleanup(srv.Close) + got, err := checkpointTreeSize(context.Background(), srv.Client(), srv.URL) + if err != nil { + t.Fatalf("checkpoint: %v", err) + } + if got != 42 { + t.Fatalf("got %d, want 42", got) + } +} + +func TestCheckpointTreeSize_BadJSON(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{not json`)) + })) + t.Cleanup(srv.Close) + if _, err := checkpointTreeSize(context.Background(), srv.Client(), srv.URL); err == nil { + t.Fatal("want error for bad json, got nil") + } +} + +func TestCheckpointTreeSize_HTTPError(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(srv.Close) + if _, err := checkpointTreeSize(context.Background(), srv.Client(), srv.URL); err == nil { + t.Fatal("want error for 500, got nil") + } +} + +func TestClampConcurrency(t *testing.T) { + t.Parallel() + cases := map[int]int{ + 0: 8, // default + 1: 1, // floor + -5: 1, // negative → floor + 8: 8, // passthrough + 32: 32, // passthrough + 64: 64, // boundary + 500: 64, // ceiling + } + for in, want := range cases { + if got := clampConcurrency(in); got != want { + t.Errorf("clampConcurrency(%d) = %d, want %d", in, got, want) + } + } +} + +func TestReduceToLive(t *testing.T) { + t.Parallel() + // Two agents: + // alpha — registered (leaf 0), renewed (leaf 3) → live + // beta — registered (leaf 1), revoked (leaf 4) → dropped + // gamma — registered (leaf 2) → live + // (one leaf with no AnsName — should pass through untouched) + input := []AgentMatch{ + {LeafIndex: 0, AgentID: "id-a", AnsName: "ans://alpha", Host: "a.x", EventType: "AGENT_REGISTERED"}, + {LeafIndex: 1, AgentID: "id-b", AnsName: "ans://beta", Host: "b.x", EventType: "AGENT_REGISTERED"}, + {LeafIndex: 2, AgentID: "id-g", AnsName: "ans://gamma", Host: "g.x", EventType: "AGENT_REGISTERED"}, + {LeafIndex: 3, AgentID: "id-a", AnsName: "ans://alpha", Host: "a.x", EventType: "AGENT_RENEWED"}, + {LeafIndex: 4, AgentID: "id-b", AnsName: "ans://beta", Host: "b.x", EventType: "AGENT_REVOKED"}, + {LeafIndex: 5, AgentID: "id-x", AnsName: "", Host: "weird.x", EventType: "AGENT_REGISTERED"}, + } + got := reduceToLive(input) + wantAns := map[string]string{ + "ans://alpha": "AGENT_RENEWED", // dedup keeps latest + "ans://gamma": "AGENT_REGISTERED", // unchanged + "": "AGENT_REGISTERED", // empty-name passthrough + } + if len(got) != len(wantAns) { + t.Fatalf("got %d rows, want %d: %+v", len(got), len(wantAns), got) + } + for _, m := range got { + wt, ok := wantAns[m.AnsName] + if !ok { + t.Errorf("unexpected AnsName %q in result", m.AnsName) + continue + } + if m.EventType != wt { + t.Errorf("AnsName=%q eventType = %q, want %q", m.AnsName, m.EventType, wt) + } + } + // Drop a deprecated agent and a renewed one to exercise the + // other terminal type and confirm sort-by-leafIndex on the way out. + input = []AgentMatch{ + {LeafIndex: 10, AnsName: "ans://x", EventType: "AGENT_REGISTERED"}, + {LeafIndex: 5, AnsName: "ans://y", EventType: "AGENT_DEPRECATED"}, + } + got = reduceToLive(input) + if len(got) != 1 || got[0].AnsName != "ans://x" { + t.Fatalf("got %+v, want [ans://x]", got) + } +} + +func TestReduceToLive_Empty(t *testing.T) { + t.Parallel() + got := reduceToLive(nil) + if len(got) != 0 { + t.Fatalf("got %d, want 0", len(got)) + } +} + +func TestWalkProviderAgents_PopulatesAgentID(t *testing.T) { + t.Parallel() + leaves := [][]byte{ + makeEnvelopeWithID("uuid-1", "ans://v1.alpha.darknetian.com", + "alpha.darknetian.com", "AGENT_REGISTERED"), + } + srv := tileServer(t, map[string][]byte{ + "tile/entries/000.p/1": encodeEntryBundle(leaves), + }) + got, err := walkProviderAgents(context.Background(), srv.Client(), srv.URL, "darknetian.com", 1, 0) + if err != nil { + t.Fatalf("walk: %v", err) + } + if len(got) != 1 || got[0].AgentID != "uuid-1" { + t.Fatalf("got %+v, want one match with agentId=uuid-1", got) + } +} + +func TestVerifyMatches_HappyPath(t *testing.T) { + t.Parallel() + var seen sync.Mutex + seenIDs := map[string]int{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(r.URL.Path, "/receipt") { + http.NotFound(w, r) + return + } + seen.Lock() + seenIDs[r.URL.Path]++ + seen.Unlock() + _, _ = w.Write([]byte("RECEIPT-BYTES")) + })) + t.Cleanup(srv.Close) + + const ( + aID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + bID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + ) + matches := []AgentMatch{ + {AgentID: aID, AnsName: "ans://a"}, + {AgentID: bID, AnsName: "ans://b"}, + {AgentID: "", AnsName: "ans://no-id"}, // missing agentId → per-match err + } + stub := func(b []byte) error { + if string(b) != "RECEIPT-BYTES" { + return errors.New("unexpected body") + } + return nil + } + results := verifyMatches(context.Background(), srv.Client(), srv.URL, matches, stub, 4) + if len(results) != 3 { + t.Fatalf("got %d results, want 3", len(results)) + } + if !results[0].OK || !results[1].OK { + t.Errorf("results[0].OK=%v results[1].OK=%v; both want true", results[0].OK, results[1].OK) + } + if results[2].OK || results[2].Err == nil { + t.Errorf("results[2] = %+v, want err for missing agentId", results[2]) + } + if seenIDs["/v1/agents/"+aID+"/receipt"] != 1 || seenIDs["/v1/agents/"+bID+"/receipt"] != 1 { + t.Errorf("expected one fetch each for /a and /b receipts, got %+v", seenIDs) + } +} + +func TestVerifyMatches_FetchAndVerifyErrors(t *testing.T) { + t.Parallel() + // Three UUIDs that round-trip through agentIDPattern. The server + // serves a distinct body per agentId so the test verifier can + // route deterministically without depending on call order. + const ( + goodID = "11111111-1111-1111-1111-111111111111" + fetID = "22222222-2222-2222-2222-222222222222" + verID = "33333333-3333-3333-3333-333333333333" + ) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasPrefix(r.URL.Path, "/v1/agents/"+fetID+"/"): + w.WriteHeader(http.StatusInternalServerError) + case strings.HasPrefix(r.URL.Path, "/v1/agents/"+goodID+"/"): + _, _ = w.Write([]byte("body:good")) + case strings.HasPrefix(r.URL.Path, "/v1/agents/"+verID+"/"): + _, _ = w.Write([]byte("body:verify-fail")) + default: + http.NotFound(w, r) + } + })) + t.Cleanup(srv.Close) + matches := []AgentMatch{ + {AgentID: goodID, AnsName: "ans://good"}, + {AgentID: fetID, AnsName: "ans://bf"}, + {AgentID: verID, AnsName: "ans://bv"}, + } + // Verifier routes by body, not by call order — concurrency-safe, + // refactor-safe. + verifier := func(b []byte) error { + switch string(b) { + case "body:good": + return nil + case "body:verify-fail": + return errors.New("synthetic verify failure") + default: + return fmt.Errorf("unexpected body %q", b) + } + } + results := verifyMatches(context.Background(), srv.Client(), srv.URL, matches, verifier, 4) + if !results[0].OK { + t.Errorf("results[0] = %+v, want OK", results[0]) + } + if results[1].OK || results[1].Err == nil { + t.Errorf("results[1] = %+v, want fetch err", results[1]) + } + if results[2].OK || results[2].Err == nil { + t.Errorf("results[2] = %+v, want verify err", results[2]) + } +} + +func TestVerifyMatches_RejectsBadAgentID(t *testing.T) { + t.Parallel() + // Server should NEVER be hit — the path-injection guard fires + // first. Failing the test if it is reached gives a clear signal + // that the guard regressed. + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + t.Errorf("guard failed: server was hit with path %q", r.URL.Path) + })) + t.Cleanup(srv.Close) + matches := []AgentMatch{ + {AgentID: "../etc/passwd", AnsName: "ans://traversal"}, + {AgentID: "foo?injected=1", AnsName: "ans://query"}, + {AgentID: "not-a-uuid", AnsName: "ans://wrong-shape"}, + } + results := verifyMatches(context.Background(), srv.Client(), srv.URL, matches, + func([]byte) error { return nil }, 1) + for i, r := range results { + if r.OK || r.Err == nil { + t.Errorf("results[%d] = %+v, want guard err", i, r) + } + } +} + +func TestHTTPGetBytes_BodyCapped(t *testing.T) { + t.Parallel() + // Stream just over the cap — guard must reject, not truncate. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + buf := make([]byte, 1024) + written := int64(0) + for written <= maxResponseBytes { + n, _ := w.Write(buf) + written += int64(n) + } + })) + t.Cleanup(srv.Close) + _, err := httpGetBytes(context.Background(), srv.Client(), srv.URL) + if err == nil { + t.Fatal("want cap-exceeded error, got nil") + } + if !strings.Contains(err.Error(), "cap") { + t.Errorf("err = %v, want one mentioning cap", err) + } +} + +func TestWalkProviderAgents_ExternalContextCancel(t *testing.T) { + t.Parallel() + // Server hangs forever — walker must surface ctx err, not nil. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + t.Cleanup(srv.Close) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // already-cancelled + _, err := walkProviderAgents(ctx, srv.Client(), srv.URL, "darknetian.com", 5, 2) + if err == nil { + t.Fatal("want error on cancelled context, got nil") + } +} + +func TestWalkProviderAgents_FetchErrorReportsTriggeringTile(t *testing.T) { + t.Parallel() + // Tile 5 500s; tiles 0-4 return valid full-width bundles of stub + // envelopes that won't match the provider filter. The walker + // should surface the tile-5 error, not whichever lower-indexed + // tile happened to complete first. + full := make([][]byte, 256) + for i := range full { + full[i] = makeEnvelope( + fmt.Sprintf("ans://v1.stub%d.example.org", i), + fmt.Sprintf("stub%d.example.org", i), + "AGENT_REGISTERED", + ) + } + fullBundle := encodeEntryBundle(full) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "tile/entries/005") { + w.WriteHeader(http.StatusInternalServerError) + return + } + _, _ = w.Write(fullBundle) + })) + t.Cleanup(srv.Close) + // 6 tiles' worth of leaves so tile 5 exists. + _, err := walkProviderAgents(context.Background(), srv.Client(), + srv.URL, "darknetian.com", 6*256, 4) + if err == nil { + t.Fatal("want fetch err, got nil") + } + if !strings.Contains(err.Error(), "005") { + t.Errorf("err = %v, want one mentioning tile 005", err) + } +} + +func TestWalkProviderAgents_RejectsWrongTileSize(t *testing.T) { + t.Parallel() + // Tile claims to be a full (256-leaf) tile but the server returns + // only 3 leaves. The size guard MUST reject — otherwise a hostile + // TL can hide entries by serving truncated bundles. + short := encodeEntryBundle([][]byte{ + makeEnvelope("ans://v1.a.x", "a.x", "AGENT_REGISTERED"), + makeEnvelope("ans://v1.b.x", "b.x", "AGENT_REGISTERED"), + makeEnvelope("ans://v1.c.x", "c.x", "AGENT_REGISTERED"), + }) + srv := tileServer(t, map[string][]byte{ + "tile/entries/000": short, // path says full tile (no .p/N suffix) + }) + _, err := walkProviderAgents(context.Background(), srv.Client(), + srv.URL, "x", 256, 1) + if err == nil { + t.Fatal("want size-mismatch error, got nil") + } + if !strings.Contains(err.Error(), "want 256") { + t.Errorf("err = %v, want one mentioning expected size 256", err) + } +} + +func TestWalkProviderAgents_RejectsOversizedTile(t *testing.T) { + t.Parallel() + // Partial tile claims width 1 but the server returns 2 leaves. + // A hostile TL injecting extra leaves into a partial tile would + // otherwise slip through the checkpoint-signature check (the + // checkpoint binds tree shape, not tile contents). + over := encodeEntryBundle([][]byte{ + makeEnvelope("ans://v1.a.x", "a.x", "AGENT_REGISTERED"), + makeEnvelope("ans://v1.b.x", "b.x", "AGENT_REGISTERED"), + }) + srv := tileServer(t, map[string][]byte{ + "tile/entries/000.p/1": over, // path says partial=1 + }) + _, err := walkProviderAgents(context.Background(), srv.Client(), + srv.URL, "x", 1, 1) + if err == nil { + t.Fatal("want size-mismatch error, got nil") + } + if !strings.Contains(err.Error(), "want 1") { + t.Errorf("err = %v, want one mentioning expected size 1", err) + } +} + +func TestVerifyMatches_Empty(t *testing.T) { + t.Parallel() + got := verifyMatches(context.Background(), nil, "", nil, func([]byte) error { return nil }, 0) + if got != nil { + t.Fatalf("got %+v, want nil", got) + } +} + +func TestMakeReceiptVerifier_NoKeys(t *testing.T) { + t.Parallel() + v := makeReceiptVerifier(nil) + if err := v([]byte("anything")); err == nil { + t.Fatal("want error for empty key set, got nil") + } +} + +// signTestCheckpoint produces a synthetic C2SP-shaped signed note +// for use in the checkpoint-verify tests. Mirrors how the TL's +// C2SPECDSASigner constructs the signature line (keyhash:4 || sig), +// so verifyCheckpointNote exercises the same wire shape it does in +// production. +// +// Signature format: ASN.1 DER ECDSA per godaddy/ans PR #38 (the +// previous implementation emitted IEEE P1363 r||s, which the spec +// — api-spec-tl-v2.yaml § CheckpointSignature.rawSignature — never +// matched). VerifyC2SPECDSA still accepts P1363 as a legacy +// fallback, but tests should pin the production format. +func signTestCheckpoint(t *testing.T, priv *ecdsa.PrivateKey, origin string, size uint64, rootHash []byte) []byte { + t.Helper() + body := []byte(fmt.Sprintf("%s\n%d\n%s\n", + origin, size, base64.StdEncoding.EncodeToString(rootHash))) + digest := sha256.Sum256(body) + r, s, err := ecdsa.Sign(rand.Reader, priv, digest[:]) + if err != nil { + t.Fatalf("sign: %v", err) + } + sig, err := asn1.Marshal(struct{ R, S *big.Int }{r, s}) + if err != nil { + t.Fatalf("DER marshal: %v", err) + } + + khex, err := keyHashHex(&priv.PublicKey) + if err != nil { + t.Fatalf("keyhash: %v", err) + } + var kh4 [4]byte + khRaw, err := hexDecode4(khex) + if err != nil { + t.Fatalf("hex decode: %v", err) + } + copy(kh4[:], khRaw) + blob := append(kh4[:], sig...) + sigLine := fmt.Sprintf("— %s %s\n", origin, base64.StdEncoding.EncodeToString(blob)) + return append(body, append([]byte("\n"), []byte(sigLine)...)...) +} + +func hexDecode4(s string) ([]byte, error) { + if len(s) != 8 { + return nil, fmt.Errorf("want 8 hex chars, got %d", len(s)) + } + v, ok := new(big.Int).SetString(s, 16) + if !ok { + return nil, errors.New("bad hex") + } + var out [4]byte + b := v.Bytes() + copy(out[4-len(b):], b) + return out[:], nil +} + +func TestVerifyCheckpointNote_HappyPath(t *testing.T) { + t.Parallel() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("gen key: %v", err) + } + khex, _ := keyHashHex(&priv.PublicKey) + rootHash := sha256.Sum256([]byte("synthetic root")) + note := signTestCheckpoint(t, priv, "demo.example", 42, rootHash[:]) + + cp, err := verifyCheckpointNote(note, map[string]*ecdsa.PublicKey{khex: &priv.PublicKey}) + if err != nil { + t.Fatalf("verify: %v", err) + } + if cp.Origin != "demo.example" || cp.Size != 42 { + t.Errorf("got origin=%q size=%d, want demo.example/42", cp.Origin, cp.Size) + } + if string(cp.RootHash) != string(rootHash[:]) { + t.Errorf("rootHash mismatch") + } +} + +func TestVerifyCheckpointNote_TamperedBodyFails(t *testing.T) { + t.Parallel() + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + khex, _ := keyHashHex(&priv.PublicKey) + rootHash := sha256.Sum256([]byte("v1")) + note := signTestCheckpoint(t, priv, "demo.example", 42, rootHash[:]) + // Flip a digit in the size line — same length so signature + // remains the same length, but body is now wrong. + tampered := bytes.Replace(note, []byte("\n42\n"), []byte("\n99\n"), 1) + if _, err := verifyCheckpointNote(tampered, + map[string]*ecdsa.PublicKey{khex: &priv.PublicKey}); err == nil { + t.Fatal("want error for tampered body, got nil") + } +} + +func TestVerifyCheckpointNote_UnknownKey(t *testing.T) { + t.Parallel() + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + other, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + otherHex, _ := keyHashHex(&other.PublicKey) + rootHash := sha256.Sum256([]byte("rh")) + note := signTestCheckpoint(t, priv, "demo.example", 1, rootHash[:]) + // Only the wrong key is in the map — signature was made with + // `priv`, but `priv` isn't a known verifier. + if _, err := verifyCheckpointNote(note, + map[string]*ecdsa.PublicKey{otherHex: &other.PublicKey}); err == nil { + t.Fatal("want unknown-key error, got nil") + } +} + +func TestVerifyCheckpointNote_MalformedBody(t *testing.T) { + t.Parallel() + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + khex, _ := keyHashHex(&priv.PublicKey) + keys := map[string]*ecdsa.PublicKey{khex: &priv.PublicKey} + + cases := map[string][]byte{ + "no separator": []byte("just one line\n"), + "empty": nil, + "missing rootHash line": []byte("origin\n42\n\n— x AA==\n"), + "non-numeric size": []byte("origin\nNaN\nAAAA\n\n— x AA==\n"), + "bad base64 rootHash": []byte("origin\n1\n!!!notb64\n\n— x AA==\n"), + } + for name, body := range cases { + t.Run(name, func(t *testing.T) { + t.Parallel() + if _, err := verifyCheckpointNote(body, keys); err == nil { + t.Fatalf("want error for %s, got nil", name) + } + }) + } +} + +func TestVerifyMatches_LeafSubstitutionCaught(t *testing.T) { + t.Parallel() + const aID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + // The receipt returns a different payload than the LeafBytes we + // "walked" off the tile — simulates a TL that served a forged + // tile claiming a fake host under our provider but a real receipt + // for an unrelated agent. + receiptBody := []byte("REAL_RECEIPT_FOR_DIFFERENT_LEAF") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(r.URL.Path, "/receipt") { + http.NotFound(w, r) + return + } + _, _ = w.Write(receiptBody) + })) + t.Cleanup(srv.Close) + + // Synthetic receipt payload != tile leaf bytes; the test + // verifier accepts any signature so we know the substitution + // guard — not the signature guard — is what fails. + forgedTileLeaf := []byte("FORGED_TILE_LEAF_BYTES") + matches := []AgentMatch{ + {AgentID: aID, AnsName: "ans://victim", LeafBytes: forgedTileLeaf}, + } + // extractPayloadStub: this test runs verifyMatches's substitution + // branch, which calls receipt.ExtractPayload on receiptBody. That + // body isn't a valid COSE_Sign1, so ExtractPayload will fail + // before the bytes.Equal — which is still the correct outcome + // (the guard rejects). Confirm we get an error. + stub := func(_ []byte) error { return nil } + results := verifyMatches(context.Background(), srv.Client(), srv.URL, matches, stub, 1) + if results[0].OK || results[0].Err == nil { + t.Fatalf("results[0] = %+v, want substitution err", results[0]) + } + // The error should mention either "payload" or "substitution" + // depending on whether ExtractPayload succeeded. + msg := results[0].Err.Error() + if !strings.Contains(msg, "payload") && !strings.Contains(msg, "substitution") { + t.Errorf("err = %v, want substitution-related message", results[0].Err) + } +} + +func TestVerifyMatches_NilLeafBytesSkipsSubstitutionCheck(t *testing.T) { + t.Parallel() + // Caller didn't capture LeafBytes — the substitution check must + // be skipped (back-compat), and the match should pass on the + // strength of the verifier alone. + const aID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok")) + })) + t.Cleanup(srv.Close) + matches := []AgentMatch{{AgentID: aID, AnsName: "ans://legacy"}} + results := verifyMatches(context.Background(), srv.Client(), srv.URL, matches, + func(_ []byte) error { return nil }, 1) + if !results[0].OK { + t.Errorf("results[0] = %+v, want OK", results[0]) + } +}