From 8bb2a934bdbb38ceb98f95e8156e9f6f2cd0457a Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Sat, 13 Jun 2026 08:02:42 +0200 Subject: [PATCH 1/5] feat(graph): classify return-value consumption on call edges Stamp return_usage on every call edge at extraction time across Go, Python, JavaScript, TypeScript, Java, Rust, Ruby, and C#: discarded, assigned, partially_ignored, returned, goroutine, deferred, argument, condition. One shared parent-chain classifier driven by per-grammar node-kind tables, with closure and switch-expression boundaries honored so a call inside a block or match arm is not mislabeled by its enclosing statement; unknown shapes stay unstamped rather than mislabeled. Surfaced on find_usages (per-usage return_usage field + filter param) and verify_change (per-function return-usage distribution of real call sites, so a return-signature change shows how every call site consumes the value). --- cmd/gortex/wire_contract_test.go | 12 +- docs/wire-format.md | 26 +- internal/analysis/contracts.go | 73 ++- .../analysis/return_usage_summary_test.go | 149 +++++ internal/graph/edge.go | 55 ++ internal/mcp/gcx.go | 4 +- internal/mcp/return_usage_test.go | 249 ++++++++ internal/mcp/tools_core.go | 43 +- internal/mcp/tools_enhancements.go | 18 + internal/parser/languages/csharp.go | 25 +- internal/parser/languages/csharp_test.go | 2 +- internal/parser/languages/golang.go | 38 +- internal/parser/languages/golang_test.go | 4 +- internal/parser/languages/java.go | 31 +- internal/parser/languages/javascript.go | 66 ++- internal/parser/languages/python.go | 54 +- internal/parser/languages/python_test.go | 2 +- internal/parser/languages/return_usage.go | 558 ++++++++++++++++++ .../parser/languages/return_usage_test.go | 518 ++++++++++++++++ internal/parser/languages/ruby.go | 17 +- internal/parser/languages/rust.go | 25 +- internal/parser/languages/rust_test.go | 2 +- internal/parser/languages/typescript.go | 69 ++- internal/parser/languages/typescript_test.go | 2 +- 24 files changed, 1899 insertions(+), 143 deletions(-) create mode 100644 internal/analysis/return_usage_summary_test.go create mode 100644 internal/mcp/return_usage_test.go create mode 100644 internal/parser/languages/return_usage.go create mode 100644 internal/parser/languages/return_usage_test.go diff --git a/cmd/gortex/wire_contract_test.go b/cmd/gortex/wire_contract_test.go index 96049200..0f266ff9 100644 --- a/cmd/gortex/wire_contract_test.go +++ b/cmd/gortex/wire_contract_test.go @@ -143,12 +143,12 @@ func wireContractGolden(name string) string { // ProjectID, were added. return "3b8920ab88d05028e215d68d5917445e2e6d05bdad23aef6dcdf6c9920647823" case "graph.Edge": - // Bumped when Context was added — the per-reference role label - // (parameter_type / return_type / field / …) populated on demand by - // find_usages via RefContextOf. Additive: gob decodes older - // snapshots with Context blank, and it is recomputed at query time. - // (Previously bumped when Tier was added.) - return "ed897cce4720cd1482d8c217ba5ffb72d7f19d5d1c2d4015b9a98e9daa9d4b63" + // Bumped when ReturnUsage was added — the per-call-site return-value + // consumption label (discarded / assigned / returned / …) populated on + // demand at extraction and query time. Additive: gob decodes older + // snapshots with ReturnUsage blank, and it is recomputed. + // (Previously bumped when Context, then Tier, were added.) + return "f537793b5542de95a9a4f383e6ed02317ac416c529b187e02fa60dccef1112d0" case "snapshotHeader": // Bumped when the VectorIndex / VectorDims / VectorCount fields // were added (additive — gob decodes unknown fields as zero). diff --git a/docs/wire-format.md b/docs/wire-format.md index 0678ce98..2bea48cf 100644 --- a/docs/wire-format.md +++ b/docs/wire-format.md @@ -180,16 +180,22 @@ Exactly one row. ### `find_usages` -| field | type | description | -|------------|--------|-------------| -| from | string | caller symbol ID | -| to | string | called symbol ID (the query subject) | -| edge_kind | string | `calls`, `references`, `implements`, ... | -| origin | string | tier: `lsp_resolved`, `lsp_dispatch`, `ast_resolved`, `ast_inferred`, `text_matched` | -| confidence | float | 0..1 | -| from_name | string | caller short name | -| from_path | string | caller file path | -| from_line | int | caller start line | +| field | type | description | +|------------------|--------|-------------| +| from | string | caller symbol ID | +| to | string | called symbol ID (the query subject) | +| edge_kind | string | `calls`, `references`, `implements`, ... | +| context | string | reference role at the usage site: `parameter_type`, `return_type`, `field`, `value`, `type`, `attribute`, `generic_arg`, `call` | +| return_usage | string | how a call site consumes the return value: `discarded`, `assigned`, `partially_ignored`, `returned`, `goroutine`, `deferred`, `argument`, `condition`; empty when unclassified | +| origin | string | provenance: `lsp_resolved`, `lsp_dispatch`, `ast_resolved`, `ast_inferred`, `text_matched` | +| tier | string | coarse provenance label derived from origin | +| confidence | float | 0..1 | +| from_name | string | caller short name | +| from_path | string | usage-site file path | +| from_line | int | call-site line (falls back to the caller's start line) | +| from_is_test | bool | caller is a test symbol | +| from_test_role | string | `test`, `benchmark`, `fuzz`, `example` when applicable | +| from_test_runner | string | detected JS/TS test runner when applicable | ### `get_file_summary` diff --git a/internal/analysis/contracts.go b/internal/analysis/contracts.go index c2854a04..81711a79 100644 --- a/internal/analysis/contracts.go +++ b/internal/analysis/contracts.go @@ -27,12 +27,61 @@ type ContractViolation struct { // VerifyResult is the output of contract violation verification. type VerifyResult struct { - Violations []ContractViolation `json:"violations"` - CheckedCallers int `json:"checked_callers"` - CheckedImpls int `json:"checked_impls"` - Clean bool `json:"clean"` - Errors []string `json:"errors,omitempty"` - CrossRepoViolations bool `json:"cross_repo_violations,omitempty"` + Violations []ContractViolation `json:"violations"` + CheckedCallers int `json:"checked_callers"` + CheckedImpls int `json:"checked_impls"` + Clean bool `json:"clean"` + Errors []string `json:"errors,omitempty"` + CrossRepoViolations bool `json:"cross_repo_violations,omitempty"` + ReturnUsage []ReturnUsageSummary `json:"return_usage,omitempty"` +} + +// ReturnUsageSummary aggregates how the call sites of one changed +// function or method consume its return value — the "who actually uses +// the return?" answer an agent needs before changing a return +// signature. Counts come from the extractor-stamped return-usage label +// on each incoming call edge; call sites the classifier could not +// place are reported as unclassified rather than guessed. +type ReturnUsageSummary struct { + SymbolID string `json:"symbol_id"` + CallSites int `json:"call_sites"` + Counts map[string]int `json:"counts,omitempty"` + Unclassified int `json:"unclassified,omitempty"` +} + +// summarizeReturnUsage builds the return-usage distribution for one +// function/method's incoming call edges. Returns nil when the symbol +// has no call sites at all (nothing to report). +func summarizeReturnUsage(g graph.Store, node *graph.Node) *ReturnUsageSummary { + if node == nil || (node.Kind != graph.KindFunction && node.Kind != graph.KindMethod) { + return nil + } + summary := &ReturnUsageSummary{SymbolID: node.ID} + for _, e := range g.GetInEdges(node.ID) { + if e.Kind != graph.EdgeCalls { + continue + } + // Skip speculative dispatch edges: the read surfaces (find_usages + // and the rest) hide them by default, so counting them here would + // make this distribution disagree with the call sites a user + // actually sees. + if e.IsSpeculative() { + continue + } + summary.CallSites++ + if usage := graph.ReturnUsageOf(e); usage != "" { + if summary.Counts == nil { + summary.Counts = map[string]int{} + } + summary.Counts[usage]++ + } else { + summary.Unclassified++ + } + } + if summary.CallSites == 0 { + return nil + } + return summary } // parsedSignature holds the extracted parameter and return type info from a signature string. @@ -64,6 +113,18 @@ func VerifyChanges(g graph.Store, engine *query.Engine, changes []SignatureChang } } + // For a function/method, summarise how its call sites consume + // the return value — the sites that bind / return / branch on the + // result are the ones a return-type change touches. The discarded + // count is not a blanket all-clear: a discarded label also folds + // every-sink-blank multi-assignment (Go `_, _ = f()`), which still + // breaks on a return-arity change because the blank list must + // match the result count. Read it as "the value is unused here", + // not "this site is safe to change". + if summary := summarizeReturnUsage(g, node); summary != nil { + result.ReturnUsage = append(result.ReturnUsage, *summary) + } + // Check callers for parameter mismatches callerSG := engine.GetCallers(change.SymbolID, query.QueryOptions{Depth: 2, Limit: 500}) for _, callerNode := range callerSG.Nodes { diff --git a/internal/analysis/return_usage_summary_test.go b/internal/analysis/return_usage_summary_test.go new file mode 100644 index 00000000..f9a685a5 --- /dev/null +++ b/internal/analysis/return_usage_summary_test.go @@ -0,0 +1,149 @@ +package analysis + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/query" +) + +// buildReturnUsageGraph wires a target function with three classified +// call sites and one the extractor left unstamped. +func buildReturnUsageGraph(t *testing.T) (*graph.Graph, string) { + t.Helper() + g := graph.New() + targetID := "pkg/t.go::Target" + g.AddNode(&graph.Node{ + ID: targetID, Kind: graph.KindFunction, Name: "Target", + FilePath: "pkg/t.go", StartLine: 1, + Meta: map[string]any{"signature": "func() error"}, + }) + for i, usage := range []string{ + graph.ReturnUsageDiscarded, + graph.ReturnUsageDiscarded, + graph.ReturnUsageAssigned, + "", // unclassified site + } { + callerID := "pkg/c.go::caller" + string(rune('A'+i)) + g.AddNode(&graph.Node{ + ID: callerID, Kind: graph.KindFunction, Name: "caller", + FilePath: "pkg/c.go", StartLine: 10 * (i + 1), + }) + e := &graph.Edge{ + From: callerID, To: targetID, Kind: graph.EdgeCalls, + FilePath: "pkg/c.go", Line: 10*(i+1) + 1, + } + if usage != "" { + e.Meta = map[string]any{graph.MetaReturnUsage: usage} + } + g.AddEdge(e) + } + return g, targetID +} + +func TestVerifyChanges_ReturnUsageSummary(t *testing.T) { + g, targetID := buildReturnUsageGraph(t) + engine := query.NewEngine(g) + + result := VerifyChanges(g, engine, []SignatureChange{ + {SymbolID: targetID, NewSignature: "func() (int, error)"}, + }) + + require.Len(t, result.ReturnUsage, 1) + ru := result.ReturnUsage[0] + assert.Equal(t, targetID, ru.SymbolID) + assert.Equal(t, 4, ru.CallSites) + assert.Equal(t, 2, ru.Counts[graph.ReturnUsageDiscarded]) + assert.Equal(t, 1, ru.Counts[graph.ReturnUsageAssigned]) + assert.Equal(t, 1, ru.Unclassified) +} + +// Speculative dispatch edges are hidden by default on every read +// surface (find_usages and the rest), so the return-usage distribution +// must not count them either — otherwise verify_change disagrees with +// the call sites a user actually sees. +func TestVerifyChanges_ReturnUsageSkipsSpeculative(t *testing.T) { + g := graph.New() + targetID := "pkg/t.go::Target" + g.AddNode(&graph.Node{ + ID: targetID, Kind: graph.KindFunction, Name: "Target", + FilePath: "pkg/t.go", StartLine: 1, + Meta: map[string]any{"signature": "func() error"}, + }) + // One concrete (visible) assigned call site. + g.AddNode(&graph.Node{ + ID: "pkg/c.go::real", Kind: graph.KindFunction, Name: "real", FilePath: "pkg/c.go", StartLine: 10, + }) + g.AddEdge(&graph.Edge{ + From: "pkg/c.go::real", To: targetID, Kind: graph.EdgeCalls, + FilePath: "pkg/c.go", Line: 11, + Meta: map[string]any{graph.MetaReturnUsage: graph.ReturnUsageAssigned}, + }) + // One speculative dispatch call site — hidden from read surfaces by + // default, so it must not appear in the distribution. + g.AddNode(&graph.Node{ + ID: "pkg/c.go::spec", Kind: graph.KindFunction, Name: "spec", FilePath: "pkg/c.go", StartLine: 20, + }) + g.AddEdge(&graph.Edge{ + From: "pkg/c.go::spec", To: targetID, Kind: graph.EdgeCalls, + FilePath: "pkg/c.go", Line: 21, Origin: graph.OriginSpeculative, + Meta: map[string]any{ + graph.MetaReturnUsage: graph.ReturnUsageDiscarded, + graph.MetaSpeculative: true, + }, + }) + engine := query.NewEngine(g) + + result := VerifyChanges(g, engine, []SignatureChange{ + {SymbolID: targetID, NewSignature: "func() (int, error)"}, + }) + + require.Len(t, result.ReturnUsage, 1) + ru := result.ReturnUsage[0] + assert.Equal(t, 1, ru.CallSites, "speculative call site must not be counted") + assert.Equal(t, 1, ru.Counts[graph.ReturnUsageAssigned]) + assert.Zero(t, ru.Counts[graph.ReturnUsageDiscarded], "speculative discarded site must be excluded") + assert.Zero(t, ru.Unclassified) +} + +// A non-callable symbol must not produce a distribution: return-usage +// only means something for function/method return values. +func TestVerifyChanges_ReturnUsageSkipsNonFunctions(t *testing.T) { + g := graph.New() + typeID := "pkg/t.go::Config" + g.AddNode(&graph.Node{ + ID: typeID, Kind: graph.KindType, Name: "Config", FilePath: "pkg/t.go", + }) + g.AddNode(&graph.Node{ + ID: "pkg/c.go::user", Kind: graph.KindFunction, Name: "user", FilePath: "pkg/c.go", + }) + g.AddEdge(&graph.Edge{ + From: "pkg/c.go::user", To: typeID, Kind: graph.EdgeReferences, + FilePath: "pkg/c.go", Line: 4, + }) + engine := query.NewEngine(g) + + result := VerifyChanges(g, engine, []SignatureChange{ + {SymbolID: typeID, NewSignature: "struct{}"}, + }) + assert.Empty(t, result.ReturnUsage) +} + +// A function nobody calls reports no distribution rather than an empty +// one — there is nothing to break. +func TestVerifyChanges_ReturnUsageSkipsUncalled(t *testing.T) { + g := graph.New() + fnID := "pkg/t.go::Lonely" + g.AddNode(&graph.Node{ + ID: fnID, Kind: graph.KindFunction, Name: "Lonely", FilePath: "pkg/t.go", + Meta: map[string]any{"signature": "func()"}, + }) + engine := query.NewEngine(g) + + result := VerifyChanges(g, engine, []SignatureChange{ + {SymbolID: fnID, NewSignature: "func() error"}, + }) + assert.Empty(t, result.ReturnUsage) +} diff --git a/internal/graph/edge.go b/internal/graph/edge.go index 4cdd7f01..cc84d8f9 100644 --- a/internal/graph/edge.go +++ b/internal/graph/edge.go @@ -555,6 +555,15 @@ type Edge struct { // (`find_usages context:"parameter_type"`). Not part of the edge // identity / dedup key. Context string `json:"context,omitempty"` + // ReturnUsage is how the call site consumes the callee's return + // value — discarded, assigned, partially_ignored, returned, + // goroutine, deferred, argument, or condition. Like Context it is + // empty on the stored edge and populated on demand from + // Meta[MetaReturnUsage] (stamped by the language extractors at + // call-edge creation) when find_usages renders a usage, so agents + // can ask "who actually uses the return?" before changing a return + // signature. Not part of the edge identity / dedup key. + ReturnUsage string `json:"return_usage,omitempty"` // Meta is intentionally excluded from JSON. It holds internal // instrumentation (semantic_source, provider hints, etc.) that agents // don't consume but that adds measurable bytes to every edge in @@ -876,6 +885,52 @@ func RefContextOf(e *Edge, fromKind NodeKind) string { return "" } +// Return-usage labels — how a call site consumes the callee's return +// value. Stamped by the language extractors as Meta[MetaReturnUsage] +// on EdgeCalls edges at creation time, so the classification persists +// through every backend and survives reindexing. +const ( + // ReturnUsageDiscarded: the call is a bare expression statement (or + // every result is bound to a blank sink, e.g. Go `_, _ = f()`). + ReturnUsageDiscarded = "discarded" + // ReturnUsageAssigned: the result is bound to variables or fields. + ReturnUsageAssigned = "assigned" + // ReturnUsagePartiallyIgnored: a multi-result call where some (but + // not all) results are bound to blank sinks — Go `v, _ := f()`. + ReturnUsagePartiallyIgnored = "partially_ignored" + // ReturnUsageReturned: the call sits inside a return statement (or + // the implicit-return tail position of an expression-bodied lambda + // / Rust function / Ruby method). + ReturnUsageReturned = "returned" + // ReturnUsageGoroutine: the call is launched via a Go `go` + // statement — the return value is unobservable. + ReturnUsageGoroutine = "goroutine" + // ReturnUsageDeferred: the call is a Go `defer` statement. + ReturnUsageDeferred = "deferred" + // ReturnUsageArgument: the result feeds another call — either as a + // literal argument or as the receiver of a chained call. + ReturnUsageArgument = "argument" + // ReturnUsageCondition: the call sits inside an if / for / while / + // switch condition. + ReturnUsageCondition = "condition" +) + +// MetaReturnUsage is the edge Meta key carrying the return-usage label +// of a call site. Single source of truth for the extractors that stamp +// it and the read paths (find_usages, verify_change) that surface it. +const MetaReturnUsage = "return_usage" + +// ReturnUsageOf returns the extractor-stamped return-usage label of a +// call edge, or "" when the edge carries none (non-call edges, call +// shapes the classifier could not place). +func ReturnUsageOf(e *Edge) string { + if e == nil || e.Meta == nil { + return "" + } + s, _ := e.Meta[MetaReturnUsage].(string) + return s +} + // ConfidenceLabelFor returns EXTRACTED, INFERRED, or AMBIGUOUS for an edge // based on its kind and confidence value. // diff --git a/internal/mcp/gcx.go b/internal/mcp/gcx.go index 5e7de6cb..92961b2b 100644 --- a/internal/mcp/gcx.go +++ b/internal/mcp/gcx.go @@ -369,7 +369,7 @@ func encodeFindUsages(sg *query.SubGraph) ([]byte, error) { meta := []string{"edges", fmt.Sprintf("%d", len(sg.Edges))} meta = append(meta, zeroEdgeCaveatMeta(sg.Caveat)...) enc := newGCX(&buf, "find_usages", - []string{"from", "to", "edge_kind", "context", "origin", "tier", "confidence", "from_name", "from_path", "from_line", "from_is_test", "from_test_role", "from_test_runner"}, + []string{"from", "to", "edge_kind", "context", "return_usage", "origin", "tier", "confidence", "from_name", "from_path", "from_line", "from_is_test", "from_test_role", "from_test_runner"}, meta..., ) nodeIdx := indexNodes(sg.Nodes) @@ -398,7 +398,7 @@ func encodeFindUsages(sg *query.SubGraph) ([]byte, error) { tier = graph.ResolvedBy(e.Origin) } if err := enc.WriteRow( - e.From, e.To, string(e.Kind), e.Context, e.Origin, tier, e.Confidence, + e.From, e.To, string(e.Kind), e.Context, e.ReturnUsage, e.Origin, tier, e.Confidence, fname, fpath, fline, nodeIsTest(fn), nodeTestRole(fn), nodeTestRunner(fn), ); err != nil { return nil, err diff --git a/internal/mcp/return_usage_test.go b/internal/mcp/return_usage_test.go new file mode 100644 index 00000000..36cabf28 --- /dev/null +++ b/internal/mcp/return_usage_test.go @@ -0,0 +1,249 @@ +package mcp + +import ( + "context" + "encoding/json" + "strings" + "testing" + + wire "github.com/gortexhq/gcx-go" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/parser/languages" + "github.com/zzet/gortex/internal/query" + "github.com/zzet/gortex/internal/search" +) + +// returnUsageServer builds a server whose graph has a function `Fetch` +// called from four sites: one discarding the result, one assigning it, +// one returning it, and one the classifier left unstamped. +func returnUsageServer(t *testing.T) (*Server, string) { + t.Helper() + g := graph.New() + fetch := &graph.Node{ + ID: "pkg/fetch.go::Fetch", Kind: graph.KindFunction, Name: "Fetch", + FilePath: "pkg/fetch.go", StartLine: 3, + Meta: map[string]any{"signature": "func() (int, error)"}, + } + callers := []*graph.Node{ + {ID: "pkg/a.go::drop", Kind: graph.KindFunction, Name: "drop", FilePath: "pkg/a.go", StartLine: 2}, + {ID: "pkg/a.go::keep", Kind: graph.KindFunction, Name: "keep", FilePath: "pkg/a.go", StartLine: 8}, + {ID: "pkg/b.go::relay", Kind: graph.KindFunction, Name: "relay", FilePath: "pkg/b.go", StartLine: 4}, + {ID: "pkg/b.go::opaque", Kind: graph.KindFunction, Name: "opaque", FilePath: "pkg/b.go", StartLine: 12}, + } + g.AddNode(fetch) + for _, n := range callers { + g.AddNode(n) + } + g.AddEdge(&graph.Edge{ + From: "pkg/a.go::drop", To: fetch.ID, Kind: graph.EdgeCalls, + FilePath: "pkg/a.go", Line: 3, + Meta: map[string]any{graph.MetaReturnUsage: graph.ReturnUsageDiscarded}, + }) + g.AddEdge(&graph.Edge{ + From: "pkg/a.go::keep", To: fetch.ID, Kind: graph.EdgeCalls, + FilePath: "pkg/a.go", Line: 9, + Meta: map[string]any{graph.MetaReturnUsage: graph.ReturnUsageAssigned}, + }) + g.AddEdge(&graph.Edge{ + From: "pkg/b.go::relay", To: fetch.ID, Kind: graph.EdgeCalls, + FilePath: "pkg/b.go", Line: 5, + Meta: map[string]any{graph.MetaReturnUsage: graph.ReturnUsageReturned}, + }) + g.AddEdge(&graph.Edge{ + From: "pkg/b.go::opaque", To: fetch.ID, Kind: graph.EdgeCalls, + FilePath: "pkg/b.go", Line: 13, + }) + + eng := query.NewEngine(g) + eng.SetSearch(search.NewBM25()) + return NewServer(eng, g, nil, nil, zap.NewNop(), nil), fetch.ID +} + +func findUsagesEdges(t *testing.T, srv *Server, args map[string]any) []map[string]any { + t.Helper() + req := mcplib.CallToolRequest{} + req.Params.Name = "find_usages" + req.Params.Arguments = args + res, err := srv.handleFindUsages(context.Background(), req) + require.NoError(t, err) + require.False(t, res.IsError) + var resp struct { + Edges []map[string]any `json:"edges"` + } + require.NoError(t, json.Unmarshal([]byte(res.Content[0].(mcplib.TextContent).Text), &resp)) + return resp.Edges +} + +func TestFindUsages_ReturnUsageLabels(t *testing.T) { + srv, id := returnUsageServer(t) + edges := findUsagesEdges(t, srv, map[string]any{"id": id}) + require.Len(t, edges, 4) + + byFrom := map[string]string{} + for _, e := range edges { + usage, _ := e["return_usage"].(string) + byFrom[e["from"].(string)] = usage + } + assert.Equal(t, graph.ReturnUsageDiscarded, byFrom["pkg/a.go::drop"]) + assert.Equal(t, graph.ReturnUsageAssigned, byFrom["pkg/a.go::keep"]) + assert.Equal(t, graph.ReturnUsageReturned, byFrom["pkg/b.go::relay"]) + assert.Empty(t, byFrom["pkg/b.go::opaque"], "unstamped edge carries no label") +} + +func TestFindUsages_ReturnUsageFilter(t *testing.T) { + srv, id := returnUsageServer(t) + edges := findUsagesEdges(t, srv, map[string]any{"id": id, "return_usage": "discarded"}) + require.Len(t, edges, 1) + assert.Equal(t, "pkg/a.go::drop", edges[0]["from"]) + assert.Equal(t, graph.ReturnUsageDiscarded, edges[0]["return_usage"]) +} + +func TestFindUsages_ReturnUsageGroupedByFile(t *testing.T) { + srv, id := returnUsageServer(t) + groups := findUsagesGroups(t, srv, map[string]any{"id": id}) + found := map[string]bool{} + for _, g := range groups { + for _, u := range g.(map[string]any)["uses"].([]any) { + if usage, ok := u.(map[string]any)["return_usage"].(string); ok { + found[usage] = true + } + } + } + assert.True(t, found[graph.ReturnUsageDiscarded]) + assert.True(t, found[graph.ReturnUsageAssigned]) + assert.True(t, found[graph.ReturnUsageReturned]) +} + +func TestFindUsages_ReturnUsageGCXColumn(t *testing.T) { + srv, id := returnUsageServer(t) + req := mcplib.CallToolRequest{} + req.Params.Name = "find_usages" + req.Params.Arguments = map[string]any{"id": id, "format": "gcx"} + res, err := srv.handleFindUsages(context.Background(), req) + require.NoError(t, err) + require.False(t, res.IsError) + + payload := res.Content[0].(mcplib.TextContent).Text + dec := wire.NewDecoder(strings.NewReader(payload)) + h, err := dec.Header() + require.NoError(t, err) + require.Contains(t, h.Fields, "return_usage") + rows, err := dec.All() + require.NoError(t, err) + require.Len(t, rows, 4) + usages := map[string]string{} + for _, r := range rows { + usages[r["from"]] = r["return_usage"] + } + assert.Equal(t, graph.ReturnUsageDiscarded, usages["pkg/a.go::drop"]) + assert.Equal(t, graph.ReturnUsageReturned, usages["pkg/b.go::relay"]) +} + +func TestVerifyChange_ReturnUsageDistribution(t *testing.T) { + srv, id := returnUsageServer(t) + req := mcplib.CallToolRequest{} + req.Params.Name = "verify_change" + req.Params.Arguments = map[string]any{ + "changes": `[{"symbol_id":"` + id + `","new_signature":"func() (string, error)"}]`, + } + res, err := srv.handleVerifyChange(context.Background(), req) + require.NoError(t, err) + require.False(t, res.IsError) + + var resp struct { + ReturnUsage []struct { + SymbolID string `json:"symbol_id"` + CallSites int `json:"call_sites"` + Counts map[string]int `json:"counts"` + Unclassified int `json:"unclassified"` + } `json:"return_usage"` + } + require.NoError(t, json.Unmarshal([]byte(res.Content[0].(mcplib.TextContent).Text), &resp)) + require.Len(t, resp.ReturnUsage, 1) + ru := resp.ReturnUsage[0] + assert.Equal(t, id, ru.SymbolID) + assert.Equal(t, 4, ru.CallSites) + assert.Equal(t, 1, ru.Counts[graph.ReturnUsageDiscarded]) + assert.Equal(t, 1, ru.Counts[graph.ReturnUsageAssigned]) + assert.Equal(t, 1, ru.Counts[graph.ReturnUsageReturned]) + assert.Equal(t, 1, ru.Unclassified) +} + +func TestVerifyChange_ReturnUsageCompactLine(t *testing.T) { + srv, id := returnUsageServer(t) + req := mcplib.CallToolRequest{} + req.Params.Name = "verify_change" + req.Params.Arguments = map[string]any{ + "changes": `[{"symbol_id":"` + id + `","new_signature":"func() (string, error)"}]`, + "compact": true, + } + res, err := srv.handleVerifyChange(context.Background(), req) + require.NoError(t, err) + require.False(t, res.IsError) + + text := res.Content[0].(mcplib.TextContent).Text + assert.Contains(t, text, + "return_usage "+id+" call_sites:4 assigned:1 discarded:1 returned:1 unclassified:1") +} + +// TestFindUsages_ReturnUsageEndToEnd drives the full chain: the Go +// extractor classifies real call sites, the edges land in a graph (with +// the unresolved targets bound to the callee the way the resolver +// does), and find_usages surfaces each site's label. +func TestFindUsages_ReturnUsageEndToEnd(t *testing.T) { + src := []byte(`package main + +func helper() int { + return 1 +} + +func drop() { + helper() +} + +func keep() { + v := helper() + _ = v +} + +func relay() int { + return helper() +} +`) + result, err := languages.NewGoExtractor().Extract("main.go", src) + require.NoError(t, err) + defer result.Tree.Release() + + g := graph.New() + for _, n := range result.Nodes { + g.AddNode(n) + } + for _, e := range result.Edges { + // Bind the extractor's unresolved target onto the local + // definition — the same join the resolver performs. + if e.To == "unresolved::helper" { + e.To = "main.go::helper" + } + g.AddEdge(e) + } + + eng := query.NewEngine(g) + eng.SetSearch(search.NewBM25()) + srv := NewServer(eng, g, nil, nil, zap.NewNop(), nil) + + edges := findUsagesEdges(t, srv, map[string]any{"id": "main.go::helper"}) + require.NotEmpty(t, edges) + byFrom := map[string]string{} + for _, e := range edges { + usage, _ := e["return_usage"].(string) + byFrom[e["from"].(string)] = usage + } + assert.Equal(t, graph.ReturnUsageDiscarded, byFrom["main.go::drop"]) + assert.Equal(t, graph.ReturnUsageAssigned, byFrom["main.go::keep"]) + assert.Equal(t, graph.ReturnUsageReturned, byFrom["main.go::relay"]) +} diff --git a/internal/mcp/tools_core.go b/internal/mcp/tools_core.go index f5ab5751..5625be76 100644 --- a/internal/mcp/tools_core.go +++ b/internal/mcp/tools_core.go @@ -979,6 +979,7 @@ func (s *Server) registerCoreTools() { mcp.WithBoolean("exclude_tests", mcp.Description("Drop references originating in test functions (set true to see only production usages)")), mcp.WithString("group_by", mcp.Description("Set to \"file\" to bucket the usages by the file each reference originates in -- each group carries the per-file use count and the enclosing symbol of every reference. Omit for the default flat result.")), mcp.WithString("context", mcp.Description("Filter usages by their reference context — the role the symbol plays at each site: parameter_type, return_type, field, value, type, attribute, generic_arg, or call. Every returned usage also carries its classified context. Omit for all usages.")), + mcp.WithString("return_usage", mcp.Description("Filter call-site usages by how they consume the callee's return value: discarded, assigned, partially_ignored, returned, goroutine, deferred, argument, or condition. Every returned call usage also carries its classification when the extractor recorded one. Use before changing a function's return signature to see who actually uses the return. Omit for all usages.")), ), s.handleFindUsages, ) @@ -2212,6 +2213,10 @@ func (s *Server) handleFindUsages(ctx context.Context, req mcp.CallToolRequest) // / field / value / type / attribute / call) and optionally filter to // one context — `find_usages context:"parameter_type"`. annotateAndFilterUsageContext(sg, strings.ToLower(strings.TrimSpace(req.GetString("context", "")))) + // Surface the extractor-stamped return-usage classification on each + // call usage and optionally filter to one consumption shape — + // `find_usages return_usage:"discarded"`. + annotateAndFilterReturnUsage(sg, strings.ToLower(strings.TrimSpace(req.GetString("return_usage", "")))) if len(sg.Edges) == 0 { sg.Caveat = graph.CaveatForZeroEdge(s.graph, id) } @@ -2259,6 +2264,31 @@ func annotateAndFilterUsageContext(sg *query.SubGraph, contextFilter string) { sg.Edges = kept } +// annotateAndFilterReturnUsage copies the extractor-stamped +// return-usage label (Meta[MetaReturnUsage]) onto each usage edge's +// JSON-visible ReturnUsage field and, when usageFilter is non-empty, +// drops every usage whose label does not match — the engine behind +// `find_usages return_usage:"discarded"`. Edges without a label (non- +// call edges, unclassifiable sites) never match a filter. +func annotateAndFilterReturnUsage(sg *query.SubGraph, usageFilter string) { + if sg == nil { + return + } + for _, e := range sg.Edges { + e.ReturnUsage = graph.ReturnUsageOf(e) + } + if usageFilter == "" { + return + } + kept := sg.Edges[:0] + for _, e := range sg.Edges { + if e.ReturnUsage == usageFilter { + kept = append(kept, e) + } + } + sg.Edges = kept +} + // usageFileGroup is one file's worth of references from a // group_by:"file" find_usages response. type usageFileGroup struct { @@ -2270,11 +2300,12 @@ type usageFileGroup struct { // usageGroupItem is one reference inside a usageFileGroup -- the // line it sits on plus the enclosing symbol. type usageGroupItem struct { - Line int `json:"line"` - EdgeKind string `json:"edge_kind"` - Context string `json:"context,omitempty"` - SymbolID string `json:"symbol_id,omitempty"` - SymbolName string `json:"symbol_name,omitempty"` + Line int `json:"line"` + EdgeKind string `json:"edge_kind"` + Context string `json:"context,omitempty"` + ReturnUsage string `json:"return_usage,omitempty"` + SymbolID string `json:"symbol_id,omitempty"` + SymbolName string `json:"symbol_name,omitempty"` } // groupUsagesByFile buckets a find_usages SubGraph by the file each @@ -2302,7 +2333,7 @@ func groupUsagesByFile(sg *query.SubGraph) map[string]any { g = &usageFileGroup{File: file} groups[file] = g } - item := usageGroupItem{Line: e.Line, EdgeKind: string(e.Kind), Context: e.Context} + item := usageGroupItem{Line: e.Line, EdgeKind: string(e.Kind), Context: e.Context, ReturnUsage: e.ReturnUsage} if from != nil { item.SymbolID = from.ID item.SymbolName = from.Name diff --git a/internal/mcp/tools_enhancements.go b/internal/mcp/tools_enhancements.go index 1b8be072..8b07c480 100644 --- a/internal/mcp/tools_enhancements.go +++ b/internal/mcp/tools_enhancements.go @@ -377,6 +377,24 @@ func (s *Server) handleVerifyChange(ctx context.Context, req mcp.CallToolRequest for _, v := range result.Violations { fmt.Fprintf(&b, "%s %s %s:%d %s\n", v.Kind, v.SymbolID, v.FilePath, v.Line, v.Description) } + // One line per changed function: how its call sites consume the + // return value, so a return-signature change shows exactly which + // sites bind / return / branch on the result. + for _, ru := range result.ReturnUsage { + fmt.Fprintf(&b, "return_usage %s call_sites:%d", ru.SymbolID, ru.CallSites) + labels := make([]string, 0, len(ru.Counts)) + for label := range ru.Counts { + labels = append(labels, label) + } + sort.Strings(labels) + for _, label := range labels { + fmt.Fprintf(&b, " %s:%d", label, ru.Counts[label]) + } + if ru.Unclassified > 0 { + fmt.Fprintf(&b, " unclassified:%d", ru.Unclassified) + } + b.WriteString("\n") + } if result.Clean { fmt.Fprintf(&b, "clean: checked %d callers, %d implementors\n", result.CheckedCallers, result.CheckedImpls) } diff --git a/internal/parser/languages/csharp.go b/internal/parser/languages/csharp.go index d8188fd5..4a435869 100644 --- a/internal/parser/languages/csharp.go +++ b/internal/parser/languages/csharp.go @@ -106,6 +106,10 @@ type csharpDeferredCall struct { receiver string line int isMember bool + // returnUsage is how the call site consumes the return value + // (graph.ReturnUsage* label), classified at capture time and + // stamped as edge Meta on the EdgeCalls emitted for this site. + returnUsage string } // csharpDeferredLocal buffers a local variable declaration for the @@ -190,17 +194,19 @@ func (e *CSharpExtractor) Extract(filePath string, src []byte) (*parser.Extracti case m.Captures["callm.expr"] != nil: expr := m.Captures["callm.expr"] calls = append(calls, csharpDeferredCall{ - name: m.Captures["callm.method"].Text, - receiver: m.Captures["callm.receiver"].Text, - line: expr.StartLine + 1, - isMember: true, + name: m.Captures["callm.method"].Text, + receiver: m.Captures["callm.receiver"].Text, + line: expr.StartLine + 1, + isMember: true, + returnUsage: classifyReturnUsage(expr.Node, src, csharpReturnUsageSpec), }) case m.Captures["call.expr"] != nil: expr := m.Captures["call.expr"] calls = append(calls, csharpDeferredCall{ - name: m.Captures["call.name"].Text, - line: expr.StartLine + 1, + name: m.Captures["call.name"].Text, + line: expr.StartLine + 1, + returnUsage: classifyReturnUsage(expr.Node, src, csharpReturnUsageSpec), }) case m.Captures["lvar.def"] != nil: @@ -274,13 +280,16 @@ func (e *CSharpExtractor) Extract(filePath string, src []byte) (*parser.Extracti edge.Meta = map[string]any{"receiver_type": chainType} } } + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::" + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) } // .NET surfaces a symbol walk misses: DI registrations + COM diff --git a/internal/parser/languages/csharp_test.go b/internal/parser/languages/csharp_test.go index 38e337bf..5f826159 100644 --- a/internal/parser/languages/csharp_test.go +++ b/internal/parser/languages/csharp_test.go @@ -305,7 +305,7 @@ func TestCSharpExtractor_TypeEnv_Unknown(t *testing.T) { } } require.NotNil(t, processCall) - assert.Nil(t, processCall.Meta, "unknown type should not produce Meta") + assert.NotContains(t, processCall.Meta, "receiver_type", "unknown type should not produce a receiver_type hint") } func TestCSharpExtractor_TypeEnv_Chain(t *testing.T) { diff --git a/internal/parser/languages/golang.go b/internal/parser/languages/golang.go index a19020a9..97a1a00a 100644 --- a/internal/parser/languages/golang.go +++ b/internal/parser/languages/golang.go @@ -178,6 +178,11 @@ type goDeferredCall struct { line int // 1-based line of call_expression isSelector bool spawn bool // call is launched via `go` — emit EdgeSpawns alongside EdgeCalls + // returnUsage is how the call site consumes the return value + // (graph.ReturnUsage* label), classified from the call node's + // parent chain at capture time and stamped as edge Meta on every + // EdgeCalls emitted for this site. Empty when unclassifiable. + returnUsage string // gRPC server registration. Set when this call is the generated // `RegisterServer(registrar, impl)` helper: grpcRegService // is the service name and grpcRegArgNode is the second-argument AST @@ -353,9 +358,10 @@ func (e *GoExtractor) Extract(filePath string, src []byte) (*parser.ExtractionRe expr := m.Captures["call.expr"] callName := m.Captures["call.name"].Text dc := goDeferredCall{ - callName: callName, - line: expr.StartLine + 1, - spawn: isGoroutineSpawn(expr.Node), + callName: callName, + line: expr.StartLine + 1, + spawn: isGoroutineSpawn(expr.Node), + returnUsage: classifyReturnUsage(expr.Node, src, goReturnUsageSpec), } if svc, argNode, ok := grpcRegisterArgNode(expr.Node, callName); ok { dc.grpcRegService, dc.grpcRegArgNode = svc, argNode @@ -367,11 +373,12 @@ func (e *GoExtractor) Extract(filePath string, src []byte) (*parser.ExtractionRe method := m.Captures["callm.method"].Text receiver := m.Captures["callm.receiver"].Text dc := goDeferredCall{ - method: method, - receiver: receiver, - line: expr.StartLine + 1, - isSelector: true, - spawn: isGoroutineSpawn(expr.Node), + method: method, + receiver: receiver, + line: expr.StartLine + 1, + isSelector: true, + spawn: isGoroutineSpawn(expr.Node), + returnUsage: classifyReturnUsage(expr.Node, src, goReturnUsageSpec), } if svc, argNode, ok := grpcRegisterArgNode(expr.Node, method); ok { dc.grpcRegService, dc.grpcRegArgNode = svc, argNode @@ -717,7 +724,7 @@ func (e *GoExtractor) Extract(filePath string, src []byte) (*parser.ExtractionRe if c.isSelector { if svc, ok := goGRPCStubService(c, grpcStubVars); ok { target := "unresolved::grpc::" + svc + "::" + c.method - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: target, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{ @@ -725,7 +732,9 @@ func (e *GoExtractor) Extract(filePath string, src []byte) (*parser.ExtractionRe "grpc_service": svc, "grpc_method": c.method, }, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) emitGoSpawnEdge(c, callerID, target, filePath, result) continue } @@ -767,11 +776,13 @@ func (e *GoExtractor) Extract(filePath string, src []byte) (*parser.ExtractionRe if c.tempEnvDefault { meta["temporal_name_origin"] = "env_default" } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: target, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: meta, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) emitGoSpawnEdge(c, callerID, target, filePath, result) continue } @@ -788,6 +799,7 @@ func (e *GoExtractor) Extract(filePath string, src []byte) (*parser.ExtractionRe applyGoTemporalHandlerMeta(edge, c) applyGoTemporalSignalQueryMeta(edge, c) applyGoTemporalStartMeta(edge, c) + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) emitGoSpawnEdge(c, callerID, target, filePath, result) continue @@ -803,6 +815,7 @@ func (e *GoExtractor) Extract(filePath string, src []byte) (*parser.ExtractionRe applyGoTemporalHandlerMeta(edge, c) applyGoTemporalSignalQueryMeta(edge, c) applyGoTemporalStartMeta(edge, c) + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) emitGoSpawnEdge(c, callerID, target, filePath, result) continue @@ -854,6 +867,7 @@ func (e *GoExtractor) Extract(filePath string, src []byte) (*parser.ExtractionRe applyGoTemporalRegisterMeta(edge, c) applyGoTemporalSignalQueryMeta(edge, c) applyGoTemporalStartMeta(edge, c) + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) emitGoSpawnEdge(c, callerID, target, filePath, result) } diff --git a/internal/parser/languages/golang_test.go b/internal/parser/languages/golang_test.go index 418c647a..618b30fd 100644 --- a/internal/parser/languages/golang_test.go +++ b/internal/parser/languages/golang_test.go @@ -403,7 +403,7 @@ func main() { } } require.NotNil(t, saveCall, "expected a call edge to Save") - assert.Nil(t, saveCall.Meta, "unknown type should not produce Meta") + assert.NotContains(t, saveCall.Meta, "receiver_type", "unknown type should not produce a receiver_type hint") } // --- Tier 2: Chain resolution tests --- @@ -474,7 +474,7 @@ func main() { } } require.NotNil(t, finishCall) - assert.Nil(t, finishCall.Meta, "unresolvable chain should not produce Meta") + assert.NotContains(t, finishCall.Meta, "receiver_type", "unresolvable chain should not produce a receiver_type hint") } func TestGoExtractor_ReturnType(t *testing.T) { diff --git a/internal/parser/languages/java.go b/internal/parser/languages/java.go index 1d1d711a..70d6866f 100644 --- a/internal/parser/languages/java.go +++ b/internal/parser/languages/java.go @@ -98,6 +98,10 @@ type javaDeferredCall struct { // "signal"/"query" method names from false-matching. tempSignalKind string tempSignalName string + // returnUsage is how the call site consumes the return value + // (graph.ReturnUsage* label), classified at capture time and + // stamped as edge Meta on the EdgeCalls emitted for this site. + returnUsage string } // javaDeferredVar buffers a variable declaration for the post-pass @@ -186,10 +190,11 @@ func (e *JavaExtractor) Extract(filePath string, src []byte) (*parser.Extraction expr := m.Captures["callm.expr"] method := m.Captures["callm.method"].Text dc := javaDeferredCall{ - name: method, - receiver: m.Captures["callm.receiver"].Text, - line: expr.StartLine + 1, - isSelector: true, + name: method, + receiver: m.Captures["callm.receiver"].Text, + line: expr.StartLine + 1, + isSelector: true, + returnUsage: classifyReturnUsage(expr.Node, src, javaReturnUsageSpec), } if wf := javaTemporalStartWorkflowName(expr.Node, method, src); wf != "" { dc.tempStartWorkflow = wf @@ -205,8 +210,9 @@ func (e *JavaExtractor) Extract(filePath string, src []byte) (*parser.Extraction // edges, so we mirror that here. expr := m.Captures["call.expr"] calls = append(calls, javaDeferredCall{ - name: m.Captures["call.name"].Text, - line: expr.StartLine + 1, + name: m.Captures["call.name"].Text, + line: expr.StartLine + 1, + returnUsage: classifyReturnUsage(expr.Node, src, javaReturnUsageSpec), }) } }) @@ -288,6 +294,7 @@ func (e *JavaExtractor) Extract(filePath string, src []byte) (*parser.Extraction } } } + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) // Temporal workflow START (consumer side): emit a via=temporal.start @@ -295,7 +302,7 @@ func (e *JavaExtractor) Extract(filePath string, src []byte) (*parser.Extraction // it to the registered workflow — which may be implemented in a Go // repo — so get_callers on that workflow surfaces this Java service. if c.tempStartWorkflow != "" { - result.Edges = append(result.Edges, &graph.Edge{ + startEdge := &graph.Edge{ From: callerID, To: "unresolved::temporal::workflow::" + c.tempStartWorkflow, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{ @@ -303,7 +310,9 @@ func (e *JavaExtractor) Extract(filePath string, src []byte) (*parser.Extraction "temporal_kind": "workflow", "temporal_name": c.tempStartWorkflow, }, - }) + } + stampReturnUsage(startEdge, c.returnUsage) + result.Edges = append(result.Edges, startEdge) } // Outbound signal-send / query-call on an untyped WorkflowStub, // symmetric with the Go side (#81). Gated on the receiver's inferred @@ -314,7 +323,7 @@ func (e *JavaExtractor) Extract(filePath string, src []byte) (*parser.Extraction if c.tempSignalKind == "query" { via = "temporal.query-call" } - result.Edges = append(result.Edges, &graph.Edge{ + signalEdge := &graph.Edge{ From: callerID, To: "unresolved::*." + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{ @@ -322,7 +331,9 @@ func (e *JavaExtractor) Extract(filePath string, src []byte) (*parser.Extraction "temporal_kind": c.tempSignalKind, "temporal_name": c.tempSignalName, }, - }) + } + stampReturnUsage(signalEdge, c.returnUsage) + result.Edges = append(result.Edges, signalEdge) } } diff --git a/internal/parser/languages/javascript.go b/internal/parser/languages/javascript.go index 9ba81b39..bcea7578 100644 --- a/internal/parser/languages/javascript.go +++ b/internal/parser/languages/javascript.go @@ -92,6 +92,10 @@ type jsDeferredCall struct { // expr is the call_expression node, kept for member calls so the // post-pass can inspect arguments for pub/sub topic detection. expr *sitter.Node + // returnUsage is how the call site consumes the return value + // (graph.ReturnUsage* label), classified at capture time and + // stamped as edge Meta on every EdgeCalls emitted for this site. + returnUsage string } type jsDeferredVar struct { @@ -183,10 +187,11 @@ func (e *JavaScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr case m.Captures["callm.expr"] != nil: expr := m.Captures["callm.expr"] dc := jsDeferredCall{ - name: m.Captures["callm.method"].Text, - line: expr.StartLine + 1, - isMember: true, - expr: expr.Node, + name: m.Captures["callm.method"].Text, + line: expr.StartLine + 1, + isMember: true, + expr: expr.Node, + returnUsage: classifyReturnUsage(expr.Node, src, jsTSReturnUsageSpec), } if r := m.Captures["callm.receiver"]; r != nil { dc.receiver = r.Text @@ -196,8 +201,9 @@ func (e *JavaScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr case m.Captures["call.expr"] != nil: expr := m.Captures["call.expr"] calls = append(calls, jsDeferredCall{ - name: m.Captures["call.name"].Text, - line: expr.StartLine + 1, + name: m.Captures["call.name"].Text, + line: expr.StartLine + 1, + returnUsage: classifyReturnUsage(expr.Node, src, jsTSReturnUsageSpec), }) case m.Captures["var.def"] != nil: @@ -276,58 +282,72 @@ func (e *JavaScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr // package can't capture this edge in the name-only fallback. if members, ok := objLiteralMembers[c.receiver]; ok { if memberID, ok := members[c.name]; ok { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: memberID, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Origin: graph.OriginASTResolved, Confidence: 0.92, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } } // Store-factory chained call: `useStore.getState().action()`. if binding, ok := jsParseGetStateChain(c.receiver); ok { if memberID := objLiteralMembers[binding][c.name]; memberID != "" { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: memberID, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Origin: graph.OriginASTResolved, Confidence: 0.9, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::*." + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{"via": "store-factory", "store_binding": binding, "store_action": c.name}, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::*." + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } // Store-factory destructured call: `const {a}=useStore.getState(); a()`. if binding, ok := destructured[c.name]; ok { if memberID := objLiteralMembers[binding][c.name]; memberID != "" { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: memberID, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Origin: graph.OriginASTResolved, Confidence: 0.9, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::" + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{"via": "store-factory", "store_binding": binding, "store_action": c.name}, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::" + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) } // --- Event pub/sub edges --- @@ -365,11 +385,13 @@ func (e *JavaScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr if callerID == "" { continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: rnNativePlaceholder(module, c.name), Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{"via": rnNativeVia, "rn_module": module, "rn_method": c.name}, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) } // Test-runner classification (Mocha / Bun-test / Jest / Vitest / diff --git a/internal/parser/languages/python.go b/internal/parser/languages/python.go index 77c2a0e6..9c4144ba 100644 --- a/internal/parser/languages/python.go +++ b/internal/parser/languages/python.go @@ -96,6 +96,10 @@ type pyDeferredCall struct { // the literal method name when the subscript is a string literal. dynShape string dynKey string + // returnUsage is how the call site consumes the return value + // (graph.ReturnUsage* label), classified at capture time and + // stamped as edge Meta on every EdgeCalls emitted for this site. + returnUsage string } // pyStringLiteralValue returns the unquoted value of a Python string-literal @@ -169,19 +173,21 @@ func (e *PythonExtractor) Extract(filePath string, src []byte) (*parser.Extracti case m.Captures["callattr.expr"] != nil: expr := m.Captures["callattr.expr"] calls = append(calls, pyDeferredCall{ - name: m.Captures["callattr.method"].Text, - receiver: m.Captures["callattr.receiver"].Text, - line: expr.StartLine + 1, - isAttr: true, - expr: expr.Node, + name: m.Captures["callattr.method"].Text, + receiver: m.Captures["callattr.receiver"].Text, + line: expr.StartLine + 1, + isAttr: true, + expr: expr.Node, + returnUsage: classifyReturnUsage(expr.Node, src, pyReturnUsageSpec), }) case m.Captures["call.expr"] != nil: expr := m.Captures["call.expr"] calls = append(calls, pyDeferredCall{ - name: m.Captures["call.name"].Text, - line: expr.StartLine + 1, - expr: expr.Node, + name: m.Captures["call.name"].Text, + line: expr.StartLine + 1, + expr: expr.Node, + returnUsage: classifyReturnUsage(expr.Node, src, pyReturnUsageSpec), }) case m.Captures["subcall.expr"] != nil: @@ -190,9 +196,10 @@ func (e *PythonExtractor) Extract(filePath string, src []byte) (*parser.Extracti // edge by itself unless that pass is enabled. expr := m.Captures["subcall.expr"] dc := pyDeferredCall{ - line: expr.StartLine + 1, - expr: expr.Node, - dynShape: "computed_member", + line: expr.StartLine + 1, + expr: expr.Node, + dynShape: "computed_member", + returnUsage: classifyReturnUsage(expr.Node, src, pyReturnUsageSpec), } if r := m.Captures["subcall.receiver"]; r != nil { dc.receiver = r.Text @@ -251,20 +258,24 @@ func (e *PythonExtractor) Extract(filePath string, src []byte) (*parser.Extracti if c.receiver != "" { meta["dyn_receiver"] = c.receiver } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::*", Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: meta, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } if c.isAttr { // Module-qualified call (requests.get, np.array, os.path.join): // attach the import path so resolver can classify externally. if importPath, ok := lookupPyImport(c.receiver, imports); ok { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::extern::" + importPath + "::" + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } @@ -279,6 +290,7 @@ func (e *PythonExtractor) Extract(filePath string, src []byte) (*parser.Extracti edge.Meta = map[string]any{"receiver_type": chainType} } } + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) continue } @@ -291,15 +303,19 @@ func (e *PythonExtractor) Extract(filePath string, src []byte) (*parser.Extracti // `import numpy as np; np.array(...)` both attribute to // numpy after the resolver post-pass. if importPath, ok := lookupPyImport(c.name, imports); ok { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::extern::" + importPath + "::" + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) } else { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::" + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) } // FastAPI dependency injection: Depends(target) — emit a direct diff --git a/internal/parser/languages/python_test.go b/internal/parser/languages/python_test.go index 6c1d663c..b8b89025 100644 --- a/internal/parser/languages/python_test.go +++ b/internal/parser/languages/python_test.go @@ -177,7 +177,7 @@ def main(): } } require.NotNil(t, processCall) - assert.Nil(t, processCall.Meta, "unknown type should not produce Meta") + assert.NotContains(t, processCall.Meta, "receiver_type", "unknown type should not produce a receiver_type hint") } func TestPythonExtractor_FastAPIDepends(t *testing.T) { diff --git a/internal/parser/languages/return_usage.go b/internal/parser/languages/return_usage.go new file mode 100644 index 00000000..106b9b21 --- /dev/null +++ b/internal/parser/languages/return_usage.go @@ -0,0 +1,558 @@ +package languages + +import ( + "strings" + + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" +) + +// Return-usage classification: given a call node, walk its parent chain +// and decide how the call site consumes the callee's return value +// (graph.ReturnUsage* labels). One engine serves every language; the +// per-language differences — which node kinds mean "assignment", +// "return", "argument list", "condition", and which kinds the walk +// passes through — live in a returnUsageSpec table per grammar. +// +// The walk is conservative: a parent kind no table covers stops the +// classification and the call edge carries no label, rather than a +// wrong one. Two deliberate label foldings: a call whose result is the +// receiver of a chained call (`f().g()`) is classified as argument +// (the value feeds another call, same as a literal argument), and an +// assignment whose every sink is blank (`_, _ = f()`) is classified as +// discarded — partially_ignored is reserved for the mixed case. + +// returnUsageSpec drives classifyReturnUsage for one language. All kind +// sets refer to tree-sitter node kinds seen on the call's parent chain. +type returnUsageSpec struct { + // goroutine / deferred: statement kinds that launch or defer the + // call (Go-only — other languages leave these nil and never produce + // the corresponding labels). + goroutine map[string]bool + deferred map[string]bool + // returns: return statements plus expression-bodied callable kinds + // (arrow functions, lambdas, Rust fn tails) whose value becomes the + // enclosing callable's return value. + returns map[string]bool + // assign: assignment-like kind → field name of its sink side. An + // empty field name means "no countable sinks" (e.g. a Go channel + // send) and classifies directly as assigned. + assign map[string]string + // sinkLists: kinds whose children are individual binding targets + // (Go expression_list, Python pattern_list, Rust tuple_pattern, …). + sinkLists map[string]bool + // argument: argument-list kinds. argOwners guards grammars (Ruby) + // that reuse the list kind under non-call constructs — when the + // list's parent is not an owner the walk continues instead. + argument map[string]bool + argOwners map[string]bool + // condition: conditional-statement kind → field names whose child + // means "inside the condition". An empty field list matches any + // direct child that is not a block (Go's while-style for statement + // carries its condition without a field name). + condition map[string][]string + // conditionExcludeKinds names child node kinds the empty-field + // positional condition fallback must reject in addition to blocks — + // the branch/arm containers a conditional shares its child list with + // (C# switch_expression mixes its governing expression and its + // fieldless switch_expression_arm children at the same level). Only + // the governing expression is the condition; the arms are not. + conditionExcludeKinds map[string]bool + // chain: member-access kind → field name of the receiver/object + // position. A call in that position feeds another expression and + // classifies as argument; any other position continues the walk. + chain map[string]string + // discard: bare expression-statement kinds. + discard map[string]bool + // body: statement-container kinds for grammars without an + // expression-statement wrapper (Ruby). A call in tail position + // continues the walk (implicit-return languages route it to the + // enclosing returns kind); any other position is discarded. + body map[string]bool + // closures: closure-boundary kinds (Ruby block / do_block) that the + // walk must NOT cross. A call that reaches one is the closure's tail + // value — returned from the closure, not from the enclosing method or + // bound by the enclosing assignment. Classifying it as returned and + // stopping keeps a block-internal call from inheriting the label of + // whatever consumes the block (`x.map { f() }` must not mark f as + // assigned to the map's receiver, nor as the method's return). Other + // grammars register their closure forms (arrow_function, lambda, + // closure_expression) in returns, which already halts the walk at the + // boundary; Ruby's blocks have no such kind, so they live here. + closures map[string]bool + // transparent: kinds the walk passes straight through (parenthesised + // and binary expressions, await, casts, literal containers, …). + transparent map[string]bool + // blocks: block/body kinds — used by the empty-field condition rule + // to avoid classifying a branch body as a condition. + blocks map[string]bool +} + +// returnUsageMaxHops bounds the parent walk. Real classifications +// resolve within a handful of hops; the cap only guards degenerate +// trees (deeply nested transparent expressions). +const returnUsageMaxHops = 32 + +// classifyReturnUsage walks up from a call node and returns the +// graph.ReturnUsage* label for the call site, or "" when the parent +// chain doesn't match any classification the spec covers. +func classifyReturnUsage(call *sitter.Node, src []byte, spec *returnUsageSpec) string { + if call == nil || spec == nil { + return "" + } + child := call + parent := call.Parent() + for hops := 0; parent != nil && hops < returnUsageMaxHops; hops++ { + kind := parent.Type() + switch { + case spec.goroutine[kind]: + return graph.ReturnUsageGoroutine + case spec.deferred[kind]: + return graph.ReturnUsageDeferred + case spec.returns[kind]: + return graph.ReturnUsageReturned + case spec.closures[kind]: + // Closure boundary (Ruby block / do_block): a call reaches + // one only as the closure's tail value, so it is returned from + // the closure. Stop here — the walk must not continue into the + // statement that consumes the closure, or a block-internal + // call would inherit that statement's label. + return graph.ReturnUsageReturned + default: + } + if field, ok := spec.assign[kind]; ok { + return classifyAssign(parent, child, field, src, spec) + } + if spec.argument[kind] { + owner := parent.Parent() + if len(spec.argOwners) == 0 || (owner != nil && spec.argOwners[owner.Type()]) { + return graph.ReturnUsageArgument + } + // Argument-list kind under a non-call owner (Ruby `return + // f()`): treat as transparent and keep walking. + child, parent = parent, owner + continue + } + if fields, ok := spec.condition[kind]; ok { + if matchesConditionField(parent, child, fields, spec) { + return graph.ReturnUsageCondition + } + // We arrived from a non-condition position (a branch body, + // an init clause): keep walking — the statement's own fate + // decides (e.g. `let x = if c { f() } else …` → assigned). + child, parent = parent, parent.Parent() + continue + } + if field, ok := spec.chain[kind]; ok { + if c := parent.ChildByFieldName(field); c != nil && sameSpan(c, child) { + return graph.ReturnUsageArgument + } + child, parent = parent, parent.Parent() + continue + } + if spec.discard[kind] { + return graph.ReturnUsageDiscarded + } + if spec.body[kind] { + if isLastNamedChild(parent, child) { + // Tail position in an implicit-return language: the + // container is transparent; the enclosing callable (in + // spec.returns) or outer statement decides. + child, parent = parent, parent.Parent() + continue + } + return graph.ReturnUsageDiscarded + } + if spec.transparent[kind] { + child, parent = parent, parent.Parent() + continue + } + return "" + } + return "" +} + +// classifyAssign decides between assigned / partially_ignored / +// discarded for a call on the value side of an assignment-like node. +// A call inside the sink side itself (`m[f()] = v`) is not a binding +// of the call's result and stays unclassified. +func classifyAssign(parent, child *sitter.Node, sinkField string, src []byte, spec *returnUsageSpec) string { + if sinkField == "" { + return graph.ReturnUsageAssigned + } + total, blank := 0, 0 + for i := 0; i < int(parent.ChildCount()); i++ { + if parent.FieldNameForChild(i) != sinkField { + continue + } + c := parent.Child(i) + if c == nil { + continue + } + if sameSpan(c, child) || containsSpan(c, child) { + return "" + } + t, b := countSinkLeaves(c, src, spec) + total += t + blank += b + } + switch { + case total == 0: + // The sink field is named but absent in this tree — a binding + // form with no targets at all, e.g. Go `for range f()`. Nothing + // captures the result; it is consumed without being bound. (The + // genuinely sinkless forms that should read as assigned — a Go + // channel send — take the sinkField == "" path above and never + // reach here.) + return graph.ReturnUsageDiscarded + case blank == 0: + return graph.ReturnUsageAssigned + case blank == total: + return graph.ReturnUsageDiscarded + default: + return graph.ReturnUsagePartiallyIgnored + } +} + +// countSinkLeaves counts the individual binding targets under a sink +// node and how many of them are blank (`_`). List/tuple kinds count +// each element; anything else is a single sink. Rust's `_` wildcard is +// an unnamed node, so the element loop keeps unnamed children whose +// text is exactly "_". +func countSinkLeaves(n *sitter.Node, src []byte, spec *returnUsageSpec) (total, blank int) { + isBlank := func(c *sitter.Node) bool { + return strings.TrimSpace(c.Content(src)) == "_" + } + if spec.sinkLists[n.Type()] { + for i := 0; i < int(n.ChildCount()); i++ { + c := n.Child(i) + if c == nil { + continue + } + if !c.IsNamed() && !isBlank(c) { + continue + } + total++ + if isBlank(c) { + blank++ + } + } + return total, blank + } + if isBlank(n) { + return 1, 1 + } + return 1, 0 +} + +// matchesConditionField reports whether child sits in one of the +// statement's condition fields. An empty field list applies the +// positional fallback: any direct non-block child is the condition +// (Go's `for f() { … }` carries the condition without a field). +func matchesConditionField(parent, child *sitter.Node, fields []string, spec *returnUsageSpec) bool { + if len(fields) == 0 { + ck := child.Type() + return !spec.blocks[ck] && !spec.conditionExcludeKinds[ck] + } + for _, f := range fields { + if c := parent.ChildByFieldName(f); c != nil && sameSpan(c, child) { + return true + } + } + return false +} + +// sameSpan reports whether two nodes are the same node, compared by +// byte span (wrapper *Node pointers are not identity-stable). +func sameSpan(a, b *sitter.Node) bool { + return a.StartByte() == b.StartByte() && a.EndByte() == b.EndByte() +} + +// containsSpan reports whether outer strictly contains inner. +func containsSpan(outer, inner *sitter.Node) bool { + return outer.StartByte() <= inner.StartByte() && inner.EndByte() <= outer.EndByte() && + (outer.StartByte() != inner.StartByte() || outer.EndByte() != inner.EndByte()) +} + +// isLastNamedChild reports whether child is parent's last named child. +func isLastNamedChild(parent, child *sitter.Node) bool { + n := int(parent.NamedChildCount()) + if n == 0 { + return false + } + last := parent.NamedChild(n - 1) + return last != nil && sameSpan(last, child) +} + +// stampReturnUsage records the return-usage label on a call edge's +// Meta. No-op for an empty label so unclassifiable sites carry no key. +func stampReturnUsage(e *graph.Edge, usage string) { + if e == nil || usage == "" { + return + } + if e.Meta == nil { + e.Meta = map[string]any{} + } + e.Meta[graph.MetaReturnUsage] = usage +} + +func kindSet(kinds ...string) map[string]bool { + m := make(map[string]bool, len(kinds)) + for _, k := range kinds { + m[k] = true + } + return m +} + +// --- Per-language tables -------------------------------------------- + +var goReturnUsageSpec = &returnUsageSpec{ + goroutine: kindSet("go_statement"), + deferred: kindSet("defer_statement"), + returns: kindSet("return_statement"), + assign: map[string]string{ + "assignment_statement": "left", + "short_var_declaration": "left", + "var_spec": "name", + "range_clause": "left", + // A channel send binds the value into the channel; there is no + // countable sink list. + "send_statement": "", + }, + sinkLists: kindSet("expression_list"), + argument: kindSet("argument_list", "special_argument_list"), + argOwners: kindSet("call_expression"), + condition: map[string][]string{ + "if_statement": {"condition"}, + "for_statement": {}, // while-style condition has no field + "for_clause": {"condition"}, + "expression_switch_statement": {"value"}, + "type_switch_statement": {"value"}, + }, + chain: map[string]string{ + "selector_expression": "operand", + "call_expression": "function", // curried `f()()` + }, + discard: kindSet("expression_statement"), + transparent: kindSet( + "expression_list", "parenthesized_expression", "binary_expression", + "unary_expression", "index_expression", "slice_expression", + "type_assertion_expression", "type_conversion_expression", + "literal_element", "keyed_element", "literal_value", "composite_literal", + ), + blocks: kindSet("block"), +} + +var pyReturnUsageSpec = &returnUsageSpec{ + returns: kindSet("return_statement", "lambda"), + assign: map[string]string{ + "assignment": "left", + "augmented_assignment": "left", + }, + sinkLists: kindSet("pattern_list", "tuple_pattern"), + argument: kindSet("argument_list"), + argOwners: kindSet("call"), + condition: map[string][]string{ + "if_statement": {"condition"}, + "elif_clause": {"condition"}, + "while_statement": {"condition"}, + "for_statement": {"right"}, + }, + chain: map[string]string{ + "attribute": "object", + "call": "function", // curried `f()()` + }, + discard: kindSet("expression_statement"), + transparent: kindSet( + "await", "parenthesized_expression", "binary_operator", + "boolean_operator", "comparison_operator", "not_operator", + "unary_operator", "conditional_expression", "subscript", + "expression_list", "tuple", "list", "set", "dictionary", "pair", + ), + blocks: kindSet("block"), +} + +// jsTSReturnUsageSpec covers the javascript, typescript, and tsx +// grammars — the kinds the walk touches are shared; the TS-only +// wrapper expressions (as / satisfies / non-null) are transparent +// no-ops under the JS grammar. +var jsTSReturnUsageSpec = &returnUsageSpec{ + returns: kindSet("return_statement", "arrow_function"), + assign: map[string]string{ + "variable_declarator": "name", + "assignment_expression": "left", + "augmented_assignment_expression": "left", + "public_field_definition": "name", + }, + sinkLists: kindSet("array_pattern"), + argument: kindSet("arguments"), + argOwners: kindSet("call_expression", "new_expression"), + condition: map[string][]string{ + "if_statement": {"condition"}, + "while_statement": {"condition"}, + "do_statement": {"condition"}, + "for_statement": {"condition"}, + "for_in_statement": {"right"}, + "switch_statement": {"value"}, + }, + chain: map[string]string{ + "member_expression": "object", + "call_expression": "function", // curried `f()()` + }, + discard: kindSet("expression_statement"), + transparent: kindSet( + "parenthesized_expression", "binary_expression", "unary_expression", + "ternary_expression", "await_expression", "subscript_expression", + "template_substitution", "spread_element", "object", "pair", "array", + "sequence_expression", + // TypeScript-only wrappers; absent from the JS grammar. + "as_expression", "satisfies_expression", "non_null_expression", + "type_assertion", + ), + blocks: kindSet("statement_block"), +} + +var javaReturnUsageSpec = &returnUsageSpec{ + returns: kindSet("return_statement", "lambda_expression"), + assign: map[string]string{ + "variable_declarator": "name", + "assignment_expression": "left", + }, + argument: kindSet("argument_list"), + argOwners: kindSet("method_invocation", "object_creation_expression", "explicit_constructor_invocation"), + condition: map[string][]string{ + "if_statement": {"condition"}, + "while_statement": {"condition"}, + "do_statement": {"condition"}, + "for_statement": {"condition"}, + "enhanced_for_statement": {"value"}, + "switch_expression": {"condition"}, + }, + chain: map[string]string{ + "method_invocation": "object", + "field_access": "object", + }, + discard: kindSet("expression_statement"), + transparent: kindSet( + "parenthesized_expression", "binary_expression", "unary_expression", + "ternary_expression", "cast_expression", "array_access", + "array_initializer", "argument", "instanceof_expression", + ), + blocks: kindSet("block"), +} + +var rustReturnUsageSpec = &returnUsageSpec{ + // function_item / closure_expression terminate the tail-expression + // walk: a call in block-tail position reaches them through the + // transparent block and is the implicit return value. + returns: kindSet("return_expression", "function_item", "closure_expression"), + assign: map[string]string{ + "let_declaration": "pattern", + "assignment_expression": "left", + }, + sinkLists: kindSet("tuple_pattern"), + argument: kindSet("arguments"), + argOwners: kindSet("call_expression"), + condition: map[string][]string{ + "if_expression": {"condition"}, + "while_expression": {"condition"}, + "match_expression": {"value"}, + "for_expression": {"value"}, + }, + chain: map[string]string{ + "field_expression": "value", + "call_expression": "function", // curried `f()()` + }, + discard: kindSet("expression_statement"), + transparent: kindSet( + "block", "parenthesized_expression", "binary_expression", + "unary_expression", "reference_expression", "try_expression", + "await_expression", "let_condition", "match_arm", "match_block", + "else_clause", "tuple_expression", "array_expression", + "struct_expression", "field_initializer", "field_initializer_list", + "index_expression", "range_expression", "type_cast_expression", + ), + blocks: kindSet("block"), +} + +var rubyReturnUsageSpec = &returnUsageSpec{ + // method / singleton_method terminate the tail-position walk + // through body containers — Ruby's implicit return. + returns: kindSet("return", "method", "singleton_method"), + assign: map[string]string{ + "assignment": "left", + "operator_assignment": "left", + }, + sinkLists: kindSet("left_assignment_list"), + argument: kindSet("argument_list"), + argOwners: kindSet("call"), + condition: map[string][]string{ + "if": {"condition"}, + "unless": {"condition"}, + "elsif": {"condition"}, + "while": {"condition"}, + "until": {"condition"}, + "case": {"value"}, + "if_modifier": {"condition"}, + "unless_modifier": {"condition"}, + "while_modifier": {"condition"}, + "until_modifier": {"condition"}, + }, + chain: map[string]string{ + "call": "receiver", + }, + // Ruby has no expression-statement wrapper: statements sit directly + // in body containers, and the tail expression is the implicit + // return value. + body: kindSet("body_statement", "then", "else", "do", "block_body", "program"), + // block / do_block are closure boundaries, not transparent wrappers: + // a call in a block's tail position is the block's value, not the + // enclosing method's return nor a binding of whatever consumes the + // block. Crossing them mislabels `x.map { f() }` (f would read as + // assigned to x.map's receiver) and `return x.each { f() }` (f would + // read as the method's return). + closures: kindSet("block", "do_block"), + transparent: kindSet( + "parenthesized_statements", "binary", "unary", "conditional", + "begin", "argument_list_with_parens", + "right_assignment_list", "element_reference", "pair", "array", "hash", + ), + blocks: kindSet("then", "do", "block_body", "body_statement"), +} + +var csharpReturnUsageSpec = &returnUsageSpec{ + returns: kindSet("return_statement", "lambda_expression", "arrow_expression_clause"), + assign: map[string]string{ + "variable_declarator": "name", + "assignment_expression": "left", + }, + argument: kindSet("argument_list"), + argOwners: kindSet("invocation_expression", "object_creation_expression", "implicit_object_creation_expression"), + condition: map[string][]string{ + "if_statement": {"condition"}, + "while_statement": {"condition"}, + "do_statement": {"condition"}, + "for_statement": {"condition"}, + "foreach_statement": {"right"}, + "switch_statement": {"value"}, + // A switch_expression carries no `value` field: its governing + // expression is the first child, ahead of the `switch` keyword, + // and the arms are fieldless switch_expression_arm siblings. The + // empty field list selects it positionally; conditionExcludeKinds + // keeps an arm from ever reading as the condition. + "switch_expression": {}, + }, + conditionExcludeKinds: kindSet("switch_expression_arm"), + chain: map[string]string{ + "member_access_expression": "expression", + }, + discard: kindSet("expression_statement"), + transparent: kindSet( + "argument", "parenthesized_expression", "binary_expression", + "prefix_unary_expression", "postfix_unary_expression", + "cast_expression", "conditional_expression", "await_expression", + "element_access_expression", "interpolation", + "conditional_access_expression", + ), + blocks: kindSet("block"), +} diff --git a/internal/parser/languages/return_usage_test.go b/internal/parser/languages/return_usage_test.go new file mode 100644 index 00000000..cca1571b --- /dev/null +++ b/internal/parser/languages/return_usage_test.go @@ -0,0 +1,518 @@ +package languages + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zzet/gortex/internal/graph" +) + +// callReturnUsage returns the stamped return-usage label of the first +// EdgeCalls edge targeting callee. The empty string means the edge +// exists but carries no label. +func callReturnUsage(t *testing.T, edges []*graph.Edge, callee string) string { + t.Helper() + for _, e := range edges { + if e.Kind != graph.EdgeCalls { + continue + } + if strings.HasSuffix(e.To, "::"+callee) || strings.HasSuffix(e.To, "*."+callee) { + return graph.ReturnUsageOf(e) + } + } + t.Fatalf("no call edge to %q found", callee) + return "" +} + +func TestGoReturnUsage(t *testing.T) { + src := []byte(`package main + +func caller() (int, error) { + f1() + x := f2() + var v = f3() + y = f4() + a, _ := f5() + _, _ = f6() + go f7() + defer f8() + g(f9()) + if f10() { + } + for f11() { + } + switch f12() { + } + f13().Method() + ch <- f14() + for _, item := range f15() { + } + if err := f16(); err != nil { + } + m := T{Field: f17()} + return f18() +} +`) + e := NewGoExtractor() + result, err := e.Extract("main.go", src) + require.NoError(t, err) + defer result.Tree.Release() + + for callee, want := range map[string]string{ + "f1": graph.ReturnUsageDiscarded, + "f2": graph.ReturnUsageAssigned, + "f3": graph.ReturnUsageAssigned, + "f4": graph.ReturnUsageAssigned, + "f5": graph.ReturnUsagePartiallyIgnored, + "f6": graph.ReturnUsageDiscarded, // every sink blank — value thrown away + "f7": graph.ReturnUsageGoroutine, + "f8": graph.ReturnUsageDeferred, + "f9": graph.ReturnUsageArgument, + "f10": graph.ReturnUsageCondition, + "f11": graph.ReturnUsageCondition, + "f12": graph.ReturnUsageCondition, + "f13": graph.ReturnUsageArgument, // chained receiver feeds .Method() + "f14": graph.ReturnUsageAssigned, // channel send binds the value + "f15": graph.ReturnUsagePartiallyIgnored, + "f16": graph.ReturnUsageAssigned, + "f17": graph.ReturnUsageAssigned, // composite-literal field, then bound + "f18": graph.ReturnUsageReturned, + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +// A Go range clause with no binding targets (`for range f()`) consumes +// the result by iterating it but captures nothing — discarded, not +// assigned. The zero-sink fold must not read it as a binding. +func TestGoReturnUsageRangeNoVars(t *testing.T) { + src := []byte(`package main + +func caller() { + for range f1() { + } + for k := range f2() { + _ = k + } +} +`) + e := NewGoExtractor() + result, err := e.Extract("main.go", src) + require.NoError(t, err) + defer result.Tree.Release() + + assert.Equal(t, graph.ReturnUsageDiscarded, callReturnUsage(t, result.Edges, "f1"), + "for range with no variables captures nothing") + assert.Equal(t, graph.ReturnUsageAssigned, callReturnUsage(t, result.Edges, "f2"), + "for k := range binds the induction variable") +} + +func TestPythonReturnUsage(t *testing.T) { + src := []byte(`def caller(self): + f1() + x = f2() + x += f3() + a, _ = f4() + _ = f5() + g(f6()) + if f7(): + pass + while f8(): + pass + for item in f9(): + pass + f10().chained() + h = lambda: f11() + self.m1() + y = self.m2() + return f12() +`) + e := NewPythonExtractor() + result, err := e.Extract("main.py", src) + require.NoError(t, err) + + for callee, want := range map[string]string{ + "f1": graph.ReturnUsageDiscarded, + "f2": graph.ReturnUsageAssigned, + "f3": graph.ReturnUsageAssigned, + "f4": graph.ReturnUsagePartiallyIgnored, + "f5": graph.ReturnUsageDiscarded, + "f6": graph.ReturnUsageArgument, + "f7": graph.ReturnUsageCondition, + "f8": graph.ReturnUsageCondition, + "f9": graph.ReturnUsageCondition, + "f10": graph.ReturnUsageArgument, + "f11": graph.ReturnUsageReturned, // lambda body is its return value + "m1": graph.ReturnUsageDiscarded, + "m2": graph.ReturnUsageAssigned, + "f12": graph.ReturnUsageReturned, + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +func TestJavaScriptReturnUsage(t *testing.T) { + src := []byte(`function caller() { + f1(); + const x = f2(); + y = f3(); + y += f4(); + g(f5()); + if (f6()) { } + while (f7()) { } + for (let i = 0; f8(); i++) { } + for (const a of f9()) { } + switch (f10()) { } + f11().chained(); + const h = () => f12(); + new C(f13()); + obj.m1(); + const z = obj.m2(); + return f14(); +} +`) + e := NewJavaScriptExtractor() + result, err := e.Extract("main.js", src) + require.NoError(t, err) + + for callee, want := range map[string]string{ + "f1": graph.ReturnUsageDiscarded, + "f2": graph.ReturnUsageAssigned, + "f3": graph.ReturnUsageAssigned, + "f4": graph.ReturnUsageAssigned, + "f5": graph.ReturnUsageArgument, + "f6": graph.ReturnUsageCondition, + "f7": graph.ReturnUsageCondition, + "f8": graph.ReturnUsageCondition, + "f9": graph.ReturnUsageCondition, + "f10": graph.ReturnUsageCondition, + "f11": graph.ReturnUsageArgument, + "f12": graph.ReturnUsageReturned, // concise arrow body is its return value + "f13": graph.ReturnUsageArgument, + "m1": graph.ReturnUsageDiscarded, + "m2": graph.ReturnUsageAssigned, + "f14": graph.ReturnUsageReturned, + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +func TestTypeScriptReturnUsage(t *testing.T) { + src := []byte(`function caller(): number { + f1(); + const x = f2() as number; + const y = f3()!; + if (f4()) { } + g(f5()); + svc.m1(); + const z = svc.m2(); + return f6(); +} +`) + e := NewTypeScriptExtractor() + result, err := e.Extract("main.ts", src) + require.NoError(t, err) + + for callee, want := range map[string]string{ + "f1": graph.ReturnUsageDiscarded, + "f2": graph.ReturnUsageAssigned, // through the `as` cast + "f3": graph.ReturnUsageAssigned, // through the non-null assertion + "f4": graph.ReturnUsageCondition, + "f5": graph.ReturnUsageArgument, + "m1": graph.ReturnUsageDiscarded, + "m2": graph.ReturnUsageAssigned, + "f6": graph.ReturnUsageReturned, + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +func TestJavaReturnUsage(t *testing.T) { + src := []byte(`class A { + void caller() { + f1(); + int x = f2(); + x = f3(); + g(f4()); + if (f5() > 0) { } + while (f6()) { } + for (int i = 0; f7(); i++) { } + for (var a : f8()) { } + switch (f9()) { } + f10().chained(); + Supplier s = () -> f11(); + return f12(); + } +} +`) + e := NewJavaExtractor() + result, err := e.Extract("A.java", src) + require.NoError(t, err) + + for callee, want := range map[string]string{ + "f1": graph.ReturnUsageDiscarded, + "f2": graph.ReturnUsageAssigned, + "f3": graph.ReturnUsageAssigned, + "f4": graph.ReturnUsageArgument, + "f5": graph.ReturnUsageCondition, + "f6": graph.ReturnUsageCondition, + "f7": graph.ReturnUsageCondition, + "f8": graph.ReturnUsageCondition, + "f9": graph.ReturnUsageCondition, + "f10": graph.ReturnUsageArgument, + "f11": graph.ReturnUsageReturned, // expression-bodied lambda + "f12": graph.ReturnUsageReturned, + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +func TestRustReturnUsage(t *testing.T) { + src := []byte(`fn caller() -> i32 { + f1(); + let x = f2(); + let (a, _) = f3(); + let _ = f4(); + x = f5(); + g(f6()); + if f7() > 0 { } + if let Some(v) = f8() { } + while f9() { } + match f10() { _ => {} } + for i in f11() { } + f12().chained(); + let c = || f13(); + return f14(); +} + +fn tail_caller() -> i32 { + f15() +} + +fn branch_tail() -> i32 { + let x = if true { f16() } else { 0 }; + x +} +`) + e := NewRustExtractor() + result, err := e.Extract("main.rs", src) + require.NoError(t, err) + + for callee, want := range map[string]string{ + "f1": graph.ReturnUsageDiscarded, + "f2": graph.ReturnUsageAssigned, + "f3": graph.ReturnUsagePartiallyIgnored, + "f4": graph.ReturnUsageDiscarded, // wildcard binding throws the value away + "f5": graph.ReturnUsageAssigned, + "f6": graph.ReturnUsageArgument, + "f7": graph.ReturnUsageCondition, + "f8": graph.ReturnUsageCondition, // if-let binds inside the condition + "f9": graph.ReturnUsageCondition, + "f10": graph.ReturnUsageCondition, + "f11": graph.ReturnUsageCondition, + "f12": graph.ReturnUsageArgument, + "f13": graph.ReturnUsageReturned, // closure body is its return value + "f14": graph.ReturnUsageReturned, + "f15": graph.ReturnUsageReturned, // implicit tail return + "f16": graph.ReturnUsageAssigned, // branch value flows into the let + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +func TestRubyReturnUsage(t *testing.T) { + src := []byte(`def caller + f1() + x = f2() + a, _ = f3() + x += f4() + g(f5()) + if f6() + nil + end + while f7() + nil + end + case f8() + when 1 then nil + end + f9().chained + foo if f10() + return f11() +end + +def tail_caller + f12() +end +`) + e := NewRubyExtractor() + result, err := e.Extract("main.rb", src) + require.NoError(t, err) + + for callee, want := range map[string]string{ + "f1": graph.ReturnUsageDiscarded, + "f2": graph.ReturnUsageAssigned, + "f3": graph.ReturnUsagePartiallyIgnored, + "f4": graph.ReturnUsageAssigned, + "f5": graph.ReturnUsageArgument, + "f6": graph.ReturnUsageCondition, + "f7": graph.ReturnUsageCondition, + "f8": graph.ReturnUsageCondition, + "f9": graph.ReturnUsageArgument, + "f10": graph.ReturnUsageCondition, // statement modifier + "f11": graph.ReturnUsageReturned, + "f12": graph.ReturnUsageReturned, // implicit tail return + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +// A call inside a Ruby block / do_block must be classified relative to +// the block, never to the statement that consumes the block. Crossing +// the closure boundary would mislabel a block-internal call as assigned +// (to the receiver of a `.map` assignment), as a condition (the if the +// block result feeds), or as the enclosing method's return. +func TestRubyReturnUsageBlockBoundary(t *testing.T) { + src := []byte(`def caller + items.each do |i| + f1(i) + end + y = items.map { |i| f2(i) } + if items.any? { |i| f3(i) } + nil + end + return items.map do |i| + f4(i) + end +end +`) + e := NewRubyExtractor() + result, err := e.Extract("main.rb", src) + require.NoError(t, err) + + for callee, want := range map[string]string{ + // Each block call is the block's tail value: returned from the + // block, not inheriting the consuming statement's label. + "f1": graph.ReturnUsageReturned, // bare `.each` block, not the method body + "f2": graph.ReturnUsageReturned, // `.map` block, not assigned to y + "f3": graph.ReturnUsageReturned, // `.any?` block, not the if condition + "f4": graph.ReturnUsageReturned, // block under `return`, returned from the block + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +// A call that is NOT the tail of a Ruby block body is a bare statement +// inside the block — discarded — and likewise must not leak the +// enclosing statement's label across the closure boundary. +func TestRubyReturnUsageBlockNonTail(t *testing.T) { + src := []byte(`def caller + y = items.map do |i| + f_mid(i) + g(i) + end +end +`) + e := NewRubyExtractor() + result, err := e.Extract("main.rb", src) + require.NoError(t, err) + + assert.Equal(t, graph.ReturnUsageDiscarded, callReturnUsage(t, result.Edges, "f_mid"), + "non-tail block statement is discarded, not assigned to y") + assert.Equal(t, graph.ReturnUsageReturned, callReturnUsage(t, result.Edges, "g"), + "block tail is returned from the block, not assigned to y") +} + +func TestCSharpReturnUsage(t *testing.T) { + src := []byte(`class A { + void Caller() { + F1(); + int x = F2(); + var y = F3(); + x = F4(); + G(F5()); + if (F6() > 0) { } + while (F7()) { } + for (int i = 0; F8(); i++) { } + foreach (var a in F9()) { } + switch (F10()) { } + F11().Chained(); + Func h = () => F12(); + return F13(); + } + + int Shorthand() => F14(); +} +`) + e := NewCSharpExtractor() + result, err := e.Extract("A.cs", src) + require.NoError(t, err) + + for callee, want := range map[string]string{ + "F1": graph.ReturnUsageDiscarded, + "F2": graph.ReturnUsageAssigned, + "F3": graph.ReturnUsageAssigned, + "F4": graph.ReturnUsageAssigned, + "F5": graph.ReturnUsageArgument, + "F6": graph.ReturnUsageCondition, + "F7": graph.ReturnUsageCondition, + "F8": graph.ReturnUsageCondition, + "F9": graph.ReturnUsageCondition, + "F10": graph.ReturnUsageCondition, + "F11": graph.ReturnUsageArgument, + "F12": graph.ReturnUsageReturned, // expression-bodied lambda + "F13": graph.ReturnUsageReturned, + "F14": graph.ReturnUsageReturned, // expression-bodied member + } { + assert.Equal(t, want, callReturnUsage(t, result.Edges, callee), "callee %s", callee) + } +} + +// The C# switch_expression carries no `value` field — its governing +// expression is positional, ahead of the `switch` keyword. The +// classifier must still place a call in that position as a condition, +// and must not mistake an arm-body call for the condition. +func TestCSharpReturnUsageSwitchExpression(t *testing.T) { + src := []byte(`class A { + int Caller(int n) { + return F1() switch { + 1 => F2(), + _ => 0, + }; + } +} +`) + e := NewCSharpExtractor() + result, err := e.Extract("A.cs", src) + require.NoError(t, err) + + assert.Equal(t, graph.ReturnUsageCondition, callReturnUsage(t, result.Edges, "F1"), + "the governing expression of a switch expression reads as a condition") + // F2 sits in an arm body — its value becomes the switch-expression + // result, a shape the classifier does not model; it must at least + // never be mislabeled as the condition. + assert.NotEqual(t, graph.ReturnUsageCondition, callReturnUsage(t, result.Edges, "F2"), + "an arm-body call is not the switch condition") +} + +// A language without go/defer must never see those labels, and an +// unclassifiable parent chain must leave the edge unstamped rather +// than mislabeled. +func TestReturnUsageNeverFabricatesLabels(t *testing.T) { + src := []byte(`def caller(): + raise f1() +`) + e := NewPythonExtractor() + result, err := e.Extract("main.py", src) + require.NoError(t, err) + + got := callReturnUsage(t, result.Edges, "f1") + assert.Empty(t, got, "raise is not a covered consumption shape — no label") + for _, edge := range result.Edges { + usage := graph.ReturnUsageOf(edge) + assert.NotEqual(t, graph.ReturnUsageGoroutine, usage) + assert.NotEqual(t, graph.ReturnUsageDeferred, usage) + } +} diff --git a/internal/parser/languages/ruby.go b/internal/parser/languages/ruby.go index c91a3f20..d6e1d1d6 100644 --- a/internal/parser/languages/ruby.go +++ b/internal/parser/languages/ruby.go @@ -69,6 +69,10 @@ type rubyDeferredCall struct { name string line int hasRecv bool + // returnUsage is how the call site consumes the return value + // (graph.ReturnUsage* label), classified at capture time and + // stamped as edge Meta on the EdgeCalls emitted for this site. + returnUsage string } func (e *RubyExtractor) Extract(filePath string, src []byte) (*parser.ExtractionResult, error) { @@ -128,9 +132,10 @@ func (e *RubyExtractor) Extract(filePath string, src []byte) (*parser.Extraction } } calls = append(calls, rubyDeferredCall{ - name: name, - line: expr.StartLine + 1, - hasRecv: hasRecv, + name: name, + line: expr.StartLine + 1, + hasRecv: hasRecv, + returnUsage: classifyReturnUsage(expr.Node, src, rubyReturnUsageSpec), }) case m.Captures["const.def"] != nil: @@ -149,10 +154,12 @@ func (e *RubyExtractor) Extract(filePath string, src []byte) (*parser.Extraction if c.hasRecv { target = "unresolved::*." + c.name } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: target, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) } // Rails-style callback dispatch — preserves legacy behaviour exactly. diff --git a/internal/parser/languages/rust.go b/internal/parser/languages/rust.go index 12b8404e..13eb356f 100644 --- a/internal/parser/languages/rust.go +++ b/internal/parser/languages/rust.go @@ -93,6 +93,10 @@ type rustDeferredCall struct { path string // full scoped_identifier text for path calls (e.g. "Foo::new", "crate::util::helper"); "" otherwise line int isSelector bool + // returnUsage is how the call site consumes the return value + // (graph.ReturnUsage* label), classified at capture time and + // stamped as edge Meta on the EdgeCalls emitted for this site. + returnUsage string } // rustDeferredLet buffers a let_declaration for the post-pass type-env @@ -179,17 +183,19 @@ func (e *RustExtractor) Extract(filePath string, src []byte) (*parser.Extraction case m.Captures["callm.expr"] != nil: expr := m.Captures["callm.expr"] calls = append(calls, rustDeferredCall{ - name: m.Captures["callm.method"].Text, - receiver: m.Captures["callm.receiver"].Text, - line: expr.StartLine + 1, - isSelector: true, + name: m.Captures["callm.method"].Text, + receiver: m.Captures["callm.receiver"].Text, + line: expr.StartLine + 1, + isSelector: true, + returnUsage: classifyReturnUsage(expr.Node, src, rustReturnUsageSpec), }) case m.Captures["callp.expr"] != nil: expr := m.Captures["callp.expr"] c := rustDeferredCall{ - name: m.Captures["callp.name"].Text, - line: expr.StartLine + 1, + name: m.Captures["callp.name"].Text, + line: expr.StartLine + 1, + returnUsage: classifyReturnUsage(expr.Node, src, rustReturnUsageSpec), } // The query only captures the scoped_identifier's trailing // segment, so the qualifier (Foo / Self / crate / super / @@ -207,8 +213,9 @@ func (e *RustExtractor) Extract(filePath string, src []byte) (*parser.Extraction case m.Captures["call.expr"] != nil: expr := m.Captures["call.expr"] calls = append(calls, rustDeferredCall{ - name: m.Captures["call.name"].Text, - line: expr.StartLine + 1, + name: m.Captures["call.name"].Text, + line: expr.StartLine + 1, + returnUsage: classifyReturnUsage(expr.Node, src, rustReturnUsageSpec), }) } }) @@ -281,6 +288,7 @@ func (e *RustExtractor) Extract(filePath string, src []byte) (*parser.Extraction } edge.Meta["rust_recv"] = c.receiver } + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) continue } @@ -294,6 +302,7 @@ func (e *RustExtractor) Extract(filePath string, src []byte) (*parser.Extraction if c.path != "" && strings.Contains(c.path, "::") { edge.Meta = map[string]any{"rust_path": c.path} } + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) } diff --git a/internal/parser/languages/rust_test.go b/internal/parser/languages/rust_test.go index 6909a194..3f0d9356 100644 --- a/internal/parser/languages/rust_test.go +++ b/internal/parser/languages/rust_test.go @@ -296,7 +296,7 @@ fn main() { } } require.NotNil(t, processCall) - assert.Nil(t, processCall.Meta, "unknown type should not produce Meta") + assert.NotContains(t, processCall.Meta, "receiver_type", "unknown type should not produce a receiver_type hint") } func TestRsExtractor_TypeEnv_Chain(t *testing.T) { diff --git a/internal/parser/languages/typescript.go b/internal/parser/languages/typescript.go index dd4d7976..38003d10 100644 --- a/internal/parser/languages/typescript.go +++ b/internal/parser/languages/typescript.go @@ -137,6 +137,10 @@ type deferredCall struct { // expr is the call_expression node, kept for member calls so the // post-pass can inspect arguments for pub/sub topic detection. expr *sitter.Node + // returnUsage is how the call site consumes the return value + // (graph.ReturnUsage* label), classified at capture time and + // stamped as edge Meta on every EdgeCalls emitted for this site. + returnUsage string } // deferredVar holds a lexical_declaration match whose emission is @@ -254,18 +258,20 @@ func (e *TypeScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr case m.Captures["call.expr"] != nil: expr := m.Captures["call.expr"] calls = append(calls, deferredCall{ - name: m.Captures["call.name"].Text, - line: expr.StartLine + 1, + name: m.Captures["call.name"].Text, + line: expr.StartLine + 1, + returnUsage: classifyReturnUsage(expr.Node, src, jsTSReturnUsageSpec), }) case m.Captures["callm.expr"] != nil: expr := m.Captures["callm.expr"] calls = append(calls, deferredCall{ - method: m.Captures["callm.method"].Text, - receiver: m.Captures["callm.receiver"].Text, - line: expr.StartLine + 1, - isMember: true, - expr: expr.Node, + method: m.Captures["callm.method"].Text, + receiver: m.Captures["callm.receiver"].Text, + line: expr.StartLine + 1, + isMember: true, + expr: expr.Node, + returnUsage: classifyReturnUsage(expr.Node, src, jsTSReturnUsageSpec), }) case m.Captures["tvar.def"] != nil: @@ -350,24 +356,30 @@ func (e *TypeScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr if !c.isMember { if binding, ok := destructured[c.name]; ok { if memberID := objLiteralMembers[binding][c.name]; memberID != "" { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: memberID, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Origin: graph.OriginASTResolved, Confidence: 0.9, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::" + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{"via": "store-factory", "store_binding": binding, "store_action": c.name}, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::" + c.name, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } // Object-literal member call (`api.process()` where @@ -378,38 +390,46 @@ func (e *TypeScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr // this edge in the name-only fallback. if members, ok := objLiteralMembers[c.receiver]; ok { if memberID, ok := members[c.method]; ok { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: memberID, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Origin: graph.OriginASTResolved, Confidence: 0.92, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } } // Namespace/default import receiver (e.g. `fs.readFile`): attach // the module path so the resolver can classify externally. if importPath, ok := imports[c.receiver]; ok { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::extern::" + importPath + "::" + c.method, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } // Store-factory chained call: `useStore.getState().action()`. if binding, ok := jsParseGetStateChain(c.receiver); ok { if memberID := objLiteralMembers[binding][c.method]; memberID != "" { - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: memberID, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Origin: graph.OriginASTResolved, Confidence: 0.9, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: "unresolved::*." + c.method, Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{"via": "store-factory", "store_binding": binding, "store_action": c.method}, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) continue } edge := &graph.Edge{ @@ -423,6 +443,7 @@ func (e *TypeScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr edge.Meta = map[string]any{"receiver_type": chainType} } } + stampReturnUsage(edge, c.returnUsage) result.Edges = append(result.Edges, edge) } @@ -484,11 +505,13 @@ func (e *TypeScriptExtractor) Extract(filePath string, src []byte) (*parser.Extr if callerID == "" { continue } - result.Edges = append(result.Edges, &graph.Edge{ + edge := &graph.Edge{ From: callerID, To: rnNativePlaceholder(module, c.method), Kind: graph.EdgeCalls, FilePath: filePath, Line: c.line, Meta: map[string]any{"via": rnNativeVia, "rn_module": module, "rn_method": c.method}, - }) + } + stampReturnUsage(edge, c.returnUsage) + result.Edges = append(result.Edges, edge) } // --- React Native Fabric / Codegen component spec --- diff --git a/internal/parser/languages/typescript_test.go b/internal/parser/languages/typescript_test.go index e7599a6c..b8dee94a 100644 --- a/internal/parser/languages/typescript_test.go +++ b/internal/parser/languages/typescript_test.go @@ -259,7 +259,7 @@ function main() { } } require.NotNil(t, processCall) - assert.Nil(t, processCall.Meta, "unknown type should not produce Meta") + assert.NotContains(t, processCall.Meta, "receiver_type", "unknown type should not produce a receiver_type hint") } func TestTSExtractor_TypeEnv_Chain(t *testing.T) { From 3dee4780d55f530b79f8fa0de81d5adfae9de432 Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Sat, 13 Jun 2026 08:04:12 +0200 Subject: [PATCH 2/5] feat(cfg): per-function control-flow graphs + reaching-definitions fixpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New internal/cfg package: on-demand per-function CFGs from the tree-sitter AST for Go, Python, JavaScript, TypeScript, Java, Rust, and Ruby — basic blocks with per-statement def/use sets, labeled edges (branches, loops, labeled break/continue, switch/match fallthrough, try/except/finally), and a bitset GEN/KILL reaching-definitions fixpoint producing statement-granular def-to-use chains, with a Mermaid renderer. Exposed as the get_cfg MCP tool and analyze kind=def_use. internal/dataflow gained a CFG-backed refiner on flow_between and taint_paths: same-function value_flow hops are confirmed or pruned based on whether the def reaches the use, and pruned paths sink in the ranking. --- internal/agents/claudecode/content.go | 3 +- internal/cfg/builder.go | 567 +++++++++++ internal/cfg/cfg.go | 279 ++++++ internal/cfg/cfg_test.go | 716 ++++++++++++++ internal/cfg/defuse.go | 224 +++++ internal/cfg/lang.go | 1287 +++++++++++++++++++++++++ internal/cfg/lang_test.go | 898 +++++++++++++++++ internal/cfg/mermaid.go | 88 ++ internal/cfg/reaching.go | 236 +++++ internal/dataflow/dataflow.go | 73 +- internal/dataflow/refine.go | 269 ++++++ internal/dataflow/refine_test.go | 318 ++++++ internal/mcp/gcx.go | 12 + internal/mcp/scope_init.go | 4 + internal/mcp/server.go | 1 + internal/mcp/tools_cfg.go | 388 ++++++++ internal/mcp/tools_cfg_test.go | 330 +++++++ internal/mcp/tools_dataflow.go | 55 +- internal/mcp/tools_enhancements.go | 6 +- 19 files changed, 5732 insertions(+), 22 deletions(-) create mode 100644 internal/cfg/builder.go create mode 100644 internal/cfg/cfg.go create mode 100644 internal/cfg/cfg_test.go create mode 100644 internal/cfg/defuse.go create mode 100644 internal/cfg/lang.go create mode 100644 internal/cfg/lang_test.go create mode 100644 internal/cfg/mermaid.go create mode 100644 internal/cfg/reaching.go create mode 100644 internal/dataflow/refine.go create mode 100644 internal/dataflow/refine_test.go create mode 100644 internal/mcp/tools_cfg.go create mode 100644 internal/mcp/tools_cfg_test.go diff --git a/internal/agents/claudecode/content.go b/internal/agents/claudecode/content.go index ca47403c..75885f57 100644 --- a/internal/agents/claudecode/content.go +++ b/internal/agents/claudecode/content.go @@ -348,6 +348,7 @@ These wrap the discovery + impact + memory surfaces into ordered playbooks so po |------|-------------------| | flow_between | Ranked dataflow paths between two symbol IDs. Walks ` + "`value_flow`" + ` (intra-procedural) ∪ ` + "`arg_of`" + ` (caller arg → callee param) ∪ ` + "`returns_to`" + ` (callee → assignment). Pass ` + "`max_depth`" + ` (default 8) and ` + "`max_paths`" + ` (default 10). | | taint_paths | Pattern-driven dataflow sweep — every flow from a matching source to a matching sink. Patterns: bare token = name substring; ` + "`exact:Foo`" + `; ` + "`path:dir/`" + `; ` + "`kind:method`" + ` (clauses combine with AND). Sinks expand functions to their params automatically. | +| get_cfg | Per-function control-flow graph for one function/method ID — basic blocks, labeled edges (seq/true/false/loop_back/break/continue/return/case/exception/finally), per-statement def/use sets, and statement-granular reaching-definition chains. Optional ` + "`mermaid`" + ` rendering. Go / Python / JS / TS / Java / Rust / Ruby. | ### Structural Code Search | Tool | What it gives you | @@ -392,7 +393,7 @@ These wrap the discovery + impact + memory surfaces into ordered playbooks so po ### Code Quality | Tool | What it gives you | |------|-------------------| -| analyze | Unified graph-analysis dispatcher (60 kinds). Structural: dead_code, hotspots, cycles, would_create_cycle, clusters, concepts, role, connectivity_health, edge_audit, constructors_missing_fields. Quality / security: health_score, impact, sast, hygiene, unsafe_patterns, named, review (idiomatic / correctness rulepack — NPE, thread-safety check-then-act, N+1, logic-error; Go + Python — with a graph-grounded false-positive-reduction pass). Churn / ownership: todos, stale_code, ownership, fixes_history, blame. Coverage / releases: coverage, coverage_gaps, coverage_summary, releases. Schema / SQL: orphan_tables, unreferenced_tables, sql_call_sites, sql_rebuild, dbt_models, models. Flags / interop: stale_flags, cgo_users, wasm_users. Edge-driven: channel_ops, race_writes, unclosed_channels, goroutine_spawns, field_writers, annotation_users, config_readers, env_var_users, event_emitters, log_events, string_emitters, error_surface, external_calls, tests_as_edges. Web / infra: routes, components, k8s_resources, images, kustomize, pubsub. Cross-repo: cross_repo. Provenance / resolution: synthesizers (framework-dispatch edges grouped by pass), resolution_outcomes (why a call/ref edge stayed unresolved). Extensible: domain | +| analyze | Unified graph-analysis dispatcher (60 kinds). Structural: dead_code, hotspots, cycles, would_create_cycle, clusters, concepts, role, connectivity_health, edge_audit, constructors_missing_fields. Quality / security: health_score, impact, sast, hygiene, unsafe_patterns, named, review (idiomatic / correctness rulepack — NPE, thread-safety check-then-act, N+1, logic-error; Go + Python — with a graph-grounded false-positive-reduction pass). Churn / ownership: todos, stale_code, ownership, fixes_history, blame. Coverage / releases: coverage, coverage_gaps, coverage_summary, releases. Schema / SQL: orphan_tables, unreferenced_tables, sql_call_sites, sql_rebuild, dbt_models, models. Flags / interop: stale_flags, cgo_users, wasm_users. Edge-driven: channel_ops, race_writes, unclosed_channels, goroutine_spawns, field_writers, annotation_users, config_readers, env_var_users, event_emitters, log_events, string_emitters, error_surface, external_calls, tests_as_edges. Web / infra: routes, components, k8s_resources, images, kustomize, pubsub. Cross-repo: cross_repo. Dataflow: def_use (per-function reaching-definition def→use chains over the on-demand CFG; pairs with the get_cfg tool). Provenance / resolution: synthesizers (framework-dispatch edges grouped by pass), resolution_outcomes (why a call/ref edge stayed unresolved). Extensible: domain | | analyze kind=dead_code | Symbols with zero incoming edges (excludes entry points, tests, exports) | | analyze kind=hotspots | Over-coupled symbols ranked by fan-in, fan-out, and community crossings | | analyze kind=cycles | Tarjan's SCC with severity classification | diff --git a/internal/cfg/builder.go b/internal/cfg/builder.go new file mode 100644 index 00000000..c9a92801 --- /dev/null +++ b/internal/cfg/builder.go @@ -0,0 +1,567 @@ +package cfg + +import ( + "strings" + + sitter "github.com/zzet/gortex/internal/parser/tsitter" +) + +// maxStmtText caps the recorded statement text so giant one-liners +// don't bloat tool responses. +const maxStmtText = 120 + +// frame is one entry of the break/continue resolution stack. Loops +// push a frame with both targets; switch statements in languages +// where `break` exits the switch push a frame with only breakTo. +type frame struct { + label string + continueTo *Block + breakTo *Block + isLoop bool +} + +// builder drives CFG construction. cur is the block receiving the +// next statement; nil means the current position is past a +// terminator (return/break/…) — the next statement starts a fresh, +// unreachable block so its defs/uses still surface. +type builder struct { + spec *langSpec + src []byte + lineOffset int + cfg *CFG + + cur *Block + frames []frame + pendingLabel string + pendingFallthrough bool + + edgeSeen map[edgeKey]bool +} + +type edgeKey struct { + from, to int + label EdgeLabel +} + +func (b *builder) newBlock(label string) *Block { + bl := &Block{ID: len(b.cfg.Blocks), Label: label} + b.cfg.Blocks = append(b.cfg.Blocks, bl) + return bl +} + +func (b *builder) edge(from, to *Block, label EdgeLabel) { + if from == nil || to == nil { + return + } + k := edgeKey{from.ID, to.ID, label} + if b.edgeSeen[k] { + return + } + b.edgeSeen[k] = true + b.cfg.Edges = append(b.cfg.Edges, Edge{From: from.ID, To: to.ID, Label: label}) +} + +// moveTo links cur to bl sequentially and makes bl current. +func (b *builder) moveTo(bl *Block) { + if b.cur != nil { + b.edge(b.cur, bl, LabelSeq) + } + b.cur = bl +} + +// ensureCur guarantees a current block, opening an unreachable one +// when the previous statement terminated control flow. +func (b *builder) ensureCur() { + if b.cur == nil { + b.cur = b.newBlock("unreachable") + } +} + +func (b *builder) pushFrame(f frame) { b.frames = append(b.frames, f) } +func (b *builder) popFrame() { b.frames = b.frames[:len(b.frames)-1] } + +// takeLabel consumes the label set by an enclosing labeled +// statement, if any. +func (b *builder) takeLabel() string { + l := b.pendingLabel + b.pendingLabel = "" + return l +} + +// record appends a synthetic statement with explicit position/text. +func (b *builder) record(startLine, endLine int, text, kind string) *Statement { + st := &Statement{ + Index: len(b.cfg.Stmts), + Block: b.cur.ID, + StartLine: startLine + b.lineOffset, + EndLine: endLine + b.lineOffset, + Text: text, + Kind: kind, + } + b.cfg.Stmts = append(b.cfg.Stmts, st) + b.cur.Stmts = append(b.cur.Stmts, st) + return st +} + +// recordNode appends a statement positioned at n without running +// def/use extraction (callers fill Defs/Uses themselves). +func (b *builder) recordNode(n *sitter.Node, kind string) *Statement { + return b.record(int(n.StartPoint().Row)+1, int(n.EndPoint().Row)+1, stmtText(n, b.src), kind) +} + +// addStmt appends a statement for n with def/use extraction. +func (b *builder) addStmt(n *sitter.Node, kind string) *Statement { + if n == nil { + return nil + } + st := b.recordNode(n, kind) + st.Defs, st.Uses = extractDefUse(b.spec, b.src, n, false) + return st +} + +// leaf records n as a plain statement in the current block. +func (b *builder) leaf(n *sitter.Node, kind string) { + b.ensureCur() + b.addStmt(n, kind) +} + +// stmtText renders the statement's first source line, trimmed and +// capped; multi-line statements get an ellipsis. +func stmtText(n *sitter.Node, src []byte) string { + text := n.Content(src) + if i := strings.IndexByte(text, '\n'); i >= 0 { + text = text[:i] + " …" + } + text = strings.TrimSpace(text) + if len(text) > maxStmtText { + text = text[:maxStmtText] + "…" + } + return text +} + +// buildStmt processes one statement node: control constructs are +// consumed by the language dispatch table, everything else is a leaf. +func (b *builder) buildStmt(n *sitter.Node) { + if n == nil { + return + } + if n.Type() == "comment" { + return + } + if b.spec.dispatch(b, n) { + return + } + b.leaf(n, "") +} + +// buildSeq processes every named child of n as a statement. +func (b *builder) buildSeq(n *sitter.Node) { + for i := 0; i < int(n.NamedChildCount()); i++ { + if c := n.NamedChild(i); c != nil { + b.buildStmt(c) + } + } +} + +// --------------------------------------------------------------------------- +// if / else +// --------------------------------------------------------------------------- + +// buildIf wires the classic diamond. alt may itself be an if (else- +// if chains) — the recursive buildStmt handles it through dispatch. +func (b *builder) buildIf(init, cond, then, alt *sitter.Node) { + b.ensureCur() + if init != nil { + b.buildStmt(init) + } + if cond != nil { + b.addStmt(cond, "cond") + } + head := b.cur + after := b.newBlock("if_end") + + thenBlock := b.newBlock("then") + b.edge(head, thenBlock, LabelTrue) + b.cur = thenBlock + if then != nil { + b.buildStmt(then) + } + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + + if alt != nil { + elseBlock := b.newBlock("else") + b.edge(head, elseBlock, LabelFalse) + b.cur = elseBlock + b.buildStmt(alt) + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + } else { + b.edge(head, after, LabelFalse) + } + b.cur = after +} + +// buildIfChain handles grammars that stack elif/else clauses as +// sibling `alternative` fields (Python) instead of nesting them. +func (b *builder) buildIfChain(cond, cons *sitter.Node, alts []*sitter.Node) { + b.ensureCur() + if cond != nil { + b.addStmt(cond, "cond") + } + head := b.cur + after := b.newBlock("if_end") + + thenBlock := b.newBlock("then") + b.edge(head, thenBlock, LabelTrue) + b.cur = thenBlock + if cons != nil { + b.buildStmt(cons) + } + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + + if len(alts) == 0 { + b.edge(head, after, LabelFalse) + b.cur = after + return + } + + elseBlock := b.newBlock("else") + b.edge(head, elseBlock, LabelFalse) + b.cur = elseBlock + first := alts[0] + if first.Type() == "elif_clause" { + b.buildIfChain(first.ChildByFieldName("condition"), first.ChildByFieldName("consequence"), alts[1:]) + } else { + // else_clause terminates the chain. + b.buildStmt(first) + } + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + b.cur = after +} + +// --------------------------------------------------------------------------- +// loops +// --------------------------------------------------------------------------- + +// loopParts feeds buildLoop. Exactly one of {cond, headerStmt, +// infinite} shapes the header: +// - cond: pre/post-test condition loop (while / for / do-while) +// - headerStmt: for-in/range loop; the header statement defines the +// loop variables and reads the iterable. When +// headerStmtOnlyHeaderFields is set the node also contains the +// body, so def/use extraction is restricted to the header fields. +// - infinite: no condition (Go `for {}`, Rust `loop {}`). +type loopParts struct { + init, cond, update, body *sitter.Node + headerStmt *sitter.Node + headerStmtOnlyHeaderFields bool + postTest bool + infinite bool + // elseNode is a Python for/while-else clause. It runs only on a + // normal (non-break) loop exit, so the builder routes the header's + // False edge through the else block while `break` jumps past it to + // a dedicated join block. Nil for every other language. + elseNode *sitter.Node +} + +func (b *builder) buildLoop(p loopParts) { + b.ensureCur() + if p.init != nil { + b.buildStmt(p.init) + } + label := b.takeLabel() + + if p.postTest { + bodyBlock := b.newBlock("loop_body") + b.moveTo(bodyBlock) + header := b.newBlock("loop_header") + after := b.newBlock("loop_end") + b.pushFrame(frame{label: label, continueTo: header, breakTo: after, isLoop: true}) + if p.body != nil { + b.buildStmt(p.body) + } + b.popFrame() + if b.cur != nil { + b.edge(b.cur, header, LabelSeq) + } + b.cur = header + if p.cond != nil { + b.addStmt(p.cond, "cond") + } + b.edge(header, bodyBlock, LabelLoopBack) + b.edge(header, after, LabelFalse) + b.cur = after + return + } + + header := b.newBlock("loop_header") + b.moveTo(header) + if p.headerStmt != nil { + b.addLoopHeaderStmt(p) + } else if p.cond != nil { + b.addStmt(p.cond, "cond") + } + after := b.newBlock("loop_end") + bodyBlock := b.newBlock("loop_body") + // With a for/while-else clause, the normal exit (header False) runs + // the else before reaching the join, but `break` must skip it. + // Route break edges at a separate join block; without an else the + // join is the loop_end itself so behaviour is unchanged. + breakTarget := after + if p.elseNode != nil { + breakTarget = b.newBlock("loop_join") + } + if p.infinite { + b.edge(header, bodyBlock, LabelSeq) + } else { + b.edge(header, bodyBlock, LabelTrue) + b.edge(header, after, LabelFalse) + } + + var updateBlock *Block + contTarget := header + if p.update != nil { + updateBlock = b.newBlock("loop_update") + contTarget = updateBlock + } + b.pushFrame(frame{label: label, continueTo: contTarget, breakTo: breakTarget, isLoop: true}) + b.cur = bodyBlock + if p.body != nil { + b.buildStmt(p.body) + } + b.popFrame() + if updateBlock != nil { + if b.cur != nil { + b.edge(b.cur, updateBlock, LabelSeq) + } + b.cur = updateBlock + b.buildStmt(p.update) + if b.cur != nil { + b.edge(b.cur, header, LabelLoopBack) + } + } else if b.cur != nil { + b.edge(b.cur, header, LabelLoopBack) + } + b.cur = after + if p.elseNode != nil { + // The else clause runs on the False (no-break) exit; break + // edges already bypass it by targeting the join directly. + b.buildStmt(p.elseNode) + if b.cur != nil { + b.edge(b.cur, breakTarget, LabelSeq) + } + b.cur = breakTarget + } +} + +// addLoopHeaderStmt records the for-in header: loop variables are +// definitions, the iterable is a use. When the header node embeds +// the body (Python for, JS for-in, Java enhanced-for, Rust for, Ruby +// for) only the header fields are inspected. +func (b *builder) addLoopHeaderStmt(p loopParts) { + n := p.headerStmt + if !p.headerStmtOnlyHeaderFields { + b.addStmt(n, "loop") + return + } + defsNode, usesNode := forInHeaderFields(n) + startLine := int(n.StartPoint().Row) + 1 + endLine := startLine + if usesNode != nil { + endLine = int(usesNode.EndPoint().Row) + 1 + } + text := stmtText(n, b.src) + st := b.record(startLine, endLine, text, "loop") + if defsNode != nil { + defs, _ := extractDefUse(b.spec, b.src, defsNode, true) + st.Defs = defs + } + if usesNode != nil { + _, uses := extractDefUse(b.spec, b.src, usesNode, false) + st.Uses = uses + } +} + +// forInHeaderFields probes the field-name pairs the supported +// grammars use for " in " headers. +func forInHeaderFields(n *sitter.Node) (defs, uses *sitter.Node) { + if l := n.ChildByFieldName("left"); l != nil { + return l, n.ChildByFieldName("right") + } + if p := n.ChildByFieldName("pattern"); p != nil { + return p, n.ChildByFieldName("value") + } + if name := n.ChildByFieldName("name"); name != nil { + return name, n.ChildByFieldName("value") + } + return nil, nil +} + +// --------------------------------------------------------------------------- +// jumps +// --------------------------------------------------------------------------- + +// findFrame resolves a break/continue target. continue skips frames +// without a continue target (switches); a label restricts the match. +func (b *builder) findFrame(label string, needContinue bool) *frame { + for i := len(b.frames) - 1; i >= 0; i-- { + f := &b.frames[i] + if needContinue && f.continueTo == nil { + continue + } + if label != "" && f.label != label { + continue + } + return f + } + return nil +} + +func (b *builder) buildBreak(n *sitter.Node, label string) { + b.ensureCur() + b.addStmt(n, "break") + if f := b.findFrame(label, false); f != nil { + b.edge(b.cur, f.breakTo, LabelBreak) + } else { + // break outside any loop/switch — treat as function exit so + // the flow graph stays connected. + b.edge(b.cur, b.cfg.Blocks[b.cfg.Exit], LabelBreak) + } + b.cur = nil +} + +func (b *builder) buildContinue(n *sitter.Node, label string) { + b.ensureCur() + b.addStmt(n, "continue") + if f := b.findFrame(label, true); f != nil { + b.edge(b.cur, f.continueTo, LabelContinue) + } else { + b.edge(b.cur, b.cfg.Blocks[b.cfg.Exit], LabelContinue) + } + b.cur = nil +} + +// buildReturn handles return/raise/throw: the statement reads its +// expression and control transfers to the exit block. +func (b *builder) buildReturn(n *sitter.Node, kind string, label EdgeLabel) { + b.ensureCur() + b.addStmt(n, kind) + b.edge(b.cur, b.cfg.Blocks[b.cfg.Exit], label) + b.cur = nil +} + +// --------------------------------------------------------------------------- +// try / except / finally +// --------------------------------------------------------------------------- + +// handlerPart is one catch/except/rescue clause. headerNode carries +// the exception filter and binding; headerDefs marks the node as a +// pure binding (its identifiers are definitions, e.g. `catch (e)`). +type handlerPart struct { + headerNode *sitter.Node + headerDefs bool + bodyNode *sitter.Node +} + +// tryParts feeds buildTry. The protected body is either one block +// node or an explicit statement list (Ruby's method-level rescue). +type tryParts struct { + bodyNode *sitter.Node + bodyStmts []*sitter.Node + handlers []handlerPart + elseNode *sitter.Node + finallyNode *sitter.Node +} + +// buildTry wires the protected region: every block created while +// building the body gets an exception edge to every handler — the +// conservative may-throw model (an exception can surface at any +// point of the region, so handler entry merges the region's defs). +// The region opens with an empty marker block so the region's IN +// state also reaches the handlers (an exception can fire before the +// first protected statement completes). Within one basic block the +// model stays block-granular: a def made and re-killed inside the +// same region block is not separately visible to the handler. +func (b *builder) buildTry(p tryParts) { + b.ensureCur() + tryBlock := b.newBlock("try") + b.moveTo(tryBlock) + regionStart := tryBlock.ID + bodyBlock := b.newBlock("try_body") + b.moveTo(bodyBlock) + if p.bodyNode != nil { + b.buildStmt(p.bodyNode) + } + for _, st := range p.bodyStmts { + b.buildStmt(st) + } + tryEnd := b.cur + regionEnd := len(b.cfg.Blocks) + + after := b.newBlock("try_end") + handlerEnds := make([]*Block, 0, len(p.handlers)) + for _, h := range p.handlers { + hb := b.newBlock("handler") + for id := regionStart; id < regionEnd; id++ { + b.edge(b.cfg.Blocks[id], hb, LabelException) + } + b.cur = hb + if h.headerNode != nil { + st := b.recordNode(h.headerNode, "catch") + st.Defs, st.Uses = extractDefUse(b.spec, b.src, h.headerNode, h.headerDefs) + } + if h.bodyNode != nil { + b.buildStmt(h.bodyNode) + } + handlerEnds = append(handlerEnds, b.cur) + } + + mainEnd := tryEnd + if p.elseNode != nil && tryEnd != nil { + eb := b.newBlock("try_else") + b.edge(tryEnd, eb, LabelSeq) + b.cur = eb + b.buildStmt(p.elseNode) + mainEnd = b.cur + } + + if p.finallyNode != nil { + fb := b.newBlock("finally") + if mainEnd != nil { + b.edge(mainEnd, fb, LabelSeq) + } + for _, he := range handlerEnds { + if he != nil { + b.edge(he, fb, LabelFinally) + } + } + // An exception that matches no handler (or a handler-less + // try/finally) still runs the finalizer on its way out, so + // the protected region always feeds the finally directly. + for id := regionStart; id < regionEnd; id++ { + b.edge(b.cfg.Blocks[id], fb, LabelException) + } + b.cur = fb + b.buildStmt(p.finallyNode) + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + } else { + if mainEnd != nil { + b.edge(mainEnd, after, LabelSeq) + } + for _, he := range handlerEnds { + if he != nil { + b.edge(he, after, LabelSeq) + } + } + } + b.cur = after +} diff --git a/internal/cfg/cfg.go b/internal/cfg/cfg.go new file mode 100644 index 00000000..d2c04725 --- /dev/null +++ b/internal/cfg/cfg.go @@ -0,0 +1,279 @@ +// Package cfg builds intra-procedural control-flow graphs from a +// single function's source text, on demand, and runs a classic +// GEN/KILL reaching-definitions fixpoint over them. +// +// The package is deliberately query-time-only: nothing here runs at +// index time and nothing touches the whole graph. A caller hands in +// one function's source (typically sliced out of a file by the +// symbol's line range), names the language, and gets back: +// +// - basic blocks, each holding the ordered statements it executes +// with their line spans and per-statement def/use variable sets; +// - labeled edges between blocks (seq / true / false / loop_back / +// break / continue / return / case / exception / finally); +// - via ReachingDefinitions, statement-granular def→use chains: +// for every variable read, the set of definitions that can reach +// it along some path. +// +// Seven languages are covered by per-language control-construct +// tables: Go, Python, JavaScript, TypeScript, Java, Rust, and Ruby. +// Parsing reuses the parser package's pooled tree-sitter parsers +// (errored parsers are closed, never pooled — see parser.ParseFile). +// +// Scope model. Definitions and uses are tracked by variable NAME +// within the function: parameters are entry-block definitions, +// assignments / declarations / augmented assigns define, identifier +// reads use. Nested function literals (closures, lambdas, inner +// defs) are treated as opaque — their bodies execute at an unknown +// later time, so neither their assignments nor their reads are +// attributed to the enclosing function's statements. Reads of names +// never defined in the function (globals, package symbols) simply +// produce no def→use chain. +package cfg + +import ( + "errors" + "fmt" + "strings" + + "github.com/zzet/gortex/internal/parser" + sitter "github.com/zzet/gortex/internal/parser/tsitter" +) + +// EdgeLabel classifies a control-flow edge between two basic blocks. +type EdgeLabel string + +const ( + LabelSeq EdgeLabel = "seq" + LabelTrue EdgeLabel = "true" + LabelFalse EdgeLabel = "false" + LabelLoopBack EdgeLabel = "loop_back" + LabelBreak EdgeLabel = "break" + LabelContinue EdgeLabel = "continue" + LabelReturn EdgeLabel = "return" + LabelCase EdgeLabel = "case" + LabelException EdgeLabel = "exception" + LabelFinally EdgeLabel = "finally" +) + +// Statement is one executable statement (or synthetic pseudo- +// statement: parameter binding, branch condition, case label) inside +// a basic block. Lines are 1-based and already shifted by +// Options.LineOffset so they are file-absolute when the caller +// passes the function's position. +type Statement struct { + Index int `json:"index"` + Block int `json:"block"` + StartLine int `json:"start_line"` + EndLine int `json:"end_line"` + Text string `json:"text"` + Kind string `json:"kind,omitempty"` + Defs []string `json:"defs,omitempty"` + Uses []string `json:"uses,omitempty"` +} + +// Block is a basic block: a maximal straight-line statement sequence +// with control transfers only at the end. +type Block struct { + ID int + Label string + Stmts []*Statement +} + +// Edge is one labeled control-flow edge. +type Edge struct { + From int `json:"from"` + To int `json:"to"` + Label EdgeLabel `json:"label"` +} + +// CFG is a per-function control-flow graph. Blocks[Entry] holds the +// synthetic parameter definitions; Blocks[Exit] is the empty sink +// every return/fall-off-the-end edge targets. +type CFG struct { + FuncName string + Language string + Entry int + Exit int + Blocks []*Block + Edges []Edge + Stmts []*Statement +} + +// Options tunes Build. +type Options struct { + // LineOffset is added to every (1-based) snippet line so the CFG + // reports file-absolute lines. Pass the function node's + // StartLine-1 when src is sliced out of a larger file. + LineOffset int + // FuncName overrides the name discovered from the AST. + FuncName string +} + +// Build parses src as one function in the given language and +// constructs its control-flow graph. src must contain (or start +// with) a single function/method definition — the first function- +// like node found in the parse tree is used. Languages whose +// methods don't parse standalone (Java, JS/TS class methods) are +// retried inside a synthetic class wrapper. +func Build(src []byte, language string, opts Options) (*CFG, error) { + spec := specFor(language) + if spec == nil { + return nil, fmt.Errorf("cfg: unsupported language %q", language) + } + prepared := src + if spec.dedent { + prepared = dedent(src) + } + + tree, err := parser.ParseFile(prepared, spec.grammar()) + if err != nil { + return nil, fmt.Errorf("cfg: parse: %w", err) + } + fn := findFuncRoot(tree.RootNode(), spec) + wrapOffset := 0 + if fn == nil && spec.classWrap { + tree.Close() + wrapped := append([]byte("class __gortexcfg__ {\n"), prepared...) + wrapped = append(wrapped, []byte("\n}")...) + tree, err = parser.ParseFile(wrapped, spec.grammar()) + if err != nil { + return nil, fmt.Errorf("cfg: parse (wrapped): %w", err) + } + fn = findFuncRoot(tree.RootNode(), spec) + wrapOffset = -1 + prepared = wrapped + } + defer tree.Close() + if fn == nil { + return nil, errors.New("cfg: no function definition found in source") + } + + c := &CFG{Language: spec.name, FuncName: opts.FuncName} + if c.FuncName == "" { + if nameNode := fn.ChildByFieldName("name"); nameNode != nil { + c.FuncName = nameNode.Content(prepared) + } + } + + b := &builder{ + spec: spec, + src: prepared, + lineOffset: opts.LineOffset + wrapOffset, + cfg: c, + edgeSeen: map[edgeKey]bool{}, + } + entry := b.newBlock("entry") + exit := b.newBlock("exit") + c.Entry, c.Exit = entry.ID, exit.ID + + // Parameters become block-0 definitions: one synthetic statement + // per parameter so every chain points at its own binding site. + b.cur = entry + for _, p := range spec.params(fn, prepared) { + st := b.record(p.line, p.line, p.name, "param") + st.Defs = []string{p.name} + } + + body := spec.bodyOf(fn) + if body == nil { + // A bodiless declaration (interface method, abstract method) + // still yields a degenerate-but-valid CFG. + b.edge(entry, exit, LabelSeq) + return c, nil + } + + first := b.newBlock("body") + b.edge(entry, first, LabelSeq) + b.cur = first + b.buildStmt(body) + if b.cur != nil { + b.edge(b.cur, exit, LabelSeq) + } + return c, nil +} + +// findFuncRoot returns the first function-like node in breadth-first +// order, so the outermost definition wins when functions nest. +func findFuncRoot(root *sitter.Node, spec *langSpec) *sitter.Node { + if root == nil { + return nil + } + queue := []*sitter.Node{root} + for len(queue) > 0 { + n := queue[0] + queue = queue[1:] + if spec.funcKinds[n.Type()] { + return n + } + for i := 0; i < int(n.NamedChildCount()); i++ { + if c := n.NamedChild(i); c != nil { + queue = append(queue, c) + } + } + } + return nil +} + +// dedent strips the longest common leading whitespace prefix from +// every non-blank line. Needed for indentation-sensitive snippets +// (a Python method sliced out of a class body would otherwise parse +// as an indentation error). +func dedent(src []byte) []byte { + lines := strings.Split(string(src), "\n") + prefix := "" + first := true + for _, ln := range lines { + if strings.TrimSpace(ln) == "" { + continue + } + indent := ln[:len(ln)-len(strings.TrimLeft(ln, " \t"))] + if first { + prefix = indent + first = false + continue + } + for !strings.HasPrefix(ln, prefix) { + prefix = prefix[:len(prefix)-1] + } + if prefix == "" { + return src + } + } + if prefix == "" { + return src + } + for i, ln := range lines { + lines[i] = strings.TrimPrefix(ln, prefix) + } + return []byte(strings.Join(lines, "\n")) +} + +// StatementAt returns the statement covering the given (file- +// absolute) line that defines name, or nil. Used by the dataflow +// refinement layer to anchor graph binding nodes onto CFG +// statements. +func (c *CFG) StatementAt(line int, definedVar string) *Statement { + var best *Statement + for _, st := range c.Stmts { + if line < st.StartLine || line > st.EndLine { + continue + } + found := false + for _, d := range st.Defs { + if d == definedVar { + found = true + break + } + } + if !found { + continue + } + // Prefer the tightest span when statements nest (loop + // headers cover their condition line, etc.). + if best == nil || (st.EndLine-st.StartLine) < (best.EndLine-best.StartLine) { + best = st + } + } + return best +} diff --git a/internal/cfg/cfg_test.go b/internal/cfg/cfg_test.go new file mode 100644 index 00000000..dfc152aa --- /dev/null +++ b/internal/cfg/cfg_test.go @@ -0,0 +1,716 @@ +package cfg + +import ( + "strings" + "testing" +) + +// mustBuild builds a CFG and fails the test on error. +func mustBuild(t *testing.T, src, lang string) *CFG { + t.Helper() + c, err := Build([]byte(src), lang, Options{}) + if err != nil { + t.Fatalf("Build(%s): %v", lang, err) + } + return c +} + +// stmtByText finds the first statement whose text contains sub. +func stmtByText(t *testing.T, c *CFG, sub string) *Statement { + t.Helper() + for _, st := range c.Stmts { + if strings.Contains(st.Text, sub) { + return st + } + } + t.Fatalf("no statement containing %q; have: %v", sub, stmtTexts(c)) + return nil +} + +func stmtTexts(c *CFG) []string { + out := make([]string, len(c.Stmts)) + for i, st := range c.Stmts { + out[i] = st.Text + } + return out +} + +// hasEdge reports whether an edge with the label connects the blocks +// holding the two statements (or block IDs when from/to are ints). +func hasEdgeLabel(c *CFG, label EdgeLabel) bool { + for _, e := range c.Edges { + if e.Label == label { + return true + } + } + return false +} + +func edgeBetween(c *CFG, from, to int, label EdgeLabel) bool { + for _, e := range c.Edges { + if e.From == from && e.To == to && e.Label == label { + return true + } + } + return false +} + +// chainFor returns the def→use chain for (use statement, var). +func chainFor(t *testing.T, r *ReachingResult, stmt int, v string) UseChain { + t.Helper() + for _, ch := range r.Chains { + if ch.Stmt == stmt && ch.Var == v { + return ch + } + } + t.Fatalf("no chain for stmt=%d var=%q; chains: %+v", stmt, v, r.Chains) + return UseChain{} +} + +func hasChain(r *ReachingResult, stmt int, v string) bool { + for _, ch := range r.Chains { + if ch.Stmt == stmt && ch.Var == v { + return true + } + } + return false +} + +func containsInt(xs []int, x int) bool { + for _, v := range xs { + if v == x { + return true + } + } + return false +} + +// --------------------------------------------------------------------------- +// construction basics +// --------------------------------------------------------------------------- + +func TestBuildGoIfElseDiamond(t *testing.T) { + c := mustBuild(t, ` +func f(a int) int { + x := 1 + if a > 0 { + x = 2 + } else { + x = 3 + } + return x +} +`, "go") + + if c.FuncName != "f" { + t.Errorf("FuncName = %q, want f", c.FuncName) + } + cond := stmtByText(t, c, "a > 0") + if cond.Kind != "cond" { + t.Errorf("condition kind = %q, want cond", cond.Kind) + } + // The diamond: cond block branches true and false, both sides + // rejoin before the return. + trueTo, falseTo := -1, -1 + for _, e := range c.Edges { + if e.From == cond.Block && e.Label == LabelTrue { + trueTo = e.To + } + if e.From == cond.Block && e.Label == LabelFalse { + falseTo = e.To + } + } + if trueTo < 0 || falseTo < 0 { + t.Fatalf("missing branch edges from cond block %d: %+v", cond.Block, c.Edges) + } + thenStmt := stmtByText(t, c, "x = 2") + elseStmt := stmtByText(t, c, "x = 3") + if thenStmt.Block != trueTo { + t.Errorf("then statement in block %d, want %d", thenStmt.Block, trueTo) + } + if elseStmt.Block != falseTo { + t.Errorf("else statement in block %d, want %d", elseStmt.Block, falseTo) + } + ret := stmtByText(t, c, "return x") + if ret.Block == thenStmt.Block || ret.Block == elseStmt.Block { + t.Errorf("return must live in the join block, not a branch arm") + } + if !edgeBetween(c, ret.Block, c.Exit, LabelReturn) { + t.Errorf("missing return edge to exit") + } +} + +func TestBuildGoLoopBreakContinue(t *testing.T) { + c := mustBuild(t, ` +func f(n int) int { + s := 0 + for i := 0; i < n; i++ { + if i == 3 { + continue + } + if i == 7 { + break + } + s += i + } + return s +} +`, "go") + + for _, want := range []EdgeLabel{LabelLoopBack, LabelBreak, LabelContinue, LabelTrue, LabelFalse} { + if !hasEdgeLabel(c, want) { + t.Errorf("missing %s edge; edges: %+v", want, c.Edges) + } + } + // continue must target the update block (i++ still runs), not + // skip it. + contStmt := stmtByText(t, c, "continue") + upd := stmtByText(t, c, "i++") + if !edgeBetween(c, contStmt.Block, upd.Block, LabelContinue) { + t.Errorf("continue should target the loop update block %d; edges: %+v", upd.Block, c.Edges) + } + // The update's defs: i (and a use of i). + if len(upd.Defs) != 1 || upd.Defs[0] != "i" { + t.Errorf("update defs = %v, want [i]", upd.Defs) + } + if len(upd.Uses) != 1 || upd.Uses[0] != "i" { + t.Errorf("update uses = %v, want [i]", upd.Uses) + } +} + +func TestBuildGoSwitchFallthrough(t *testing.T) { + c := mustBuild(t, ` +func f(x int) int { + y := 0 + switch x { + case 1: + y = 1 + fallthrough + case 2: + y = 2 + default: + y = 3 + } + return y +} +`, "go") + + if !hasEdgeLabel(c, LabelCase) { + t.Fatalf("missing case edges") + } + // fallthrough: the block holding `y = 1` flows into the block + // holding `y = 2` sequentially. + s1 := stmtByText(t, c, "y = 1") + s2 := stmtByText(t, c, "y = 2") + if !edgeBetween(c, s1.Block, s2.Block, LabelSeq) { + t.Errorf("missing fallthrough seq edge %d->%d; edges: %+v", s1.Block, s2.Block, c.Edges) + } + // With a default case there must be no unmatched-subject edge. + cond := stmtByText(t, c, "x") + if cond.Kind != "cond" { + cond = c.Stmts[1] + } + for _, e := range c.Edges { + if e.From == cond.Block && e.Label == LabelFalse { + t.Errorf("switch with default must not emit a false edge") + } + } +} + +func TestBuildGoLabeledBreak(t *testing.T) { + c := mustBuild(t, ` +func f() int { + s := 0 +outer: + for i := 0; i < 3; i++ { + for j := 0; j < 3; j++ { + if j == 2 { + break outer + } + s++ + } + } + return s +} +`, "go") + + br := stmtByText(t, c, "break outer") + // The labeled break must exit the OUTER loop: its target block + // must be the block holding `return s` (outer loop_end flows + // there) — concretely, the break edge must not target the inner + // loop's end. + var breakTo = -1 + for _, e := range c.Edges { + if e.From == br.Block && e.Label == LabelBreak { + breakTo = e.To + } + } + if breakTo < 0 { + t.Fatalf("no break edge from %d", br.Block) + } + ret := stmtByText(t, c, "return s") + // outer loop_end may be empty and flow to the return's block; + // accept either the return block itself or a block that reaches + // it via one seq hop. + ok := breakTo == ret.Block || edgeBetween(c, breakTo, ret.Block, LabelSeq) + if !ok { + t.Errorf("labeled break targets block %d, expected the outer loop end (return block %d)", breakTo, ret.Block) + } +} + +func TestBuildGoDeferNoted(t *testing.T) { + c := mustBuild(t, ` +func f() { + defer cleanup() + work() +} +`, "go") + d := stmtByText(t, c, "defer cleanup()") + if d.Kind != "defer" { + t.Errorf("defer kind = %q, want defer", d.Kind) + } + w := stmtByText(t, c, "work()") + if w.Block != d.Block { + t.Errorf("defer must not split the basic block: defer in %d, work in %d", d.Block, w.Block) + } +} + +func TestBuildGoInfiniteLoop(t *testing.T) { + c := mustBuild(t, ` +func f() { + for { + if done() { + break + } + } +} +`, "go") + if !hasEdgeLabel(c, LabelLoopBack) || !hasEdgeLabel(c, LabelBreak) { + t.Fatalf("infinite loop needs loop_back and break edges: %+v", c.Edges) + } +} + +// blockLabelByID returns a block's label. +func blockLabelByID(c *CFG, id int) string { return c.Blocks[id].Label } + +// breakEdgeTarget returns the block a break statement's break edge +// targets, or -1. +func breakEdgeTarget(c *CFG, br *Statement) int { + for _, e := range c.Edges { + if e.From == br.Block && e.Label == LabelBreak { + return e.To + } + } + return -1 +} + +// A `break` inside a Go select must exit the select, not the +// enclosing loop. With the fix the break targets the switch_end +// block; without it the break leaked to the loop / function exit and +// the statement after the select became unreachable. +func TestBuildGoSelectBreakExitsSelect(t *testing.T) { + c := mustBuild(t, `func f(ch chan int) int { + s := 0 + for { + select { + case v := <-ch: + s += v + break + } + s++ + } + return s +}`, "go") + br := stmtByText(t, c, "break") + to := breakEdgeTarget(c, br) + if to < 0 { + t.Fatalf("no break edge from block %d; edges: %+v", br.Block, c.Edges) + } + if got := blockLabelByID(c, to); got != "switch_end" { + t.Errorf("select break targets %q block %d, want the switch_end; edges: %+v", got, to, c.Edges) + } + // `s++` after the select must stay reachable from the select's + // merge point — the loop-carried `s` must reach the final use. + r := c.ReachingDefinitions() + inc := stmtByText(t, c, "s++") + chS := chainFor(t, r, inc.Index, "s") + if len(chS.Defs) == 0 { + t.Errorf("s++ must see a prior def of s: %v", chS.Defs) + } +} + +// A labeled Go switch: `break L` must resolve to the switch end so +// the post-switch use sees the in-case definition. Before the fix the +// labeled break leaked to the function exit, so `s = 1` never reached +// `return s`. +func TestBuildGoLabeledSwitchBreak(t *testing.T) { + c := mustBuild(t, `func f(x int) int { + s := 0 + L: + switch x { + case 1: + s = 1 + break L + } + return s +}`, "go") + br := stmtByText(t, c, "break L") + to := breakEdgeTarget(c, br) + if to < 0 { + t.Fatalf("no break edge from block %d", br.Block) + } + if got := blockLabelByID(c, to); got != "switch_end" { + t.Errorf("labeled switch break targets %q, want switch_end; edges: %+v", got, c.Edges) + } + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return s") + ch := chainFor(t, r, ret.Index, "s") + d1 := stmtByText(t, c, "s = 1") + if !containsInt(ch.Defs, d1.Index) { + t.Errorf("in-case def `s = 1` must reach the return via the labeled break: %v", ch.Defs) + } +} + +// A Go type switch `switch v := x.(type)` must define the alias v; +// every use of v inside the cases must chain back to it. +func TestBuildGoTypeSwitchAliasDefines(t *testing.T) { + c := mustBuild(t, `func f(x interface{}) int { + switch v := x.(type) { + case int: + return v + default: + return 0 + } +}`, "go") + r := c.ReachingDefinitions() + use := stmtByText(t, c, "return v") + ch := chainFor(t, r, use.Index, "v") + if len(ch.Defs) == 0 { + t.Fatalf("type-switch alias v must produce a chain at its use: %v", ch.Defs) + } + def := c.Stmts[ch.Defs[0]] + if def.Kind != "cond" { + t.Errorf("v's def should be the type-switch cond statement, got kind %q (%q)", def.Kind, def.Text) + } + foundV := false + for _, d := range def.Defs { + if d == "v" { + foundV = true + } + } + if !foundV { + t.Errorf("type-switch cond must define v: %v", def.Defs) + } + // The switched expression x is a read, not a binding. + if containsString(def.Defs, "x") { + t.Errorf("the switched value x must not be a definition: %v", def.Defs) + } +} + +func containsString(xs []string, x string) bool { + for _, v := range xs { + if v == x { + return true + } + } + return false +} + +// --------------------------------------------------------------------------- +// reaching definitions — textbook shapes +// --------------------------------------------------------------------------- + +// Redefinition kills: the second assignment must be the only def +// reaching the final use. +func TestReachingRedefinitionKills(t *testing.T) { + c := mustBuild(t, ` +func f() int { + x := 1 + x = 2 + return x +} +`, "go") + r := c.ReachingDefinitions() + def1 := stmtByText(t, c, "x := 1") + def2 := stmtByText(t, c, "x = 2") + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + if containsInt(ch.Defs, def1.Index) { + t.Errorf("killed definition %d still reaches the use: %v", def1.Index, ch.Defs) + } + if !containsInt(ch.Defs, def2.Index) { + t.Errorf("live definition %d does not reach the use: %v", def2.Index, ch.Defs) + } +} + +// Branch merge unions: both arm definitions reach the post-join use. +func TestReachingBranchMergeUnion(t *testing.T) { + c := mustBuild(t, ` +func f(a bool) int { + x := 0 + if a { + x = 1 + } else { + x = 2 + } + return x +} +`, "go") + r := c.ReachingDefinitions() + d1 := stmtByText(t, c, "x = 1") + d2 := stmtByText(t, c, "x = 2") + d0 := stmtByText(t, c, "x := 0") + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + if !containsInt(ch.Defs, d1.Index) || !containsInt(ch.Defs, d2.Index) { + t.Errorf("branch-arm defs must union at the join: %v (want %d and %d)", ch.Defs, d1.Index, d2.Index) + } + if containsInt(ch.Defs, d0.Index) { + t.Errorf("pre-branch def %d is killed on every path and must not reach: %v", d0.Index, ch.Defs) + } +} + +// One-armed if: the initial def survives the merge alongside the arm +// def. +func TestReachingOneArmedIfKeepsBoth(t *testing.T) { + c := mustBuild(t, ` +func f(a bool) int { + x := 0 + if a { + x = 1 + } + return x +} +`, "go") + r := c.ReachingDefinitions() + d0 := stmtByText(t, c, "x := 0") + d1 := stmtByText(t, c, "x = 1") + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + if !containsInt(ch.Defs, d0.Index) || !containsInt(ch.Defs, d1.Index) { + t.Errorf("one-armed if must keep both defs at the join: %v", ch.Defs) + } +} + +// Loop-carried defs: a def at the loop bottom reaches the use at the +// loop top on the next iteration. +func TestReachingLoopCarried(t *testing.T) { + c := mustBuild(t, ` +func f(n int) int { + s := 0 + for i := 0; i < n; i++ { + s = s + i + } + return s +} +`, "go") + r := c.ReachingDefinitions() + d0 := stmtByText(t, c, "s := 0") + dLoop := stmtByText(t, c, "s = s + i") + // The loop body's use of s sees both the init def and its own + // previous-iteration def. + ch := chainFor(t, r, dLoop.Index, "s") + if !containsInt(ch.Defs, d0.Index) { + t.Errorf("init def must reach the first iteration: %v", ch.Defs) + } + if !containsInt(ch.Defs, dLoop.Index) { + t.Errorf("loop-carried def must reach the next iteration: %v", ch.Defs) + } + // The condition's use of i sees the init AND the increment. + cond := stmtByText(t, c, "i < n") + chI := chainFor(t, r, cond.Index, "i") + init := stmtByText(t, c, "i := 0") + inc := stmtByText(t, c, "i++") + if !containsInt(chI.Defs, init.Index) || !containsInt(chI.Defs, inc.Index) { + t.Errorf("loop condition must see init and increment defs of i: %v", chI.Defs) + } + // And the final use of s unions init + loop defs. + ret := stmtByText(t, c, "return s") + chRet := chainFor(t, r, ret.Index, "s") + if !containsInt(chRet.Defs, d0.Index) || !containsInt(chRet.Defs, dLoop.Index) { + t.Errorf("post-loop use must union zero-trip and loop defs: %v", chRet.Defs) + } +} + +// Parameters are block-0 definitions reaching every unshadowed use. +func TestReachingParamsReachUses(t *testing.T) { + c := mustBuild(t, ` +func f(a int, b int) int { + x := a + b + return x +} +`, "go") + r := c.ReachingDefinitions() + assign := stmtByText(t, c, "x := a + b") + chA := chainFor(t, r, assign.Index, "a") + if len(chA.Defs) != 1 { + t.Fatalf("param a should have exactly one def: %v", chA.Defs) + } + paramStmt := c.Stmts[chA.Defs[0]] + if paramStmt.Kind != "param" { + t.Errorf("a's def should be the synthetic param statement, got kind %q", paramStmt.Kind) + } + if paramStmt.Block != c.Entry { + t.Errorf("param defs live in the entry block %d, got %d", c.Entry, paramStmt.Block) + } +} + +// A use with no defining statement (package global) produces no chain. +func TestReachingGlobalsHaveNoChain(t *testing.T) { + c := mustBuild(t, ` +func f() int { + return globalCounter +} +`, "go") + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return globalCounter") + if hasChain(r, ret.Index, "globalCounter") { + t.Errorf("global reads must not produce chains") + } +} + +// Early return prunes: a def after the return in the same branch +// can't reach uses before it. +func TestReachingEarlyReturnIsolation(t *testing.T) { + c := mustBuild(t, ` +func f(a bool) int { + x := 1 + if a { + return x + } + x = 2 + return x +} +`, "go") + r := c.ReachingDefinitions() + d1 := stmtByText(t, c, "x := 1") + d2 := stmtByText(t, c, "x = 2") + stmts := []*Statement{} + for _, st := range c.Stmts { + if strings.Contains(st.Text, "return x") { + stmts = append(stmts, st) + } + } + if len(stmts) != 2 { + t.Fatalf("want two return statements, got %d", len(stmts)) + } + early, late := stmts[0], stmts[1] + chEarly := chainFor(t, r, early.Index, "x") + if containsInt(chEarly.Defs, d2.Index) { + t.Errorf("def after the early return must not reach it: %v", chEarly.Defs) + } + if !containsInt(chEarly.Defs, d1.Index) { + t.Errorf("initial def must reach the early return: %v", chEarly.Defs) + } + chLate := chainFor(t, r, late.Index, "x") + if !containsInt(chLate.Defs, d2.Index) || containsInt(chLate.Defs, d1.Index) { + t.Errorf("late return sees only the redefinition: %v", chLate.Defs) + } +} + +// Statement granularity: two uses of the same variable in different +// statements get independent chains. +func TestReachingStatementGranularity(t *testing.T) { + c := mustBuild(t, ` +func f() int { + x := 1 + y := x + x = 2 + z := x + return y + z +} +`, "go") + r := c.ReachingDefinitions() + d1 := stmtByText(t, c, "x := 1") + d2 := stmtByText(t, c, "x = 2") + useY := stmtByText(t, c, "y := x") + useZ := stmtByText(t, c, "z := x") + chY := chainFor(t, r, useY.Index, "x") + chZ := chainFor(t, r, useZ.Index, "x") + if !containsInt(chY.Defs, d1.Index) || containsInt(chY.Defs, d2.Index) { + t.Errorf("y := x must see only the first def: %v", chY.Defs) + } + if !containsInt(chZ.Defs, d2.Index) || containsInt(chZ.Defs, d1.Index) { + t.Errorf("z := x must see only the second def: %v", chZ.Defs) + } +} + +// Closures are opaque: assignments inside a func literal don't +// define in the enclosing frame. +func TestClosureOpaque(t *testing.T) { + c := mustBuild(t, ` +func f() int { + x := 1 + g := func() { + x = 99 + } + g() + return x +} +`, "go") + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + d1 := stmtByText(t, c, "x := 1") + if len(ch.Defs) != 1 || ch.Defs[0] != d1.Index { + t.Errorf("closure write must not count as an enclosing-scope def: %v", ch.Defs) + } +} + +// --------------------------------------------------------------------------- +// options / rendering +// --------------------------------------------------------------------------- + +func TestLineOffsetShiftsLines(t *testing.T) { + src := `func f() int { + x := 1 + return x +}` + c, err := Build([]byte(src), "go", Options{LineOffset: 99}) + if err != nil { + t.Fatal(err) + } + st := stmtByText(t, c, "x := 1") + if st.StartLine != 101 { + t.Errorf("StartLine = %d, want 101 (snippet line 2 + offset 99)", st.StartLine) + } +} + +func TestMermaidRendering(t *testing.T) { + c := mustBuild(t, ` +func f(a int) int { + if a > 0 { + return 1 + } + return 2 +} +`, "go") + m := c.Mermaid() + if !strings.HasPrefix(m, "flowchart TD") { + t.Errorf("mermaid must start with flowchart TD: %q", m) + } + if !strings.Contains(m, "-->|true|") || !strings.Contains(m, "-->|false|") { + t.Errorf("mermaid must label branch edges: %s", m) + } + if !strings.Contains(m, "entry") || !strings.Contains(m, "exit") { + t.Errorf("mermaid must render entry/exit: %s", m) + } +} + +func TestUnsupportedLanguage(t *testing.T) { + if _, err := Build([]byte("x"), "cobol", Options{}); err == nil { + t.Fatal("expected error for unsupported language") + } + if SupportedLanguage("cobol") { + t.Fatal("cobol must not be supported") + } + if !SupportedLanguage("go") || !SupportedLanguage("ruby") { + t.Fatal("go and ruby must be supported") + } +} + +func TestNoFunctionInSource(t *testing.T) { + if _, err := Build([]byte("var x = 1"), "go", Options{}); err == nil { + t.Fatal("expected error when source holds no function") + } +} diff --git a/internal/cfg/defuse.go b/internal/cfg/defuse.go new file mode 100644 index 00000000..06078441 --- /dev/null +++ b/internal/cfg/defuse.go @@ -0,0 +1,224 @@ +package cfg + +import ( + sitter "github.com/zzet/gortex/internal/parser/tsitter" +) + +// extractDefUse walks one statement's subtree and returns the +// variable names it defines and reads, in source order, deduplicated. +// asDef seeds the walk in definition context (used for binding-only +// nodes: loop variables, catch parameters, match patterns). +// +// Nested function literals are opaque: their bodies run at an +// unknown later time, so neither their writes nor their reads are +// attributed to this statement. A nested definition that binds a +// name in the enclosing scope (Python `def inner`, JS `function g`) +// still defines that name. +func extractDefUse(spec *langSpec, src []byte, n *sitter.Node, asDef bool) (defs, uses []string) { + x := &duExtractor{spec: spec, src: src, seenDef: map[string]bool{}, seenUse: map[string]bool{}} + x.walk(n, asDef) + return x.defs, x.uses +} + +type duExtractor struct { + spec *langSpec + src []byte + defs []string + uses []string + seenDef map[string]bool + seenUse map[string]bool +} + +func (x *duExtractor) addDef(name string) { + if name == "" || name == "_" || x.seenDef[name] { + return + } + x.seenDef[name] = true + x.defs = append(x.defs, name) +} + +func (x *duExtractor) addUse(name string) { + if name == "" || name == "_" || x.seenUse[name] { + return + } + x.seenUse[name] = true + x.uses = append(x.uses, name) +} + +// walk visits n in read context (asDef=false) or binding context. +func (x *duExtractor) walk(n *sitter.Node, asDef bool) { + if n == nil { + return + } + t := n.Type() + + if x.spec.nestedFuncs[t] { + // A named nested definition binds its name in this scope. + if nameNode := n.ChildByFieldName("name"); nameNode != nil && x.spec.identKinds[nameNode.Type()] { + x.addDef(nameNode.Content(x.src)) + } + return + } + if x.spec.skipKinds[t] { + return + } + if x.spec.identKinds[t] { + if asDef { + x.addDef(n.Content(x.src)) + } else { + x.addUse(n.Content(x.src)) + } + return + } + if rule, ok := x.spec.assigns[t]; ok { + x.handleAssign(n, rule) + return + } + if rule, ok := x.spec.updates[t]; ok { + x.handleUpdate(n, rule) + return + } + x.walkChildren(n, asDef) +} + +func (x *duExtractor) walkChildren(n *sitter.Node, asDef bool) { + t := n.Type() + skips := x.spec.skipFields[t] + cnt := int(n.ChildCount()) + for i := 0; i < cnt; i++ { + c := n.Child(i) + if c == nil || !c.IsNamed() { + continue + } + if skips != nil { + if f := n.FieldNameForChild(i); f != "" && skips[f] { + continue + } + } + x.walk(c, asDef) + } +} + +// handleAssign processes an assignment-shaped node: the LHS field +// holds binding targets, every other child is read. +func (x *duExtractor) handleAssign(n *sitter.Node, rule assignRule) { + alsoUse := false + switch rule.mode { + case augAlways: + alsoUse = true + case augIfOp: + alsoUse = hasAugmentedOperator(n) + } + skips := x.spec.skipFields[n.Type()] + cnt := int(n.ChildCount()) + for i := 0; i < cnt; i++ { + c := n.Child(i) + if c == nil || !c.IsNamed() { + continue + } + f := n.FieldNameForChild(i) + if skips != nil && f != "" && skips[f] { + continue + } + if f == rule.lhsField { + x.walkLHS(c, alsoUse) + continue + } + // Type annotations and other non-value fields are pruned by + // skipKinds inside the recursive walk. + x.walk(c, false) + } +} + +// handleUpdate processes increment/decrement-shaped nodes whose +// target is both read and written. +func (x *duExtractor) handleUpdate(n *sitter.Node, rule updateRule) { + var target *sitter.Node + if rule.field != "" { + target = n.ChildByFieldName(rule.field) + } + if target == nil { + target = n.NamedChild(0) + } + if target != nil { + x.walkLHS(target, true) + } +} + +// walkLHS classifies an assignment target: a bare identifier is a +// definition (plus a use for augmented assigns); pattern containers +// recurse; anything else (member access, index expression) reads its +// base — writing x.f or x[i] mutates the object, not the binding. +func (x *duExtractor) walkLHS(n *sitter.Node, alsoUse bool) { + if n == nil { + return + } + t := n.Type() + if x.spec.identKinds[t] { + name := n.Content(x.src) + x.addDef(name) + if alsoUse { + x.addUse(name) + } + return + } + if x.spec.patternContainers[t] { + skips := x.spec.skipFields[t] + cnt := int(n.ChildCount()) + for i := 0; i < cnt; i++ { + c := n.Child(i) + if c == nil || !c.IsNamed() { + continue + } + if skips != nil { + if f := n.FieldNameForChild(i); f != "" && skips[f] { + continue + } + } + x.walkLHS(c, alsoUse) + // Default values inside destructuring patterns are + // handled by the assignment_pattern rule during the + // recursive walk; nothing extra needed here. + } + return + } + if rule, ok := x.spec.assigns[t]; ok && t == "assignment_pattern" { + // Destructuring default: `{a = 1}` — a is a def, 1 is read. + if l := n.ChildByFieldName(rule.lhsField); l != nil { + x.walkLHS(l, alsoUse) + } + if r := n.ChildByFieldName("right"); r != nil { + x.walk(r, false) + } + return + } + // Non-binding target: reads flow normally. + x.walk(n, false) +} + +// hasAugmentedOperator reports whether an assignment node carries a +// compound operator token (+=, -=, &&=, …) rather than plain "=". +func hasAugmentedOperator(n *sitter.Node) bool { + if op := n.ChildByFieldName("operator"); op != nil { + return op.Type() != "=" + } + cnt := int(n.ChildCount()) + for i := 0; i < cnt; i++ { + c := n.Child(i) + if c == nil || c.IsNamed() { + continue + } + t := c.Type() + if t == "=" { + return false + } + if len(t) >= 2 && t[len(t)-1] == '=' { + switch t { + case "==", "!=", "<=", ">=": + continue + } + return true + } + } + return false +} diff --git a/internal/cfg/lang.go b/internal/cfg/lang.go new file mode 100644 index 00000000..53925ea4 --- /dev/null +++ b/internal/cfg/lang.go @@ -0,0 +1,1287 @@ +package cfg + +import ( + "strings" + + sitter "github.com/zzet/gortex/internal/parser/tsitter" + golang "github.com/zzet/gortex/internal/parser/tsitter/golang" + javalang "github.com/zzet/gortex/internal/parser/tsitter/java" + jslang "github.com/zzet/gortex/internal/parser/tsitter/javascript" + pylang "github.com/zzet/gortex/internal/parser/tsitter/python" + rubylang "github.com/zzet/gortex/internal/parser/tsitter/ruby" + rustlang "github.com/zzet/gortex/internal/parser/tsitter/rust" + tsxlang "github.com/zzet/gortex/internal/parser/tsitter/tsx" + tslang "github.com/zzet/gortex/internal/parser/tsitter/typescript" +) + +// assignMode says whether an assignment-shaped construct also reads +// its targets before writing them. +type assignMode int + +const ( + augNever assignMode = iota // plain assignment / declaration + augAlways // augmented assign (x += 1): def + use + augIfOp // augmented iff the operator token isn't bare "=" +) + +// assignRule describes one assignment-shaped node kind: which field +// holds the write targets and whether the targets are also read. +type assignRule struct { + lhsField string + mode assignMode +} + +// updateRule describes increment/decrement-shaped nodes whose single +// target is both read and written. An empty field means "first named +// child". +type updateRule struct { + field string +} + +// param is one parameter binding discovered in a function header. +type param struct { + name string + line int +} + +// langSpec is the per-language table driving CFG construction and +// def/use extraction. The shared builder owns all block/edge +// mechanics; the spec only names the AST shapes. +type langSpec struct { + name string + grammar func() *sitter.Language + classWrap bool // retry parse inside `class __gortexcfg__ { … }` + dedent bool // strip common indentation before parsing + + funcKinds map[string]bool + // dispatch consumes control constructs; returning false makes the + // builder record the node as a leaf statement. + dispatch func(b *builder, n *sitter.Node) bool + + identKinds map[string]bool + assigns map[string]assignRule + updates map[string]updateRule + skipFields map[string]map[string]bool + skipKinds map[string]bool + nestedFuncs map[string]bool + patternContainers map[string]bool + paramSkipFields map[string]bool + paramSkipKinds map[string]bool +} + +// bodyOf locates a function node's body, falling back from the +// `body` field to the language's known body node kind (Ruby methods +// carry an unfielded body_statement in some grammar shapes). +func (s *langSpec) bodyOf(fn *sitter.Node) *sitter.Node { + if b := fn.ChildByFieldName("body"); b != nil { + return b + } + if s.name == "ruby" { + return childOfType(fn, "body_statement") + } + return nil +} + +// params collects the function's parameter bindings (including the +// Go receiver) as block-0 definitions. +func (s *langSpec) params(fn *sitter.Node, src []byte) []param { + var out []param + collect := func(container *sitter.Node) { + if container == nil { + return + } + var walk func(n *sitter.Node) + walk = func(n *sitter.Node) { + if s.identKinds[n.Type()] { + name := n.Content(src) + if name != "" && name != "_" { + out = append(out, param{name: name, line: int(n.StartPoint().Row) + 1}) + } + return + } + if s.paramSkipKinds[n.Type()] { + return + } + cnt := int(n.ChildCount()) + for i := 0; i < cnt; i++ { + c := n.Child(i) + if c == nil || !c.IsNamed() { + continue + } + if f := n.FieldNameForChild(i); f != "" && s.paramSkipFields[f] { + continue + } + walk(c) + } + } + walk(container) + } + if s.name == "go" { + collect(fn.ChildByFieldName("receiver")) + } + if p := fn.ChildByFieldName("parameters"); p != nil { + collect(p) + } else if p := fn.ChildByFieldName("parameter"); p != nil { + // JS arrow functions with a single unparenthesized parameter. + collect(p) + } + return out +} + +// specFor maps a graph language label to its spec. Returns nil for +// languages without a control-construct table. +func specFor(language string) *langSpec { + switch strings.ToLower(language) { + case "go", "golang": + return goSpec + case "python", "py": + return pySpec + case "javascript", "js", "jsx": + return jsSpec + case "typescript", "ts": + return tsSpec + case "tsx": + return tsxSpec + case "java": + return javaSpec + case "rust", "rs": + return rustSpec + case "ruby", "rb": + return rubySpec + } + return nil +} + +// SupportedLanguage reports whether Build can construct a CFG for +// the given graph language label. +func SupportedLanguage(language string) bool { return specFor(language) != nil } + +// childOfType returns the first named child with the given type. +func childOfType(n *sitter.Node, kind string) *sitter.Node { + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + if c != nil && c.Type() == kind { + return c + } + } + return nil +} + +// childrenByField returns every child stored under the given field +// name, in order. tree-sitter allows repeated fields (Python's +// if_statement stacks elif/else clauses under `alternative`). +func childrenByField(n *sitter.Node, field string) []*sitter.Node { + var out []*sitter.Node + cnt := int(n.ChildCount()) + for i := 0; i < cnt; i++ { + if n.FieldNameForChild(i) == field { + if c := n.Child(i); c != nil { + out = append(out, c) + } + } + } + return out +} + +// --------------------------------------------------------------------------- +// Go +// --------------------------------------------------------------------------- + +var goSpec = &langSpec{ + name: "go", + grammar: golang.GetLanguage, + funcKinds: map[string]bool{ + "function_declaration": true, "method_declaration": true, "func_literal": true, + }, + identKinds: map[string]bool{"identifier": true}, + assigns: map[string]assignRule{ + "short_var_declaration": {lhsField: "left", mode: augNever}, + "assignment_statement": {lhsField: "left", mode: augIfOp}, + "var_spec": {lhsField: "name", mode: augNever}, + "const_spec": {lhsField: "name", mode: augNever}, + "range_clause": {lhsField: "left", mode: augNever}, + "receive_statement": {lhsField: "left", mode: augNever}, + }, + updates: map[string]updateRule{ + "inc_statement": {}, "dec_statement": {}, + }, + skipFields: map[string]map[string]bool{ + "selector_expression": {"field": true}, + "keyed_element": {"key": true}, + }, + skipKinds: map[string]bool{ + "field_identifier": true, "type_identifier": true, "package_identifier": true, + "label_name": true, + }, + nestedFuncs: map[string]bool{"func_literal": true}, + patternContainers: map[string]bool{"expression_list": true}, + paramSkipFields: map[string]bool{"type": true}, + paramSkipKinds: map[string]bool{"type_identifier": true, "qualified_type": true}, + dispatch: goDispatch, +} + +func goDispatch(b *builder, n *sitter.Node) bool { + switch n.Type() { + case "block", "statement_list": + b.buildSeq(n) + case "if_statement": + b.buildIf(n.ChildByFieldName("initializer"), n.ChildByFieldName("condition"), + n.ChildByFieldName("consequence"), n.ChildByFieldName("alternative")) + case "for_statement": + body := n.ChildByFieldName("body") + if fc := childOfType(n, "for_clause"); fc != nil { + b.buildLoop(loopParts{ + init: fc.ChildByFieldName("initializer"), + cond: fc.ChildByFieldName("condition"), + update: fc.ChildByFieldName("update"), + body: body, + }) + return true + } + if rc := childOfType(n, "range_clause"); rc != nil { + b.buildLoop(loopParts{headerStmt: rc, body: body}) + return true + } + // `for cond { … }` — the condition is the lone named child + // that isn't the body; bare `for { … }` has none. + var cond *sitter.Node + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + if c != nil && !c.Equal(body) && c.Type() != "comment" { + cond = c + break + } + } + b.buildLoop(loopParts{cond: cond, body: body, infinite: cond == nil}) + case "expression_switch_statement", "type_switch_statement", "select_statement": + b.buildGoSwitch(n) + case "labeled_statement": + if lbl := childOfType(n, "label_name"); lbl != nil { + b.pendingLabel = lbl.Content(b.src) + } + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + if c != nil && c.Type() != "label_name" { + b.buildStmt(c) + } + } + b.pendingLabel = "" + case "break_statement": + b.buildBreak(n, goJumpLabel(n, b.src)) + case "continue_statement": + b.buildContinue(n, goJumpLabel(n, b.src)) + case "return_statement": + b.buildReturn(n, "return", LabelReturn) + case "fallthrough_statement": + b.leaf(n, "fallthrough") + b.pendingFallthrough = true + case "defer_statement": + // Defers run at function exit; they do not alter intra-block + // flow, so the statement is recorded in place and flagged. + b.leaf(n, "defer") + case "go_statement": + b.leaf(n, "go") + default: + return false + } + return true +} + +// goJumpLabel extracts the optional label off a break/continue. +func goJumpLabel(n *sitter.Node, src []byte) string { + if lbl := childOfType(n, "label_name"); lbl != nil { + return lbl.Content(src) + } + return "" +} + +// buildGoSwitch covers expression/type switches and select: every +// case is dispatched from the head, cases do not fall through unless +// an explicit fallthrough statement was seen. +func (b *builder) buildGoSwitch(n *sitter.Node) { + b.ensureCur() + if init := n.ChildByFieldName("initializer"); init != nil { + b.buildStmt(init) + } + if alias := n.ChildByFieldName("alias"); alias != nil { + // Type switch `switch v := x.(type)`: the alias binds v as a + // definition, the switched value field is a read. (The grammar + // always exposes `value` for the subject, so the binding lives + // in the alias field, not value.) + st := b.recordNode(alias, "cond") + st.Defs, _ = extractDefUse(b.spec, b.src, alias, true) + if v := n.ChildByFieldName("value"); v != nil { + _, st.Uses = extractDefUse(b.spec, b.src, v, false) + } + } else if v := n.ChildByFieldName("value"); v != nil { + b.addStmt(v, "cond") + } + head := b.cur + after := b.newBlock("switch_end") + // Go `break` (labeled or bare) exits the switch/select, not the + // enclosing loop. Push a break-only frame (no continue target, so + // `continue` still skips past it to the loop) and consume any + // pending label from an enclosing labeled statement. + b.pushFrame(frame{label: b.takeLabel(), breakTo: after}) + defer b.popFrame() + hasDefault := false + var pendingFT *Block + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + if c == nil { + continue + } + switch c.Type() { + case "expression_case", "type_case", "communication_case", "default_case": + default: + continue + } + caseBlock := b.newBlock("case") + b.edge(head, caseBlock, LabelCase) + if pendingFT != nil { + b.edge(pendingFT, caseBlock, LabelSeq) + pendingFT = nil + } + b.cur = caseBlock + if c.Type() == "default_case" { + hasDefault = true + } else if c.Type() == "communication_case" { + if comm := c.ChildByFieldName("communication"); comm != nil { + b.addStmt(comm, "case") + } + } else if v := c.ChildByFieldName("value"); v != nil { + b.addStmt(v, "case") + } else if tn := c.ChildByFieldName("type"); tn != nil { + b.addStmt(tn, "case") + } + if sl := childOfType(c, "statement_list"); sl != nil { + b.buildSeq(sl) + } + if b.pendingFallthrough { + pendingFT = b.cur + b.pendingFallthrough = false + b.cur = nil + } + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + } + if !hasDefault { + b.edge(head, after, LabelFalse) + } + b.cur = after +} + +// --------------------------------------------------------------------------- +// Python +// --------------------------------------------------------------------------- + +var pySpec = &langSpec{ + name: "python", + grammar: pylang.GetLanguage, + dedent: true, + funcKinds: map[string]bool{ + "function_definition": true, + }, + identKinds: map[string]bool{"identifier": true}, + assigns: map[string]assignRule{ + "assignment": {lhsField: "left", mode: augNever}, + "augmented_assignment": {lhsField: "left", mode: augAlways}, + "named_expression": {lhsField: "name", mode: augNever}, + "as_pattern": {lhsField: "alias", mode: augNever}, + }, + updates: map[string]updateRule{}, + skipFields: map[string]map[string]bool{ + "attribute": {"attribute": true}, + "keyword_argument": {"name": true}, + }, + skipKinds: map[string]bool{}, + nestedFuncs: map[string]bool{ + "function_definition": true, "lambda": true, "class_definition": true, + }, + patternContainers: map[string]bool{ + "pattern_list": true, "tuple_pattern": true, "list_pattern": true, + "as_pattern_target": true, + }, + paramSkipFields: map[string]bool{"type": true, "value": true}, + paramSkipKinds: map[string]bool{"type": true}, + dispatch: pyDispatch, +} + +func pyDispatch(b *builder, n *sitter.Node) bool { + switch n.Type() { + case "block": + b.buildSeq(n) + case "if_statement": + b.buildIfChain(n.ChildByFieldName("condition"), n.ChildByFieldName("consequence"), + childrenByField(n, "alternative")) + case "elif_clause": + // Reached only via buildIfChain recursion fallback. + b.buildIfChain(n.ChildByFieldName("condition"), n.ChildByFieldName("consequence"), nil) + case "else_clause": + if body := n.ChildByFieldName("body"); body != nil { + b.buildStmt(body) + } else { + b.buildSeq(n) + } + case "while_statement": + b.buildLoop(loopParts{ + cond: n.ChildByFieldName("condition"), + body: n.ChildByFieldName("body"), + elseNode: n.ChildByFieldName("alternative"), + }) + case "for_statement": + b.buildLoop(loopParts{ + headerStmt: n, + headerStmtOnlyHeaderFields: true, + body: n.ChildByFieldName("body"), + elseNode: n.ChildByFieldName("alternative"), + }) + case "try_statement": + b.buildPyTry(n) + case "with_statement": + // `with` introduces bindings but no branching beyond the + // (ignored) exception path already modeled by try blocks. + if cl := childOfType(n, "with_clause"); cl != nil { + b.leaf(cl, "with") + } + if body := n.ChildByFieldName("body"); body != nil { + b.buildStmt(body) + } + case "break_statement": + b.buildBreak(n, "") + case "continue_statement": + b.buildContinue(n, "") + case "return_statement": + b.buildReturn(n, "return", LabelReturn) + case "raise_statement": + b.buildReturn(n, "throw", LabelException) + case "match_statement": + b.buildPyMatch(n) + default: + return false + } + return true +} + +// buildPyTry maps try/except/else/finally onto the generic try +// builder. +func (b *builder) buildPyTry(n *sitter.Node) { + p := tryParts{bodyNode: n.ChildByFieldName("body")} + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + if c == nil { + continue + } + switch c.Type() { + case "except_clause", "except_group_clause": + h := handlerPart{bodyNode: childOfType(c, "block")} + if v := c.ChildByFieldName("value"); v != nil { + h.headerNode = v + } + p.handlers = append(p.handlers, h) + case "else_clause": + p.elseNode = childOfType(c, "block") + case "finally_clause": + p.finallyNode = childOfType(c, "block") + } + } + b.buildTry(p) +} + +// buildPyMatch maps structural pattern matching onto the switch +// shape: every case is dispatched from the subject, no fallthrough. +func (b *builder) buildPyMatch(n *sitter.Node) { + b.ensureCur() + if subj := n.ChildByFieldName("subject"); subj != nil { + b.addStmt(subj, "cond") + } + head := b.cur + after := b.newBlock("match_end") + // The case clauses live inside the match's body block (stored + // under the `alternative` field), not as direct children of the + // match_statement, so descend into the body before scanning. + scope := n.ChildByFieldName("body") + if scope == nil { + scope = n + } + matchedAll := false + for i := 0; i < int(scope.NamedChildCount()); i++ { + c := scope.NamedChild(i) + if c == nil || c.Type() != "case_clause" { + continue + } + caseBlock := b.newBlock("case") + b.edge(head, caseBlock, LabelCase) + b.cur = caseBlock + // The case pattern binds names (capture patterns) — record it + // as a definition; the guard, when present, reads. + if pat := childOfType(c, "case_pattern"); pat != nil { + st := b.recordNode(pat, "case") + st.Defs, _ = extractDefUse(b.spec, b.src, pat, true) + if g := c.ChildByFieldName("guard"); g != nil { + _, st.Uses = extractDefUse(b.spec, b.src, g, false) + } + if isPyWildcardCase(pat) { + matchedAll = true + } + } + if body := c.ChildByFieldName("consequence"); body != nil { + b.buildStmt(body) + } + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + } + if !matchedAll { + b.edge(head, after, LabelFalse) + } + b.cur = after +} + +// isPyWildcardCase reports whether a case_pattern is the catch-all +// `case _`: the grammar emits an empty case_pattern (no children) for +// the wildcard, so the subsequent arms are unreachable. +func isPyWildcardCase(pat *sitter.Node) bool { + return pat.NamedChildCount() == 0 +} + +// --------------------------------------------------------------------------- +// JavaScript / TypeScript +// --------------------------------------------------------------------------- + +func jsLikeSpec(name string, grammar func() *sitter.Language) *langSpec { + return &langSpec{ + name: name, + grammar: grammar, + classWrap: true, + funcKinds: map[string]bool{ + "function_declaration": true, "function_expression": true, "function": true, + "generator_function_declaration": true, "generator_function": true, + "method_definition": true, "arrow_function": true, + }, + identKinds: map[string]bool{ + "identifier": true, "shorthand_property_identifier_pattern": true, + }, + assigns: map[string]assignRule{ + "variable_declarator": {lhsField: "name", mode: augNever}, + "assignment_expression": {lhsField: "left", mode: augNever}, + "augmented_assignment_expression": {lhsField: "left", mode: augAlways}, + "assignment_pattern": {lhsField: "left", mode: augNever}, + }, + updates: map[string]updateRule{ + "update_expression": {field: "argument"}, + }, + skipFields: map[string]map[string]bool{ + "member_expression": {"property": true}, + "pair": {"key": true}, + "pair_pattern": {"key": true}, + }, + skipKinds: map[string]bool{ + "property_identifier": true, "statement_identifier": true, + "type_annotation": true, "type_identifier": true, "predefined_type": true, + }, + nestedFuncs: map[string]bool{ + "function_declaration": true, "function_expression": true, "function": true, + "generator_function_declaration": true, "generator_function": true, + "method_definition": true, "arrow_function": true, "class_declaration": true, + "class": true, + }, + patternContainers: map[string]bool{ + "array_pattern": true, "object_pattern": true, "pair_pattern": true, + "rest_pattern": true, + }, + paramSkipFields: map[string]bool{"type": true, "value": true, "right": true}, + paramSkipKinds: map[string]bool{"type_annotation": true}, + dispatch: jsDispatch, + } +} + +var ( + jsSpec = jsLikeSpec("javascript", jslang.GetLanguage) + tsSpec = jsLikeSpec("typescript", tslang.GetLanguage) + tsxSpec = jsLikeSpec("tsx", tsxlang.GetLanguage) +) + +func jsDispatch(b *builder, n *sitter.Node) bool { + switch n.Type() { + case "statement_block": + b.buildSeq(n) + case "if_statement": + b.buildIf(nil, n.ChildByFieldName("condition"), + n.ChildByFieldName("consequence"), n.ChildByFieldName("alternative")) + case "else_clause": + for i := 0; i < int(n.NamedChildCount()); i++ { + if c := n.NamedChild(i); c != nil { + b.buildStmt(c) + } + } + case "while_statement": + b.buildLoop(loopParts{cond: n.ChildByFieldName("condition"), body: n.ChildByFieldName("body")}) + case "do_statement": + b.buildLoop(loopParts{cond: n.ChildByFieldName("condition"), body: n.ChildByFieldName("body"), postTest: true}) + case "for_statement": + b.buildLoop(loopParts{ + init: n.ChildByFieldName("initializer"), + cond: n.ChildByFieldName("condition"), + update: n.ChildByFieldName("increment"), + body: n.ChildByFieldName("body"), + }) + case "for_in_statement": + b.buildLoop(loopParts{headerStmt: n, headerStmtOnlyHeaderFields: true, body: n.ChildByFieldName("body")}) + case "switch_statement": + b.buildJsSwitch(n) + case "try_statement": + p := tryParts{bodyNode: n.ChildByFieldName("body")} + if h := n.ChildByFieldName("handler"); h != nil { + p.handlers = append(p.handlers, handlerPart{ + headerNode: h.ChildByFieldName("parameter"), + headerDefs: true, + bodyNode: h.ChildByFieldName("body"), + }) + } + if f := n.ChildByFieldName("finalizer"); f != nil { + p.finallyNode = f.ChildByFieldName("body") + } + b.buildTry(p) + case "labeled_statement": + if lbl := n.ChildByFieldName("label"); lbl != nil { + b.pendingLabel = lbl.Content(b.src) + } + if body := n.ChildByFieldName("body"); body != nil { + b.buildStmt(body) + } + b.pendingLabel = "" + case "break_statement": + b.buildBreak(n, fieldText(n, "label", b.src)) + case "continue_statement": + b.buildContinue(n, fieldText(n, "label", b.src)) + case "return_statement": + b.buildReturn(n, "return", LabelReturn) + case "throw_statement": + b.buildReturn(n, "throw", LabelException) + default: + return false + } + return true +} + +// buildJsSwitch models C-style fallthrough: consecutive cases chain +// unless a break/return terminated the previous one; break targets +// the switch end. Java arrow rules (`case 1 -> …`) never fall +// through, so they bypass the fallthrough chaining and edge straight +// to the switch end. +func (b *builder) buildJsSwitch(n *sitter.Node) { + b.ensureCur() + if v := n.ChildByFieldName("value"); v != nil { + b.addStmt(v, "cond") + } else if v := n.ChildByFieldName("condition"); v != nil { + b.addStmt(v, "cond") + } + head := b.cur + after := b.newBlock("switch_end") + b.pushFrame(frame{breakTo: after}) + hasDefault := false + var prevEnd *Block + body := n.ChildByFieldName("body") + if body == nil { + body = n + } + for i := 0; i < int(body.NamedChildCount()); i++ { + c := body.NamedChild(i) + if c == nil { + continue + } + var isDefault, isArrowRule bool + switch c.Type() { + case "switch_case": + case "switch_default": + isDefault = true + case "switch_block_statement_group": + isDefault = javaGroupIsDefault(c) + case "switch_rule": + // Arrow form: `case 1 -> { … }`. No fallthrough. + isDefault = javaGroupIsDefault(c) + isArrowRule = true + default: + continue + } + caseBlock := b.newBlock("case") + b.edge(head, caseBlock, LabelCase) + if prevEnd != nil && !isArrowRule { + b.edge(prevEnd, caseBlock, LabelSeq) + } + b.cur = caseBlock + if isDefault { + hasDefault = true + } + b.buildCaseGroupBody(c) + if isArrowRule { + // Implicit break: the rule's end leaves the switch directly + // and does not chain into the next rule. + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + prevEnd = nil + continue + } + prevEnd = b.cur + } + b.popFrame() + if prevEnd != nil { + b.edge(prevEnd, after, LabelSeq) + } + if !hasDefault { + b.edge(head, after, LabelFalse) + } + b.cur = after +} + +// buildCaseGroupBody emits the case-label match values then the +// case's statements. Works for JS switch_case/switch_default (value +// field + repeated body fields) and Java switch_block_statement_group +// (switch_label children followed by statements). +func (b *builder) buildCaseGroupBody(c *sitter.Node) { + if v := c.ChildByFieldName("value"); v != nil { + b.addStmt(v, "case") + } + cnt := int(c.ChildCount()) + for i := 0; i < cnt; i++ { + ch := c.Child(i) + if ch == nil || !ch.IsNamed() { + continue + } + f := c.FieldNameForChild(i) + if f == "value" { + continue + } + if ch.Type() == "switch_label" { + if ch.NamedChildCount() > 0 { + b.addStmt(ch, "case") + } + continue + } + if f == "body" || f == "" { + b.buildStmt(ch) + } + } +} + +// javaGroupIsDefault reports whether a Java case group carries the +// `default` label (a switch_label with no children). +func javaGroupIsDefault(c *sitter.Node) bool { + for i := 0; i < int(c.NamedChildCount()); i++ { + ch := c.NamedChild(i) + if ch != nil && ch.Type() == "switch_label" && ch.NamedChildCount() == 0 { + return true + } + } + return false +} + +func fieldText(n *sitter.Node, field string, src []byte) string { + if c := n.ChildByFieldName(field); c != nil { + return c.Content(src) + } + return "" +} + +// --------------------------------------------------------------------------- +// Java +// --------------------------------------------------------------------------- + +var javaSpec = &langSpec{ + name: "java", + grammar: javalang.GetLanguage, + classWrap: true, + funcKinds: map[string]bool{ + "method_declaration": true, "constructor_declaration": true, + }, + identKinds: map[string]bool{"identifier": true}, + assigns: map[string]assignRule{ + "variable_declarator": {lhsField: "name", mode: augNever}, + "assignment_expression": {lhsField: "left", mode: augIfOp}, + "resource": {lhsField: "name", mode: augNever}, + }, + updates: map[string]updateRule{ + "update_expression": {}, + }, + skipFields: map[string]map[string]bool{ + "field_access": {"field": true}, + "method_invocation": {"name": true}, + "method_reference": {}, + }, + skipKinds: map[string]bool{ + "type_identifier": true, "integral_type": true, "floating_point_type": true, + "boolean_type": true, "void_type": true, "generic_type": true, + "annotation": true, "marker_annotation": true, "modifiers": true, + }, + nestedFuncs: map[string]bool{ + "lambda_expression": true, "class_declaration": true, "anonymous_class_body": true, + }, + patternContainers: map[string]bool{}, + paramSkipFields: map[string]bool{"type": true, "dimensions": true}, + paramSkipKinds: map[string]bool{ + "type_identifier": true, "integral_type": true, "floating_point_type": true, + "boolean_type": true, "generic_type": true, "annotation": true, + "marker_annotation": true, "modifiers": true, + }, + dispatch: javaDispatch, +} + +func javaDispatch(b *builder, n *sitter.Node) bool { + switch n.Type() { + case "block": + b.buildSeq(n) + case "if_statement": + b.buildIf(nil, n.ChildByFieldName("condition"), + n.ChildByFieldName("consequence"), n.ChildByFieldName("alternative")) + case "while_statement": + b.buildLoop(loopParts{cond: n.ChildByFieldName("condition"), body: n.ChildByFieldName("body")}) + case "do_statement": + b.buildLoop(loopParts{cond: n.ChildByFieldName("condition"), body: n.ChildByFieldName("body"), postTest: true}) + case "for_statement": + b.buildLoop(loopParts{ + init: n.ChildByFieldName("init"), + cond: n.ChildByFieldName("condition"), + update: n.ChildByFieldName("update"), + body: n.ChildByFieldName("body"), + }) + case "enhanced_for_statement": + b.buildLoop(loopParts{headerStmt: n, headerStmtOnlyHeaderFields: true, body: n.ChildByFieldName("body")}) + case "switch_expression", "switch_statement": + b.buildJsSwitch(n) + case "try_statement", "try_with_resources_statement": + p := tryParts{bodyNode: n.ChildByFieldName("body")} + if res := n.ChildByFieldName("resources"); res != nil { + b.ensureCur() + b.addStmt(res, "with") + } + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + if c == nil { + continue + } + switch c.Type() { + case "catch_clause": + h := handlerPart{bodyNode: c.ChildByFieldName("body"), headerDefs: true} + if fp := childOfType(c, "catch_formal_parameter"); fp != nil { + h.headerNode = fp + } + p.handlers = append(p.handlers, h) + case "finally_clause": + p.finallyNode = childOfType(c, "block") + } + } + b.buildTry(p) + case "labeled_statement": + // (identifier) ':' statement — no fields in this grammar. + first := n.NamedChild(0) + if first != nil && first.Type() == "identifier" { + b.pendingLabel = first.Content(b.src) + } + for i := 1; i < int(n.NamedChildCount()); i++ { + if c := n.NamedChild(i); c != nil { + b.buildStmt(c) + } + } + b.pendingLabel = "" + case "break_statement": + b.buildBreak(n, javaJumpLabel(n, b.src)) + case "continue_statement": + b.buildContinue(n, javaJumpLabel(n, b.src)) + case "return_statement": + b.buildReturn(n, "return", LabelReturn) + case "throw_statement": + b.buildReturn(n, "throw", LabelException) + case "synchronized_statement": + if l := n.ChildByFieldName("lock"); l != nil { + b.ensureCur() + b.addStmt(l, "") + } + if body := n.ChildByFieldName("body"); body != nil { + b.buildStmt(body) + } + default: + return false + } + return true +} + +func javaJumpLabel(n *sitter.Node, src []byte) string { + if c := n.NamedChild(0); c != nil && c.Type() == "identifier" { + return c.Content(src) + } + return "" +} + +// --------------------------------------------------------------------------- +// Rust +// --------------------------------------------------------------------------- + +var rustSpec = &langSpec{ + name: "rust", + grammar: rustlang.GetLanguage, + funcKinds: map[string]bool{ + "function_item": true, "closure_expression": true, + }, + identKinds: map[string]bool{"identifier": true}, + assigns: map[string]assignRule{ + "let_declaration": {lhsField: "pattern", mode: augNever}, + "assignment_expression": {lhsField: "left", mode: augNever}, + "compound_assignment_expr": {lhsField: "left", mode: augAlways}, + "let_condition": {lhsField: "pattern", mode: augNever}, + }, + updates: map[string]updateRule{}, + skipFields: map[string]map[string]bool{ + "field_expression": {"field": true}, + "tuple_struct_pattern": {"type": true}, + "struct_pattern": {"type": true}, + "field_initializer": {"field": true}, + // A match arm's guard hangs off the pattern node as the + // `condition` field; it reads variables, it doesn't bind them, + // so the binding walk must skip it. + "match_pattern": {"condition": true}, + }, + skipKinds: map[string]bool{ + "type_identifier": true, "primitive_type": true, "field_identifier": true, + "scoped_identifier": true, "scoped_type_identifier": true, "lifetime": true, + "type_arguments": true, "label": true, + }, + nestedFuncs: map[string]bool{ + "closure_expression": true, "function_item": true, + }, + patternContainers: map[string]bool{ + "tuple_pattern": true, "tuple_struct_pattern": true, "struct_pattern": true, + "slice_pattern": true, "reference_pattern": true, "mut_pattern": true, + "field_pattern": true, "or_pattern": true, "match_pattern": true, + }, + paramSkipFields: map[string]bool{"type": true}, + paramSkipKinds: map[string]bool{"type_identifier": true, "primitive_type": true, "reference_type": true}, + dispatch: rustDispatch, +} + +func rustDispatch(b *builder, n *sitter.Node) bool { + switch n.Type() { + case "block": + b.buildSeq(n) + case "expression_statement": + // Unwrap so control-flow expressions in statement position + // (if/while/match/loop) reach the cases below. + if c := n.NamedChild(0); c != nil { + b.buildStmt(c) + return true + } + return false + case "if_expression": + b.buildIf(nil, n.ChildByFieldName("condition"), + n.ChildByFieldName("consequence"), n.ChildByFieldName("alternative")) + case "else_clause": + for i := 0; i < int(n.NamedChildCount()); i++ { + if c := n.NamedChild(i); c != nil { + b.buildStmt(c) + } + } + case "while_expression": + b.pendingLabel = rustLoopOwnLabel(n, b.src) + b.buildLoop(loopParts{cond: n.ChildByFieldName("condition"), body: n.ChildByFieldName("body")}) + case "loop_expression": + b.pendingLabel = rustLoopOwnLabel(n, b.src) + b.buildLoop(loopParts{body: n.ChildByFieldName("body"), infinite: true}) + case "for_expression": + b.pendingLabel = rustLoopOwnLabel(n, b.src) + b.buildLoop(loopParts{headerStmt: n, headerStmtOnlyHeaderFields: true, body: n.ChildByFieldName("body")}) + case "match_expression": + b.buildRustMatch(n) + case "break_expression": + b.buildBreak(n, rustLoopLabel(n, b.src)) + case "continue_expression": + b.buildContinue(n, rustLoopLabel(n, b.src)) + case "return_expression": + b.buildReturn(n, "return", LabelReturn) + default: + return false + } + return true +} + +// rustLoopLabel extracts the label off a break/continue expression +// (`break 'outer`). The vendored grammar exposes the label as a child +// of node type `label` whose first identifier is the lifetime name. +func rustLoopLabel(n *sitter.Node, src []byte) string { + return rustLabelName(childOfType(n, "label"), src) +} + +// rustLoopOwnLabel extracts the label a loop declares for itself +// (`'outer: loop { … }`). The label is a `label` child of the loop +// expression, ahead of the body. +func rustLoopOwnLabel(n *sitter.Node, src []byte) string { + return rustLabelName(childOfType(n, "label"), src) +} + +// rustLabelName normalises a `label` node to its bare name (no +// leading `'`). The label identifier is the node's first child; fall +// back to the node text for grammar shapes that inline it. +func rustLabelName(lbl *sitter.Node, src []byte) string { + if lbl == nil { + return "" + } + text := lbl.Content(src) + if id := childOfType(lbl, "identifier"); id != nil { + text = id.Content(src) + } + return strings.TrimPrefix(strings.TrimSpace(text), "'") +} + +// buildRustMatch dispatches each arm off the subject; arms never +// fall through and the match is exhaustive, so there is no +// unmatched-subject edge. +func (b *builder) buildRustMatch(n *sitter.Node) { + b.ensureCur() + if v := n.ChildByFieldName("value"); v != nil { + b.addStmt(v, "cond") + } + head := b.cur + after := b.newBlock("match_end") + body := n.ChildByFieldName("body") + if body == nil { + b.cur = after + return + } + for i := 0; i < int(body.NamedChildCount()); i++ { + arm := body.NamedChild(i) + if arm == nil || arm.Type() != "match_arm" { + continue + } + caseBlock := b.newBlock("case") + b.edge(head, caseBlock, LabelCase) + b.cur = caseBlock + if pat := arm.ChildByFieldName("pattern"); pat != nil { + // Arm patterns bind names; the guard (if any) reads. The + // guard lives on the pattern node as its `condition` field + // (skipped by the binding walk via skipFields), so the + // asDef pass over the pattern yields only the captures. + st := b.recordNode(pat, "case") + defs, _ := extractDefUse(b.spec, b.src, pat, true) + st.Defs = defs + if g := pat.ChildByFieldName("condition"); g != nil { + _, uses := extractDefUse(b.spec, b.src, g, false) + st.Uses = uses + } + } + if v := arm.ChildByFieldName("value"); v != nil { + b.buildStmt(v) + } + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + } + b.cur = after +} + +// --------------------------------------------------------------------------- +// Ruby +// --------------------------------------------------------------------------- + +var rubySpec = &langSpec{ + name: "ruby", + grammar: rubylang.GetLanguage, + funcKinds: map[string]bool{ + "method": true, "singleton_method": true, + }, + identKinds: map[string]bool{"identifier": true}, + assigns: map[string]assignRule{ + "assignment": {lhsField: "left", mode: augNever}, + "operator_assignment": {lhsField: "left", mode: augAlways}, + }, + updates: map[string]updateRule{}, + skipFields: map[string]map[string]bool{ + "call": {"method": true}, + }, + skipKinds: map[string]bool{ + "constant": true, "instance_variable": true, "class_variable": true, + "global_variable": true, "symbol": true, "hash_key_symbol": true, + }, + nestedFuncs: map[string]bool{ + "method": true, "singleton_method": true, "lambda": true, + "do_block": true, "block": true, "class": true, "module": true, + }, + patternContainers: map[string]bool{ + "left_assignment_list": true, "destructured_left_assignment": true, + "rest_assignment": true, + }, + paramSkipFields: map[string]bool{"value": true}, + paramSkipKinds: map[string]bool{}, + dispatch: rubyDispatch, +} + +func rubyDispatch(b *builder, n *sitter.Node) bool { + switch n.Type() { + case "body_statement", "begin": + b.buildRubyBody(n) + case "then", "do", "else", "ensure": + b.buildSeq(n) + case "if", "unless": + b.buildIf(nil, n.ChildByFieldName("condition"), + n.ChildByFieldName("consequence"), n.ChildByFieldName("alternative")) + case "elsif": + b.buildIf(nil, n.ChildByFieldName("condition"), + n.ChildByFieldName("consequence"), n.ChildByFieldName("alternative")) + case "if_modifier", "unless_modifier": + b.buildIf(nil, n.ChildByFieldName("condition"), n.ChildByFieldName("body"), nil) + case "while", "until": + b.buildLoop(loopParts{cond: n.ChildByFieldName("condition"), body: n.ChildByFieldName("body")}) + case "while_modifier", "until_modifier": + b.buildLoop(loopParts{cond: n.ChildByFieldName("condition"), body: n.ChildByFieldName("body")}) + case "for": + b.buildLoop(loopParts{headerStmt: n, headerStmtOnlyHeaderFields: true, body: n.ChildByFieldName("body")}) + case "case": + b.buildRubyCase(n) + case "call": + blk := n.ChildByFieldName("block") + if blk == nil { + blk = childOfType(n, "do_block") + } + if blk == nil { + blk = childOfType(n, "block") + } + if blk == nil { + return false // plain call — leaf statement + } + b.buildRubyBlockCall(n, blk) + case "break": + b.buildBreak(n, "") + case "next", "redo": + b.buildContinue(n, "") + case "return": + b.buildReturn(n, "return", LabelReturn) + case "raise", "throw": + b.buildReturn(n, "throw", LabelException) + default: + return false + } + return true +} + +// buildRubyBody handles statement sequences that may carry method- +// level or begin-level rescue/ensure/else clauses. +func (b *builder) buildRubyBody(n *sitter.Node) { + var mainStmts []*sitter.Node + var p tryParts + hasClauses := false + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + if c == nil { + continue + } + switch c.Type() { + case "rescue": + hasClauses = true + h := handlerPart{} + if v := c.ChildByFieldName("variable"); v != nil { + h.headerNode = v + h.headerDefs = true + } else if ex := c.ChildByFieldName("exceptions"); ex != nil { + h.headerNode = ex + } + h.bodyNode = c.ChildByFieldName("body") + p.handlers = append(p.handlers, h) + case "ensure": + hasClauses = true + p.finallyNode = c + case "else": + hasClauses = true + p.elseNode = c + default: + mainStmts = append(mainStmts, c) + } + } + if !hasClauses { + for _, st := range mainStmts { + b.buildStmt(st) + } + return + } + p.bodyStmts = mainStmts + b.buildTry(p) +} + +// buildRubyCase dispatches each `when` off the subject; clauses +// never fall through. +func (b *builder) buildRubyCase(n *sitter.Node) { + b.ensureCur() + if v := n.ChildByFieldName("value"); v != nil { + b.addStmt(v, "cond") + } + head := b.cur + after := b.newBlock("case_end") + hasElse := false + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + if c == nil { + continue + } + switch c.Type() { + case "when", "in_clause": + caseBlock := b.newBlock("case") + b.edge(head, caseBlock, LabelCase) + b.cur = caseBlock + if pat := c.ChildByFieldName("pattern"); pat != nil { + b.addStmt(pat, "case") + } + if body := c.ChildByFieldName("body"); body != nil { + b.buildStmt(body) + } + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + case "else": + hasElse = true + elseBlock := b.newBlock("case") + b.edge(head, elseBlock, LabelCase) + b.cur = elseBlock + b.buildSeq(c) + if b.cur != nil { + b.edge(b.cur, after, LabelSeq) + } + } + } + if !hasElse { + b.edge(head, after, LabelFalse) + } + b.cur = after +} + +// buildRubyBlockCall models `receiver.each do |x| … end` as a loop: +// the call is the header (reads receiver + args), the block body may +// run zero or more times, and `next`/`break` behave like loop +// continue/break. +func (b *builder) buildRubyBlockCall(call, blk *sitter.Node) { + b.ensureCur() + header := b.newBlock("block_call") + b.moveTo(header) + // Uses come from the call minus the block body (the spec's + // nestedFuncs already exclude do_block/block subtrees). + b.addStmt(call, "loop") + after := b.newBlock("block_end") + bodyBlock := b.newBlock("block_body") + b.edge(header, bodyBlock, LabelTrue) + b.edge(header, after, LabelFalse) + b.pushFrame(frame{label: b.takeLabel(), continueTo: header, breakTo: after, isLoop: true}) + b.cur = bodyBlock + if params := blk.ChildByFieldName("parameters"); params != nil { + st := b.recordNode(params, "param") + defs, _ := extractDefUse(b.spec, b.src, params, true) + st.Defs = defs + } + if body := blk.ChildByFieldName("body"); body != nil { + // The block body is a body_statement; build it directly so + // rescue clauses inside the block still work. + b.buildStmt(body) + } + if b.cur != nil { + b.edge(b.cur, header, LabelLoopBack) + } + b.popFrame() + b.cur = after +} diff --git a/internal/cfg/lang_test.go b/internal/cfg/lang_test.go new file mode 100644 index 00000000..4b18d403 --- /dev/null +++ b/internal/cfg/lang_test.go @@ -0,0 +1,898 @@ +package cfg + +import ( + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// Python +// --------------------------------------------------------------------------- + +func TestPythonElifChainAndLoops(t *testing.T) { + c := mustBuild(t, `def f(a): + x = 0 + if a > 10: + x = 1 + elif a > 5: + x = 2 + else: + x = 3 + while x > 0: + x -= 1 + if x == 2: + break + else: + continue + return x +`, "python") + + for _, want := range []EdgeLabel{LabelTrue, LabelFalse, LabelLoopBack, LabelBreak, LabelContinue, LabelReturn} { + if !hasEdgeLabel(c, want) { + t.Errorf("missing %s edge", want) + } + } + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + // All three branch defs and the loop decrement can reach the + // return; the zeroth def is killed on every if path... but the + // elif chain has an else, so x = 0 cannot survive — yet the + // while may run zero times, so branch defs survive. + d1 := stmtByText(t, c, "x = 1") + d2 := stmtByText(t, c, "x = 2") + d3 := stmtByText(t, c, "x = 3") + dec := stmtByText(t, c, "x -= 1") + for _, d := range []*Statement{d1, d2, d3, dec} { + if !containsInt(ch.Defs, d.Index) { + t.Errorf("def %q (stmt %d) should reach return: %v", d.Text, d.Index, ch.Defs) + } + } + d0 := stmtByText(t, c, "x = 0") + if containsInt(ch.Defs, d0.Index) { + t.Errorf("x = 0 is killed on every if/elif/else path: %v", ch.Defs) + } + // Augmented assign reads its target. + if len(dec.Uses) == 0 || dec.Uses[0] != "x" { + t.Errorf("x -= 1 must use x: %v", dec.Uses) + } +} + +func TestPythonForAndTryExceptFinally(t *testing.T) { + c := mustBuild(t, `def f(items): + total = 0 + for i in items: + total += i + try: + total = parse(total) + except ValueError as e: + total = -1 + finally: + log(total) + return total +`, "python") + + loop := stmtByText(t, c, "for i in items") + if len(loop.Defs) != 1 || loop.Defs[0] != "i" { + t.Errorf("for header must define i: %v", loop.Defs) + } + if len(loop.Uses) != 1 || loop.Uses[0] != "items" { + t.Errorf("for header must use items: %v", loop.Uses) + } + if !hasEdgeLabel(c, LabelException) || !hasEdgeLabel(c, LabelFinally) { + t.Fatalf("try/except/finally must wire exception+finally edges: %+v", c.Edges) + } + // The except binding defines e. + catch := stmtByText(t, c, "ValueError as e") + foundE := false + for _, d := range catch.Defs { + if d == "e" { + foundE = true + } + } + if !foundE { + t.Errorf("except clause must define e: %v", catch.Defs) + } + // finally sees both the try def and the handler def. + r := c.ReachingDefinitions() + logStmt := stmtByText(t, c, "log(total)") + ch := chainFor(t, r, logStmt.Index, "total") + dTry := stmtByText(t, c, "total = parse(total)") + dExc := stmtByText(t, c, "total = -1") + if !containsInt(ch.Defs, dTry.Index) || !containsInt(ch.Defs, dExc.Index) { + t.Errorf("finally must merge try and handler defs: %v", ch.Defs) + } + // An exception before the protected assignment leaves the + // pre-try def live — it must reach the finally too. + dInit := stmtByText(t, c, "total = 0") + if !containsInt(ch.Defs, dInit.Index) { + t.Errorf("pre-try def must reach the finally via the exception path: %v", ch.Defs) + } +} + +func TestPythonIndentedMethodDedents(t *testing.T) { + // A method sliced out of a class body keeps its indentation — + // Build must dedent before parsing. + c := mustBuild(t, ` def m(self, a): + x = a + return x +`, "python") + if c.FuncName != "m" { + t.Errorf("FuncName = %q, want m", c.FuncName) + } + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return x") + if !hasChain(r, ret.Index, "x") { + t.Errorf("dedented method must still produce chains") + } +} + +// Python match: arms live inside the match body block, not as direct +// children of match_statement. Each arm's pattern, body and chains +// must survive into the CFG, and capture patterns bind names. +func TestPythonMatchArmsAndCaptures(t *testing.T) { + c := mustBuild(t, `def f(x): + match x: + case [a, b]: + r = a + b + case Point(px): + r = px + case _: + r = 0 + return r +`, "python") + + if !hasEdgeLabel(c, LabelCase) { + t.Fatalf("match arms must produce case edges; edges: %+v", c.Edges) + } + // The bodies of every arm must be present. + for _, want := range []string{"r = a + b", "r = px", "r = 0"} { + stmtByText(t, c, want) + } + // Capture patterns bind a, b and px as definitions; their uses + // inside the arm bodies chain to the pattern. + r := c.ReachingDefinitions() + for use, vr := range map[string]string{"r = a + b": "a", "r = px": "px"} { + st := stmtByText(t, c, use) + if !hasChain(r, st.Index, vr) { + t.Errorf("capture %q must chain into %q: chains %+v", vr, use, r.Chains) + } + } + // All three arm defs of r reach the return. + ret := stmtByText(t, c, "return r") + ch := chainFor(t, r, ret.Index, "r") + for _, def := range []string{"r = a + b", "r = px", "r = 0"} { + d := stmtByText(t, c, def) + if !containsInt(ch.Defs, d.Index) { + t.Errorf("arm def %q must reach the return: %v", def, ch.Defs) + } + } +} + +// Python match guard: the guard reads variables, it must not turn +// them into phantom definitions. +func TestPythonMatchGuardReadsNotDefines(t *testing.T) { + c := mustBuild(t, `def f(x, lo): + match x: + case Box(v) if v > lo: + return v + case _: + return 0 +`, "python") + pat := stmtByText(t, c, "Box(v)") + if containsStr(pat.Defs, "lo") { + t.Errorf("guard variable lo must not be a definition: %v", pat.Defs) + } + foundLo := false + for _, u := range pat.Uses { + if u == "lo" { + foundLo = true + } + } + if !foundLo { + t.Errorf("guard must read lo: %v", pat.Uses) + } +} + +// Python for-else: a `break` skips the else clause, so the break-path +// definition is not killed by the else's reassignment. +func TestPythonForElseBreakSkipsElse(t *testing.T) { + c := mustBuild(t, `def f(items, target): + found = -1 + for i in items: + if i == target: + found = i + break + else: + found = 0 + return found +`, "python") + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return found") + ch := chainFor(t, r, ret.Index, "found") + dBreak := stmtByText(t, c, "found = i") + dElse := stmtByText(t, c, "found = 0") + // The break-path def must survive to the return — break jumps past + // the else, so the else cannot kill it. + if !containsInt(ch.Defs, dBreak.Index) { + t.Errorf("break-path def `found = i` must reach the return (break skips else): %v", ch.Defs) + } + // The else def reaches the return on the no-break path. + if !containsInt(ch.Defs, dElse.Index) { + t.Errorf("else def `found = 0` must reach the return on the normal exit: %v", ch.Defs) + } +} + +func containsStr(xs []string, x string) bool { + for _, v := range xs { + if v == x { + return true + } + } + return false +} + +// --------------------------------------------------------------------------- +// JavaScript / TypeScript +// --------------------------------------------------------------------------- + +func TestJavaScriptSwitchFallthroughAndTry(t *testing.T) { + c := mustBuild(t, `function f(a) { + let x = 0; + switch (a) { + case 1: + x = 1; + break; + case 2: + x = 2; + default: + x = 3; + } + try { + x = g(x); + } catch (e) { + x = -1; + } finally { + log(x); + } + return x; +} +`, "javascript") + + // case 2 falls through into default. + s2 := stmtByText(t, c, "x = 2") + s3 := stmtByText(t, c, "x = 3") + if !edgeBetween(c, s2.Block, s3.Block, LabelSeq) { + t.Errorf("case 2 must fall through to default: %+v", c.Edges) + } + // case 1 must NOT fall through (break). + s1 := stmtByText(t, c, "x = 1") + if edgeBetween(c, s1.Block, s2.Block, LabelSeq) { + t.Errorf("case 1 ends with break and must not fall through") + } + if !hasEdgeLabel(c, LabelException) || !hasEdgeLabel(c, LabelFinally) { + t.Fatalf("try/catch/finally edges missing") + } + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + dTry := stmtByText(t, c, "x = g(x)") + dCatch := stmtByText(t, c, "x = -1") + if !containsInt(ch.Defs, dTry.Index) || !containsInt(ch.Defs, dCatch.Index) { + t.Errorf("return must merge try and catch defs: %v", ch.Defs) + } + // catch parameter defines e. + for _, st := range c.Stmts { + if st.Kind == "catch" { + if len(st.Defs) != 1 || st.Defs[0] != "e" { + t.Errorf("catch must define e: %v", st.Defs) + } + } + } +} + +func TestJavaScriptForOfAndLabeledBreak(t *testing.T) { + c := mustBuild(t, `function f(arr) { + let s = 0; + outer: for (const v of arr) { + for (let j = 0; j < v; j++) { + if (j > 3) break outer; + s += j; + } + } + return s; +} +`, "javascript") + hdr := stmtByText(t, c, "for (const v of arr)") + if len(hdr.Defs) != 1 || hdr.Defs[0] != "v" { + t.Errorf("for-of header must define v: %v", hdr.Defs) + } + if len(hdr.Uses) != 1 || hdr.Uses[0] != "arr" { + t.Errorf("for-of header must use arr: %v", hdr.Uses) + } + br := stmtByText(t, c, "break outer") + ret := stmtByText(t, c, "return s") + var breakTo = -1 + for _, e := range c.Edges { + if e.From == br.Block && e.Label == LabelBreak { + breakTo = e.To + } + } + ok := breakTo == ret.Block || edgeBetween(c, breakTo, ret.Block, LabelSeq) + if !ok { + t.Errorf("labeled break must exit the outer loop (got block %d)", breakTo) + } +} + +func TestTypeScriptMethodClassWrap(t *testing.T) { + // A class method sliced out of its class doesn't parse + // standalone — Build retries inside a synthetic class wrapper. + c := mustBuild(t, `private compute(a: number): number { + let x: number = a * 2; + if (x > 10) { + x = 10; + } + return x; +} +`, "typescript") + if c.FuncName != "compute" { + t.Errorf("FuncName = %q, want compute", c.FuncName) + } + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + if len(ch.Defs) != 2 { + t.Errorf("both defs of x must reach the return: %v", ch.Defs) + } + // Line numbers must NOT be offset by the synthetic wrapper line. + first := stmtByText(t, c, "let x") + if first.StartLine != 2 { + t.Errorf("wrapped parse must keep snippet-relative lines: got %d, want 2", first.StartLine) + } +} + +func TestJavaScriptDoWhile(t *testing.T) { + c := mustBuild(t, `function f(n) { + let i = 0; + do { + i++; + } while (i < n); + return i; +} +`, "javascript") + if !hasEdgeLabel(c, LabelLoopBack) { + t.Fatalf("do-while needs a loop_back edge: %+v", c.Edges) + } + // Post-test: the body executes before the condition; i++ must + // reach the condition's use of i. + r := c.ReachingDefinitions() + cond := stmtByText(t, c, "i < n") + inc := stmtByText(t, c, "i++") + ch := chainFor(t, r, cond.Index, "i") + if !containsInt(ch.Defs, inc.Index) { + t.Errorf("do-while condition must see the body's def: %v", ch.Defs) + } +} + +// --------------------------------------------------------------------------- +// Java +// --------------------------------------------------------------------------- + +func TestJavaMethodConstructsAndChains(t *testing.T) { + c := mustBuild(t, `int f(int a) { + int x = a + 1; + for (int i = 0; i < a; i++) { + if (i == 2) continue; + if (i == 5) break; + x += i; + } + switch (x) { + case 1: + x = 10; + break; + default: + x = 20; + } + try { + x = parse(x); + } catch (Exception e) { + x = 0; + } finally { + log(x); + } + return x; +} +`, "java") + + for _, want := range []EdgeLabel{LabelTrue, LabelFalse, LabelLoopBack, LabelBreak, LabelContinue, LabelCase, LabelException, LabelFinally, LabelReturn} { + if !hasEdgeLabel(c, want) { + t.Errorf("missing %s edge", want) + } + } + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + dTry := stmtByText(t, c, "x = parse(x)") + dCatch := stmtByText(t, c, "x = 0") + if !containsInt(ch.Defs, dTry.Index) || !containsInt(ch.Defs, dCatch.Index) { + t.Errorf("return must merge try and catch defs: %v", ch.Defs) + } + // Augmented assignment x += i reads and writes x. + aug := stmtByText(t, c, "x += i") + if len(aug.Defs) != 1 || aug.Defs[0] != "x" { + t.Errorf("x += i defs: %v", aug.Defs) + } + wantUses := map[string]bool{"x": false, "i": false} + for _, u := range aug.Uses { + if _, ok := wantUses[u]; ok { + wantUses[u] = true + } + } + for v, seen := range wantUses { + if !seen { + t.Errorf("x += i must use %s: %v", v, aug.Uses) + } + } +} + +func TestJavaEnhancedForHeader(t *testing.T) { + c := mustBuild(t, `int sum(java.util.List items) { + int s = 0; + for (Integer v : items) { + s += v; + } + return s; +} +`, "java") + hdr := stmtByText(t, c, "for (Integer v : items)") + if len(hdr.Defs) != 1 || hdr.Defs[0] != "v" { + t.Errorf("enhanced-for must define v: %v", hdr.Defs) + } + if len(hdr.Uses) != 1 || hdr.Uses[0] != "items" { + t.Errorf("enhanced-for must use items: %v", hdr.Uses) + } +} + +// Java method-call names must not register as variable uses. +func TestJavaMethodNameNotAUse(t *testing.T) { + c := mustBuild(t, `int f(java.util.List list) { + int size = 99; + int n = list.size(); + return n + size; +} +`, "java") + call := stmtByText(t, c, "int n = list.size()") + for _, u := range call.Uses { + if u == "size" { + t.Errorf("the .size() method name must not count as a use of the local `size`: %v", call.Uses) + } + } +} + +// Java arrow-form switch: `case 1 -> { … }` never falls through. The +// arm bodies must not be chained to one another by a phantom seq edge; +// the post-switch use must see every arm's def but no fallthrough. +func TestJavaArrowSwitchNoFallthrough(t *testing.T) { + c := mustBuild(t, `int f(int x) { + int y = 0; + switch (x) { + case 1 -> { y = 1; } + case 2 -> { y = 2; } + default -> { y = 3; } + } + return y; +}`, "java") + s1 := stmtByText(t, c, "y = 1") + s2 := stmtByText(t, c, "y = 2") + s3 := stmtByText(t, c, "y = 3") + // No fallthrough seq edges between arrow rules. + if edgeBetween(c, s1.Block, s2.Block, LabelSeq) { + t.Errorf("arrow rule `case 1` must not fall through to `case 2`; edges: %+v", c.Edges) + } + if edgeBetween(c, s2.Block, s3.Block, LabelSeq) { + t.Errorf("arrow rule `case 2` must not fall through to default; edges: %+v", c.Edges) + } + // Every arm's def reaches the return (each via its own path). + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return y") + ch := chainFor(t, r, ret.Index, "y") + for _, d := range []*Statement{s1, s2, s3} { + if !containsInt(ch.Defs, d.Index) { + t.Errorf("arm def %q must reach the return: %v", d.Text, ch.Defs) + } + } +} + +// --------------------------------------------------------------------------- +// Rust +// --------------------------------------------------------------------------- + +func TestRustMatchLoopsAndChains(t *testing.T) { + c := mustBuild(t, `fn f(a: i32) -> i32 { + let mut x = a + 1; + while x > 0 { + x -= 1; + if x == 2 { break; } + } + for i in 0..3 { + x += i; + } + loop { + x += 1; + if x > 5 { break; } + } + match x { + 1 => { x = 10; } + n => { x = n + 1; } + } + return x; +} +`, "rust") + + for _, want := range []EdgeLabel{LabelTrue, LabelFalse, LabelLoopBack, LabelBreak, LabelCase, LabelReturn} { + if !hasEdgeLabel(c, want) { + t.Errorf("missing %s edge", want) + } + } + // The match binding pattern `n` defines n; the arm body uses it. + r := c.ReachingDefinitions() + armBody := stmtByText(t, c, "x = n + 1") + ch := chainFor(t, r, armBody.Index, "n") + pat := c.Stmts[ch.Defs[0]] + if pat.Kind != "case" { + t.Errorf("n's def should be the arm pattern statement, got kind %q (%q)", pat.Kind, pat.Text) + } + // for header defines i. + hdr := stmtByText(t, c, "for i in 0..3") + if len(hdr.Defs) != 1 || hdr.Defs[0] != "i" { + t.Errorf("for header must define i: %v", hdr.Defs) + } +} + +// Rust labeled break: `break 'outer` must exit the outer loop, so the +// outer loop_end (holding `return s`) stays reachable. Before the fix +// the label never resolved and the break leaked to the innermost +// loop. +func TestRustLabeledBreakExitsOuter(t *testing.T) { + c := mustBuild(t, `fn f() -> i32 { + let mut s = 0; + 'outer: loop { + loop { + s += 1; + if s > 3 { break 'outer; } + } + } + return s; +}`, "rust") + br := stmtByText(t, c, "break 'outer") + var breakTo = -1 + for _, e := range c.Edges { + if e.From == br.Block && e.Label == LabelBreak { + breakTo = e.To + } + } + if breakTo < 0 { + t.Fatalf("no break edge from block %d", br.Block) + } + ret := stmtByText(t, c, "return s") + ok := breakTo == ret.Block || edgeBetween(c, breakTo, ret.Block, LabelSeq) + if !ok { + t.Errorf("labeled break must exit the outer loop (block %d), reaching the return block %d; edges: %+v", breakTo, ret.Block, c.Edges) + } + // The label identifier must not leak in as a variable use. + if containsStr(br.Uses, "outer") { + t.Errorf("loop label `outer` must not register as a use: %v", br.Uses) + } +} + +// Rust match guard: the guard condition lives on the match_pattern as +// its `condition` field; its reads must not become phantom binding +// definitions that kill the real parameter def. +func TestRustMatchGuardReadsNotDefines(t *testing.T) { + c := mustBuild(t, `fn f(o: Option, z: i32) -> i32 { + match o { + Some(y) if z > 0 => y, + _ => z, + } +}`, "rust") + pat := stmtByText(t, c, "Some(y)") + // z is read by the guard, not bound by the pattern. + if containsStr(pat.Defs, "z") { + t.Errorf("guard variable z must not be a pattern definition: %v", pat.Defs) + } + foundZ := false + for _, u := range pat.Uses { + if u == "z" { + foundZ = true + } + } + if !foundZ { + t.Errorf("guard must read z: %v", pat.Uses) + } + // The capture y is still a definition. + if !containsStr(pat.Defs, "y") { + t.Errorf("capture y must be a definition: %v", pat.Defs) + } +} + +func TestRustIfLetBindsInHeader(t *testing.T) { + c := mustBuild(t, `fn f(opt: Option) -> i32 { + let mut x = 0; + if let Some(v) = opt { + x = v; + } + return x; +} +`, "rust") + cond := stmtByText(t, c, "let Some(v) = opt") + foundV := false + for _, d := range cond.Defs { + if d == "v" { + foundV = true + } + } + if !foundV { + t.Errorf("if-let must define v in the header: defs=%v", cond.Defs) + } + r := c.ReachingDefinitions() + use := stmtByText(t, c, "x = v") + if !hasChain(r, use.Index, "v") { + t.Errorf("v's use must chain to the if-let binding") + } +} + +// --------------------------------------------------------------------------- +// Ruby +// --------------------------------------------------------------------------- + +func TestRubyConstructsAndChains(t *testing.T) { + c := mustBuild(t, `def f(a) + x = a + 1 + if x > 10 + x = 10 + elsif x > 5 + x = 5 + else + x = 0 + end + while x > 0 + x -= 1 + break if x == 2 + end + case x + when 1 + x = 100 + else + x = 200 + end + return x +end +`, "ruby") + + for _, want := range []EdgeLabel{LabelTrue, LabelFalse, LabelLoopBack, LabelBreak, LabelCase, LabelReturn} { + if !hasEdgeLabel(c, want) { + t.Errorf("missing %s edge", want) + } + } + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return x") + ch := chainFor(t, r, ret.Index, "x") + d100 := stmtByText(t, c, "x = 100") + d200 := stmtByText(t, c, "x = 200") + if !containsInt(ch.Defs, d100.Index) || !containsInt(ch.Defs, d200.Index) { + t.Errorf("case-arm defs must union at the return: %v", ch.Defs) + } +} + +func TestRubyRescueEnsureAndBlocks(t *testing.T) { + c := mustBuild(t, `def f(items) + total = 0 + items.each do |it| + total += it + next if it < 0 + end + begin + total = parse(total) + rescue ArgumentError => e + total = -1 + ensure + log(total) + end + total +end +`, "ruby") + + if !hasEdgeLabel(c, LabelException) || !hasEdgeLabel(c, LabelFinally) { + t.Fatalf("begin/rescue/ensure must wire exception+finally edges") + } + // The block call models a loop: block param defined, loop_back + // present, `next` is a continue. + if !hasEdgeLabel(c, LabelLoopBack) || !hasEdgeLabel(c, LabelContinue) { + t.Errorf("each-block must model loop_back and continue: %+v", c.Edges) + } + var paramStmt *Statement + for _, st := range c.Stmts { + if st.Kind == "param" && len(st.Defs) == 1 && st.Defs[0] == "it" { + paramStmt = st + } + } + if paramStmt == nil { + t.Fatalf("block parameter |it| must be a definition") + } + r := c.ReachingDefinitions() + use := stmtByText(t, c, "total += it") + ch := chainFor(t, r, use.Index, "it") + if !containsInt(ch.Defs, paramStmt.Index) { + t.Errorf("it's use must chain to the block parameter: %v", ch.Defs) + } + // rescue binding defines e. + foundE := false + for _, st := range c.Stmts { + if st.Kind == "catch" { + for _, d := range st.Defs { + if d == "e" { + foundE = true + } + } + } + } + if !foundE { + t.Errorf("rescue => e must define e") + } + // ensure merges try + rescue defs. + logStmt := stmtByText(t, c, "log(total)") + chT := chainFor(t, r, logStmt.Index, "total") + dTry := stmtByText(t, c, "total = parse(total)") + dResc := stmtByText(t, c, "total = -1") + if !containsInt(chT.Defs, dTry.Index) || !containsInt(chT.Defs, dResc.Index) { + t.Errorf("ensure must merge try and rescue defs: %v", chT.Defs) + } +} + +func TestRubyUnlessAndModifiers(t *testing.T) { + c := mustBuild(t, `def f(x) + y = 0 + unless x > 0 + y = 1 + end + y += 2 if x == 5 + return y +end +`, "ruby") + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return y") + ch := chainFor(t, r, ret.Index, "y") + d0 := stmtByText(t, c, "y = 0") + d1 := stmtByText(t, c, "y = 1") + dMod := stmtByText(t, c, "y += 2") + for _, d := range []*Statement{d0, d1, dMod} { + if !containsInt(ch.Defs, d.Index) { + t.Errorf("def %q must reach return (conditional paths): %v", d.Text, ch.Defs) + } + } +} + +// --------------------------------------------------------------------------- +// nested construct stress: every language parses a nested +// if-in-loop-in-if shape and produces a connected graph. +// --------------------------------------------------------------------------- + +func TestNestedShapesAllLanguages(t *testing.T) { + cases := map[string]string{ + "go": `func f(a int) int { + r := 0 + if a > 0 { + for i := 0; i < a; i++ { + if i%2 == 0 { + r += i + } else { + r -= i + } + } + } + return r +}`, + "python": `def f(a): + r = 0 + if a > 0: + for i in range(a): + if i % 2 == 0: + r += i + else: + r -= i + return r +`, + "javascript": `function f(a) { + let r = 0; + if (a > 0) { + for (let i = 0; i < a; i++) { + if (i % 2 === 0) { r += i; } else { r -= i; } + } + } + return r; +}`, + "typescript": `function f(a: number): number { + let r = 0; + if (a > 0) { + for (let i = 0; i < a; i++) { + if (i % 2 === 0) { r += i; } else { r -= i; } + } + } + return r; +}`, + "java": `int f(int a) { + int r = 0; + if (a > 0) { + for (int i = 0; i < a; i++) { + if (i % 2 == 0) { r += i; } else { r -= i; } + } + } + return r; +}`, + "rust": `fn f(a: i32) -> i32 { + let mut r = 0; + if a > 0 { + for i in 0..a { + if i % 2 == 0 { r += i; } else { r -= i; } + } + } + return r; +}`, + "ruby": `def f(a) + r = 0 + if a > 0 + for i in 0..a + if i % 2 == 0 + r += i + else + r -= i + end + end + end + return r +end +`, + } + for lang, src := range cases { + t.Run(lang, func(t *testing.T) { + c := mustBuild(t, src, lang) + if len(c.Blocks) < 6 { + t.Fatalf("%s: nested shape should produce several blocks, got %d", lang, len(c.Blocks)) + } + if !hasEdgeLabel(c, LabelLoopBack) { + t.Errorf("%s: missing loop_back", lang) + } + r := c.ReachingDefinitions() + ret := stmtByText(t, c, "return r") + ch := chainFor(t, r, ret.Index, "r") + dPlus := stmtByText(t, c, "r += i") + dMinus := stmtByText(t, c, "r -= i") + dInit := c.Stmts[0] + for _, st := range c.Stmts { + if strings.Contains(st.Text, "r = 0") || strings.Contains(st.Text, "let r = 0") || strings.Contains(st.Text, "int r = 0") || strings.Contains(st.Text, "let mut r = 0") || strings.Contains(st.Text, "r := 0") { + dInit = st + break + } + } + for _, d := range []*Statement{dInit, dPlus, dMinus} { + if !containsInt(ch.Defs, d.Index) { + t.Errorf("%s: def %q must reach the return: %v", lang, d.Text, ch.Defs) + } + } + // Every non-entry block with statements must be reachable + // from somewhere except deliberately-unreachable ones. + seenTarget := map[int]bool{c.Entry: true} + for _, e := range c.Edges { + seenTarget[e.To] = true + } + for _, bl := range c.Blocks { + if bl.ID == c.Entry || len(bl.Stmts) == 0 { + continue + } + if !seenTarget[bl.ID] && bl.Label != "unreachable" { + t.Errorf("%s: block %d (%s) has statements but no incoming edge", lang, bl.ID, bl.Label) + } + } + }) + } +} diff --git a/internal/cfg/mermaid.go b/internal/cfg/mermaid.go new file mode 100644 index 00000000..e3efbd1f --- /dev/null +++ b/internal/cfg/mermaid.go @@ -0,0 +1,88 @@ +package cfg + +import ( + "fmt" + "strings" +) + +// mermaidMaxStmts caps the statement lines rendered per block so a +// long straight-line block doesn't dominate the diagram. +const mermaidMaxStmts = 8 + +// Mermaid renders the CFG as a Mermaid flowchart. Entry/exit get +// stadium shapes, every other block lists its statements. Sequential +// edges are unlabeled; every other label rides on the arrow. +func (c *CFG) Mermaid() string { + var b strings.Builder + b.WriteString("flowchart TD\n") + for _, bl := range c.Blocks { + if len(bl.Stmts) == 0 && !c.hasEdgeAt(bl.ID) { + continue // orphan empty block — noise in the diagram + } + switch bl.ID { + case c.Entry: + fmt.Fprintf(&b, " B%d([\"entry%s\"])\n", bl.ID, mermaidStmts(bl, true)) + case c.Exit: + fmt.Fprintf(&b, " B%d([\"exit\"])\n", bl.ID) + default: + body := mermaidStmts(bl, false) + if body == "" { + body = bl.Label + } + fmt.Fprintf(&b, " B%d[\"%s\"]\n", bl.ID, body) + } + } + for _, e := range c.Edges { + if e.Label == LabelSeq { + fmt.Fprintf(&b, " B%d --> B%d\n", e.From, e.To) + } else { + fmt.Fprintf(&b, " B%d -->|%s| B%d\n", e.From, string(e.Label), e.To) + } + } + return b.String() +} + +// hasEdgeAt reports whether any edge touches the block. +func (c *CFG) hasEdgeAt(id int) bool { + for _, e := range c.Edges { + if e.From == id || e.To == id { + return true + } + } + return false +} + +// mermaidStmts renders a block's statements as
-joined lines. +func mermaidStmts(bl *Block, contLine bool) string { + if len(bl.Stmts) == 0 { + return "" + } + parts := make([]string, 0, len(bl.Stmts)+1) + if contLine { + parts = append(parts, "") + } + for i, st := range bl.Stmts { + if i == mermaidMaxStmts { + parts = append(parts, fmt.Sprintf("… +%d more", len(bl.Stmts)-i)) + break + } + parts = append(parts, fmt.Sprintf("L%d: %s", st.StartLine, mermaidEscape(st.Text))) + } + return strings.Join(parts, "
") +} + +// mermaidEscape neutralizes the characters Mermaid treats as node +// syntax inside a quoted label. +func mermaidEscape(s string) string { + r := strings.NewReplacer( + "\"", "#quot;", + "<", "#lt;", + ">", "#gt;", + "{", "#123;", + "}", "#125;", + "[", "#91;", + "]", "#93;", + "|", "#124;", + ) + return r.Replace(s) +} diff --git a/internal/cfg/reaching.go b/internal/cfg/reaching.go new file mode 100644 index 00000000..454398fb --- /dev/null +++ b/internal/cfg/reaching.go @@ -0,0 +1,236 @@ +package cfg + +import ( + "math/bits" + "sort" +) + +// Definition is one (statement, variable) write site. ID is the +// definition's bit position in the analysis bitsets. +type Definition struct { + ID int `json:"id"` + Stmt int `json:"stmt"` + Var string `json:"var"` +} + +// UseChain links one variable read to every definition that can +// reach it along some control-flow path. Defs holds statement +// indices, ascending. +type UseChain struct { + Stmt int `json:"stmt"` + Var string `json:"var"` + Defs []int `json:"defs"` +} + +// ReachingResult is the fixpoint output: the definition table, the +// per-block IN/OUT sets (definition IDs), and the statement-granular +// def→use chains. +type ReachingResult struct { + Defs []Definition + Chains []UseChain + In [][]int + Out [][]int +} + +// ChainsFor returns the chains attached to one statement. +func (r *ReachingResult) ChainsFor(stmt int) []UseChain { + var out []UseChain + for _, c := range r.Chains { + if c.Stmt == stmt { + out = append(out, c) + } + } + return out +} + +// ReachingDefinitions runs the classic GEN/KILL monotone fixpoint +// over the CFG: +// +// IN[b] = ∪ OUT[p] for p ∈ preds(b) +// OUT[b] = GEN[b] ∪ (IN[b] − KILL[b]) +// +// then replays each block's statements against its IN set to link +// every use to the definitions reaching it. Bitsets keep the +// per-block transfer functions O(defs/64). +func (c *CFG) ReachingDefinitions() *ReachingResult { + res := &ReachingResult{} + + // 1. Number every definition and group them by variable. + defsByVar := map[string][]int{} + defID := map[[2]interface{}]int{} // (stmt, var) → def ID; stmts dedupe vars already + for _, st := range c.Stmts { + for _, v := range st.Defs { + id := len(res.Defs) + res.Defs = append(res.Defs, Definition{ID: id, Stmt: st.Index, Var: v}) + defsByVar[v] = append(defsByVar[v], id) + defID[[2]interface{}{st.Index, v}] = id + } + } + nDefs := len(res.Defs) + nBlocks := len(c.Blocks) + words := (nDefs + 63) / 64 + + newSet := func() bitset { return make(bitset, words) } + allDefsOf := func(v string) bitset { + s := newSet() + for _, id := range defsByVar[v] { + s.set(id) + } + return s + } + + // 2. Per-block GEN (downward-exposed defs) and KILL (every def of + // a variable the block writes). + gen := make([]bitset, nBlocks) + kill := make([]bitset, nBlocks) + for i, bl := range c.Blocks { + g, k := newSet(), newSet() + last := map[string]int{} + for _, st := range bl.Stmts { + for _, v := range st.Defs { + k.or(allDefsOf(v)) + last[v] = defID[[2]interface{}{st.Index, v}] + } + } + for _, id := range last { + g.set(id) + } + gen[i], kill[i] = g, k + } + + // 3. Predecessor / successor lists. + preds := make([][]int, nBlocks) + succs := make([][]int, nBlocks) + for _, e := range c.Edges { + preds[e.To] = append(preds[e.To], e.From) + succs[e.From] = append(succs[e.From], e.To) + } + + // 4. Worklist fixpoint. + in := make([]bitset, nBlocks) + out := make([]bitset, nBlocks) + for i := range c.Blocks { + in[i], out[i] = newSet(), newSet() + out[i].or(gen[i]) + } + work := make([]int, 0, nBlocks) + inWork := make([]bool, nBlocks) + for i := 0; i < nBlocks; i++ { + work = append(work, i) + inWork[i] = true + } + for len(work) > 0 { + b := work[0] + work = work[1:] + inWork[b] = false + + newIn := newSet() + for _, p := range preds[b] { + newIn.or(out[p]) + } + in[b] = newIn + newOut := newIn.clone() + newOut.andNot(kill[b]) + newOut.or(gen[b]) + if !newOut.equal(out[b]) { + out[b] = newOut + for _, s := range succs[b] { + if !inWork[s] { + work = append(work, s) + inWork[s] = true + } + } + } + } + + // 5. Statement-granular replay: thread the live set through each + // block, linking uses before applying the statement's defs. + for bi, bl := range c.Blocks { + live := in[bi].clone() + for _, st := range bl.Stmts { + for _, v := range st.Uses { + ids := defsByVar[v] + if len(ids) == 0 { + continue + } + var reach []int + for _, id := range ids { + if live.get(id) { + reach = append(reach, res.Defs[id].Stmt) + } + } + if len(reach) == 0 { + continue + } + sort.Ints(reach) + res.Chains = append(res.Chains, UseChain{Stmt: st.Index, Var: v, Defs: reach}) + } + for _, v := range st.Defs { + live.andNot(allDefsOf(v)) + live.set(defID[[2]interface{}{st.Index, v}]) + } + } + } + sort.SliceStable(res.Chains, func(i, j int) bool { + if res.Chains[i].Stmt != res.Chains[j].Stmt { + return res.Chains[i].Stmt < res.Chains[j].Stmt + } + return res.Chains[i].Var < res.Chains[j].Var + }) + + // 6. Export IN/OUT as sorted definition-ID lists. + res.In = make([][]int, nBlocks) + res.Out = make([][]int, nBlocks) + for i := 0; i < nBlocks; i++ { + res.In[i] = in[i].ids() + res.Out[i] = out[i].ids() + } + return res +} + +// bitset is a fixed-width bit vector over definition IDs. +type bitset []uint64 + +func (s bitset) set(i int) { s[i/64] |= 1 << (uint(i) % 64) } +func (s bitset) get(i int) bool { + return s[i/64]&(1<<(uint(i)%64)) != 0 +} + +func (s bitset) or(o bitset) { + for i := range s { + s[i] |= o[i] + } +} + +func (s bitset) andNot(o bitset) { + for i := range s { + s[i] &^= o[i] + } +} + +func (s bitset) clone() bitset { + c := make(bitset, len(s)) + copy(c, s) + return c +} + +func (s bitset) equal(o bitset) bool { + for i := range s { + if s[i] != o[i] { + return false + } + } + return true +} + +// ids enumerates the set bits ascending. +func (s bitset) ids() []int { + var out []int + for w, word := range s { + for word != 0 { + out = append(out, w*64+bits.TrailingZeros64(word)) + word &= word - 1 + } + } + return out +} diff --git a/internal/dataflow/dataflow.go b/internal/dataflow/dataflow.go index 761799f9..b6bf7719 100644 --- a/internal/dataflow/dataflow.go +++ b/internal/dataflow/dataflow.go @@ -47,20 +47,27 @@ const DefaultMaxDepth = 8 // DefaultMaxPaths bounds how many distinct paths flow_between // will return for a single (source, sink) pair. The handler ranks -// by length first, then by edge-confidence sum, so the user gets -// the most plausible paths first. +// refinement-confirmed paths ahead of disproved (pruned) ones, then +// by length, then by edge-confidence, so the user gets the most +// plausible paths first. const DefaultMaxPaths = 10 // EdgeStep is one hop along a flow path. It carries the edge kind, // origin tier, and coarse tier label so the caller can distinguish a // strong intra-procedural chain from a heuristic inter-procedural -// binding without recomputing the origin → tier mapping. +// binding without recomputing the origin → tier mapping. Refined is +// stamped by the CFG-backed reaching-definitions refinement when the +// hop's endpoints are bindings of the same function: +// confirmed_intraprocedural (a def→use chain verifies the hop) or +// pruned (the source's definition is killed before the target on +// every path). Empty when the hop is out of refinement scope. type EdgeStep struct { - From string `json:"from"` - To string `json:"to"` - Kind string `json:"kind"` - Origin string `json:"origin,omitempty"` - Tier string `json:"tier,omitempty"` + From string `json:"from"` + To string `json:"to"` + Kind string `json:"kind"` + Origin string `json:"origin,omitempty"` + Tier string `json:"tier,omitempty"` + Refined string `json:"refined,omitempty"` } // Path is an ordered sequence of edge hops from a source node to @@ -79,14 +86,27 @@ func (p Path) Length() int { return len(p.Edges) } // Engine is the dataflow query backend. It holds a reference to // the graph and exposes the two MCP-ready primitives. Concurrency- -// safe by virtue of relying only on graph.Store's read methods. +// safe by virtue of relying only on graph.Store's read methods — +// unless a Refiner is attached (refiners cache per-function CFGs and +// are meant for a single query). type Engine struct { - g graph.Store + g graph.Store + refiner *Refiner } // New returns an engine backed by the given graph. func New(g graph.Store) *Engine { return &Engine{g: g} } +// WithRefiner attaches a CFG-backed reaching-definitions refiner: +// FlowBetween / TaintPaths results get per-hop +// confirmed_intraprocedural / pruned markers where both hop +// endpoints are bindings of the same function. Returns the engine +// for chaining. +func (e *Engine) WithRefiner(r *Refiner) *Engine { + e.refiner = r + return e +} + // IsDataflowKind returns true for the three edge kinds the BFS // traverses. func IsDataflowKind(k graph.EdgeKind) bool { @@ -191,6 +211,13 @@ func (e *Engine) FlowBetweenWithTier(sourceID, sinkID string, maxDepth, maxPaths } dfs(sourceID, 0) + // CFG-backed refinement: judge same-function value_flow hops + // with reaching-definition chains before ranking, so pruned + // paths sink below confirmed ones. + if e.refiner != nil { + e.refiner.refinePaths(paths) + } + rankPaths(paths) if len(paths) > maxPaths { paths = paths[:maxPaths] @@ -210,12 +237,19 @@ func edgeOrigin(e *graph.Edge) string { return graph.DefaultOriginFor(e.Kind, e.Confidence, src) } -// rankPaths sorts in-place by length asc, then by confidence desc. -// Shorter, higher-confidence paths sort first so the agent always -// sees the most plausible explanation before the more speculative -// chains. +// rankPaths sorts in-place so the most plausible explanation comes +// first. A path the reaching-definitions refinement disproved (any +// pruned hop) always sinks below every non-pruned path, regardless of +// length — confidence demotion alone can't move a pruned hop past a +// shorter confirmed one when length is the primary key. Among paths +// of the same pruned status, shorter and then higher-confidence paths +// sort first. func rankPaths(paths []Path) { sort.SliceStable(paths, func(i, j int) bool { + pi, pj := hasPrunedHop(paths[i]), hasPrunedHop(paths[j]) + if pi != pj { + return !pi // non-pruned paths rank ahead of pruned ones + } if len(paths[i].Edges) != len(paths[j].Edges) { return len(paths[i].Edges) < len(paths[j].Edges) } @@ -223,6 +257,17 @@ func rankPaths(paths []Path) { }) } +// hasPrunedHop reports whether any hop on the path was disproved by +// the CFG-backed refinement. +func hasPrunedHop(p Path) bool { + for _, e := range p.Edges { + if e.Refined == RefinedPruned { + return true + } + } + return false +} + // confidenceFromEdges computes a normalised path confidence from // the per-edge origin tiers. Each edge contributes a 0-1 score // based on how well-grounded its kind / origin are; the path's diff --git a/internal/dataflow/refine.go b/internal/dataflow/refine.go new file mode 100644 index 00000000..2e1c30a9 --- /dev/null +++ b/internal/dataflow/refine.go @@ -0,0 +1,269 @@ +package dataflow + +import ( + "strings" + + "github.com/zzet/gortex/internal/cfg" + "github.com/zzet/gortex/internal/graph" +) + +// Refinement markers stamped on EdgeStep.Refined. A confirmed hop +// has a CFG-verified reaching-definition chain from the source +// binding to the statement that defines the target binding; a pruned +// hop is one the chain analysis disproves (the source's definition +// is killed on every path before the target's defining statement). +const ( + RefinedConfirmed = "confirmed_intraprocedural" + RefinedPruned = "pruned" +) + +// prunedPenalty scales a path's confidence for every hop the +// reaching-definitions analysis disproves, so stale edges sink in +// the ranking instead of silently disappearing. +const prunedPenalty = 0.25 + +// defaultRefinerCapacity bounds how many per-function CFGs one +// refiner will hold. Refiners are per-call; the cap keeps a +// pathological many-function path from accumulating parse trees. +// +// A taint sweep reuses one refiner across every (source, sink) pair, +// and refinement runs per pair before the findings are ranked and +// capped — so a function on several candidate paths must survive in +// the cache between pairs or it gets re-parsed. The working set of a +// single flow_between walk is bounded by DefaultMaxPaths distinct +// paths of at most DefaultMaxDepth hops; sizing the cache to cover +// that union keeps a broad pattern sweep from thrashing the FIFO, +// while still bounding transient memory (the refiner is discarded at +// the end of the call). +const defaultRefinerCapacity = DefaultMaxPaths * DefaultMaxDepth + +// FuncSource is one function's source text plus the file-absolute +// line its first byte sits on. +type FuncSource struct { + Src []byte + StartLine int +} + +// SourceResolver fetches the source of a function/method node. The +// MCP layer supplies an overlay-aware reader; tests can return +// source from memory. +type SourceResolver func(fn *graph.Node) (FuncSource, error) + +// Refiner upgrades value_flow hops whose endpoints are bindings of +// the same function using statement-granular reaching-definition +// chains. CFGs are built lazily — only for functions that actually +// appear on a candidate path — and cached with FIFO eviction. Not +// safe for concurrent use; construct one per query. +type Refiner struct { + g graph.Store + resolve SourceResolver + cap int + entries map[string]*refEntry + order []string +} + +// refEntry caches one function's analysis; a nil graph marks a +// negative entry (unsupported language, unreadable source, parse +// failure) so the failure isn't retried per hop. +type refEntry struct { + c *cfg.CFG + r *cfg.ReachingResult +} + +// NewRefiner builds a refiner over the graph with the given source +// resolver. capacity <= 0 selects the default. +func NewRefiner(g graph.Store, resolve SourceResolver, capacity int) *Refiner { + if capacity <= 0 { + capacity = defaultRefinerCapacity + } + return &Refiner{ + g: g, + resolve: resolve, + cap: capacity, + entries: make(map[string]*refEntry, capacity), + } +} + +// refinePaths stamps Refined markers on the value_flow hops it can +// judge and rescales path confidence for pruned hops. Returns true +// when any confidence changed (callers re-rank). +func (r *Refiner) refinePaths(paths []Path) bool { + if r == nil { + return false + } + changed := false + for pi := range paths { + for si := range paths[pi].Edges { + step := &paths[pi].Edges[si] + if graph.EdgeKind(step.Kind) != graph.EdgeValueFlow { + continue + } + switch r.refineStep(step) { + case RefinedConfirmed: + step.Refined = RefinedConfirmed + case RefinedPruned: + step.Refined = RefinedPruned + paths[pi].Confidence *= prunedPenalty + changed = true + } + } + } + return changed +} + +// refineStep judges one hop. Returns "" when the hop is out of scope +// (endpoints not bindings of the same function, unsupported +// language) or the CFG can't anchor both endpoints — unmarked hops +// keep their coarse-edge semantics. +func (r *Refiner) refineStep(step *EdgeStep) string { + fromOwner, fromName, ok := splitBindingID(step.From) + if !ok { + return "" + } + toOwner, toName, ok := splitBindingID(step.To) + if !ok || fromOwner != toOwner || fromName == "" || toName == "" { + return "" + } + ent := r.entryFor(fromOwner) + if ent == nil || ent.c == nil { + return "" + } + + defStmt := r.bindingDefStmt(ent, step.From, fromName) + if defStmt == nil { + return "" + } + toNode := r.g.GetNode(step.To) + if toNode == nil || toNode.StartLine == 0 { + return "" + } + useStmt := ent.c.StatementAt(toNode.StartLine, toName) + if useStmt == nil || useStmt.Index == defStmt.Index { + return "" + } + // The hop claims `toName`'s definition consumes `fromName`. If + // the CFG statement doesn't even read fromName the extraction + // disagrees with the graph edge — stay unmarked rather than + // judging on mismatched evidence. + reads := false + for _, u := range useStmt.Uses { + if u == fromName { + reads = true + break + } + } + if !reads { + return "" + } + for _, ch := range ent.r.ChainsFor(useStmt.Index) { + if ch.Var != fromName { + continue + } + for _, d := range ch.Defs { + if d == defStmt.Index { + return RefinedConfirmed + } + } + } + // fromName is read at the target statement, but the specific + // definition this hop starts from never reaches it — every path + // kills it first. + return RefinedPruned +} + +// bindingDefStmt anchors a binding node onto its defining CFG +// statement: params map to the synthetic entry-block param +// statements, locals to the statement covering their binding line. +func (r *Refiner) bindingDefStmt(ent *refEntry, id, name string) *cfg.Statement { + if strings.Contains(id, "#param:") { + for _, st := range ent.c.Stmts { + if st.Kind != "param" { + continue + } + for _, d := range st.Defs { + if d == name { + return st + } + } + } + return nil + } + node := r.g.GetNode(id) + if node == nil || node.StartLine == 0 { + return nil + } + return ent.c.StatementAt(node.StartLine, name) +} + +// entryFor returns the cached analysis for a function, building it +// on first sight. Failures cache negatively. +func (r *Refiner) entryFor(ownerID string) *refEntry { + if ent, ok := r.entries[ownerID]; ok { + return ent + } + ent := &refEntry{} + r.insert(ownerID, ent) + + fn := r.g.GetNode(ownerID) + if fn == nil || (fn.Kind != graph.KindFunction && fn.Kind != graph.KindMethod) { + return ent + } + if !cfg.SupportedLanguage(fn.Language) || fn.StartLine == 0 || fn.EndLine == 0 { + return ent + } + src, err := r.resolve(fn) + if err != nil || len(src.Src) == 0 { + return ent + } + c, err := cfg.Build(src.Src, fn.Language, cfg.Options{ + LineOffset: src.StartLine - 1, + FuncName: fn.Name, + }) + if err != nil { + return ent + } + ent.c = c + ent.r = c.ReachingDefinitions() + return ent +} + +func (r *Refiner) insert(key string, ent *refEntry) { + if len(r.order) >= r.cap { + oldest := r.order[0] + r.order = r.order[1:] + delete(r.entries, oldest) + } + r.entries[key] = ent + r.order = append(r.order, key) +} + +// splitBindingID decomposes the dataflow binding ID forms emitted at +// extraction time into the owning function ID and the variable name. +// Both binding forms may carry a position/offset suffix after `@`: +// +// - locals: `#local:@+` +// - params: `#param:` (Go) or +// `#param:@` (Python / TypeScript / Rust / +// Java / C#, which disambiguate duplicate names by position). +// +// The `@…` suffix is stripped from both so the bare variable name +// matches the CFG's def/use sets; without that the param branch +// silently left every non-Go param hop unrefined. +func splitBindingID(id string) (owner, name string, ok bool) { + if i := strings.Index(id, "#local:"); i > 0 { + return id[:i], trimBindingSuffix(id[i+len("#local:"):]), true + } + if i := strings.Index(id, "#param:"); i > 0 { + return id[:i], trimBindingSuffix(id[i+len("#param:"):]), true + } + return "", "", false +} + +// trimBindingSuffix drops the `@` / `@+` +// disambiguator a binding ID carries after its name. +func trimBindingSuffix(rest string) string { + if j := strings.IndexByte(rest, '@'); j >= 0 { + return rest[:j] + } + return rest +} diff --git a/internal/dataflow/refine_test.go b/internal/dataflow/refine_test.go new file mode 100644 index 00000000..31e77c77 --- /dev/null +++ b/internal/dataflow/refine_test.go @@ -0,0 +1,318 @@ +package dataflow + +import ( + "errors" + "testing" + + "github.com/zzet/gortex/internal/graph" +) + +// refineFixtureSrc is the function the refinement tests reason +// about. Binding sites (snippet lines): +// +// line 1: func F(a int) int — param a +// line 2: b := a — local b@+2 (consumes a) +// line 3: c := b — local c@+3 (consumes the live b) +// line 4: b = 99 — kills b@+2 +// line 5: d := b — local d@+5 (consumes the NEW b, not b@+2) +// line 6: return d +const refineFixtureSrc = `func F(a int) int { + b := a + c := b + b = 99 + d := b + return d +}` + +const refineOwner = "main.go::F" + +// refineGraph wires the binding nodes and value_flow edges. The +// b@+2 → d@+5 edge is deliberately stale: by line 5 the b binding +// from line 2 has been overwritten, so the reaching-definitions +// analysis must prune it. +func refineGraph(t *testing.T) *graph.Graph { + t.Helper() + g := graph.New() + g.AddNode(&graph.Node{ + ID: refineOwner, Kind: graph.KindFunction, Name: "F", + FilePath: "main.go", Language: "go", StartLine: 1, EndLine: 7, + }) + add := func(id string, kind graph.NodeKind, name string, line int) { + g.AddNode(&graph.Node{ + ID: id, Kind: kind, Name: name, + FilePath: "main.go", Language: "go", StartLine: line, EndLine: line, + }) + } + add(refineOwner+"#param:a", graph.KindParam, "a", 1) + add(refineOwner+"#local:b@+2", graph.KindLocal, "b", 2) + add(refineOwner+"#local:c@+3", graph.KindLocal, "c", 3) + add(refineOwner+"#local:d@+5", graph.KindLocal, "d", 5) + + flow := func(from, to string, line int) { + g.AddEdge(&graph.Edge{ + From: from, To: to, Kind: graph.EdgeValueFlow, + FilePath: "main.go", Line: line, Origin: graph.OriginASTResolved, + }) + } + // True flows. + flow(refineOwner+"#param:a", refineOwner+"#local:b@+2", 2) + flow(refineOwner+"#local:b@+2", refineOwner+"#local:c@+3", 3) + // Stale flow: b@+2 cannot reach d@+5 (b reassigned at line 4). + flow(refineOwner+"#local:b@+2", refineOwner+"#local:d@+5", 5) + return g +} + +func fixtureResolver(t *testing.T, calls *int) SourceResolver { + return func(fn *graph.Node) (FuncSource, error) { + if calls != nil { + *calls++ + } + if fn.ID != refineOwner { + return FuncSource{}, errors.New("unknown function") + } + return FuncSource{Src: []byte(refineFixtureSrc), StartLine: 1}, nil + } +} + +func TestRefinerConfirmsTrueFlow(t *testing.T) { + g := refineGraph(t) + e := New(g).WithRefiner(NewRefiner(g, fixtureResolver(t, nil), 0)) + + paths := e.FlowBetween(refineOwner+"#param:a", refineOwner+"#local:b@+2", 0, 0) + if len(paths) != 1 || paths[0].Length() != 1 { + t.Fatalf("expected one single-hop path, got %+v", paths) + } + step := paths[0].Edges[0] + if step.Refined != RefinedConfirmed { + t.Errorf("param→local flow should be confirmed, got %q", step.Refined) + } +} + +func TestRefinerConfirmsChainedFlow(t *testing.T) { + g := refineGraph(t) + e := New(g).WithRefiner(NewRefiner(g, fixtureResolver(t, nil), 0)) + + paths := e.FlowBetween(refineOwner+"#local:b@+2", refineOwner+"#local:c@+3", 0, 0) + if len(paths) != 1 { + t.Fatalf("expected one path, got %+v", paths) + } + if got := paths[0].Edges[0].Refined; got != RefinedConfirmed { + t.Errorf("live local→local flow should be confirmed, got %q", got) + } +} + +func TestRefinerPrunesStaleFlow(t *testing.T) { + g := refineGraph(t) + e := New(g).WithRefiner(NewRefiner(g, fixtureResolver(t, nil), 0)) + + paths := e.FlowBetween(refineOwner+"#local:b@+2", refineOwner+"#local:d@+5", 0, 0) + if len(paths) != 1 { + t.Fatalf("expected one path, got %+v", paths) + } + step := paths[0].Edges[0] + if step.Refined != RefinedPruned { + t.Errorf("stale flow (b reassigned before d := b) should be pruned, got %q", step.Refined) + } + + // Pruning must cost confidence relative to the unrefined run. + plain := New(g).FlowBetween(refineOwner+"#local:b@+2", refineOwner+"#local:d@+5", 0, 0) + if !(paths[0].Confidence < plain[0].Confidence) { + t.Errorf("pruned path confidence %v must drop below unrefined %v", + paths[0].Confidence, plain[0].Confidence) + } +} + +func TestRefinerRanksConfirmedAbovePruned(t *testing.T) { + // Two same-length paths from b@+2 to a common sink — one through + // the live binding c@+3, one through the stale hop to d@+5. After + // refinement the pruned path's confidence drops, so the confirmed + // route must rank first. + g := refineGraph(t) + g.AddNode(&graph.Node{ + ID: "main.go::Sink", Kind: graph.KindFunction, Name: "Sink", + FilePath: "main.go", Language: "go", StartLine: 20, EndLine: 22, + }) + for _, from := range []string{refineOwner + "#local:c@+3", refineOwner + "#local:d@+5"} { + g.AddEdge(&graph.Edge{ + From: from, To: "main.go::Sink", Kind: graph.EdgeValueFlow, + FilePath: "main.go", Origin: graph.OriginASTResolved, + }) + } + e := New(g).WithRefiner(NewRefiner(g, fixtureResolver(t, nil), 0)) + + paths := e.FlowBetween(refineOwner+"#local:b@+2", "main.go::Sink", 0, 0) + if len(paths) != 2 { + t.Fatalf("expected two competing paths, got %+v", paths) + } + first, second := paths[0], paths[1] + if first.Edges[0].Refined != RefinedConfirmed { + t.Errorf("top-ranked path must start with the confirmed hop, got %q (path %v)", first.Edges[0].Refined, first.IDs) + } + if second.Edges[0].Refined != RefinedPruned { + t.Errorf("second path must carry the pruned marker, got %q (path %v)", second.Edges[0].Refined, second.IDs) + } + if !(first.Confidence > second.Confidence) { + t.Errorf("confirmed path confidence %v must beat pruned %v", first.Confidence, second.Confidence) + } +} + +// A param whose binding ID carries a `@` suffix (the form +// every non-Go extractor mints) must still anchor onto the CFG's +// param statement so the hop is refined, not silently skipped. +func TestRefinerConfirmsPositionalParamFlow(t *testing.T) { + g := graph.New() + owner := "shape.go::F" + g.AddNode(&graph.Node{ + ID: owner, Kind: graph.KindFunction, Name: "F", + FilePath: "shape.go", Language: "go", StartLine: 1, EndLine: 4, + }) + // Param ID with a positional disambiguator, as Python/TS/etc emit. + paramID := owner + "#param:a@0" + localID := owner + "#local:b@+2" + g.AddNode(&graph.Node{ID: paramID, Kind: graph.KindParam, Name: "a", FilePath: "shape.go", Language: "go", StartLine: 1, EndLine: 1}) + g.AddNode(&graph.Node{ID: localID, Kind: graph.KindLocal, Name: "b", FilePath: "shape.go", Language: "go", StartLine: 2, EndLine: 2}) + g.AddEdge(&graph.Edge{From: paramID, To: localID, Kind: graph.EdgeValueFlow, FilePath: "shape.go", Line: 2, Origin: graph.OriginASTResolved}) + + resolve := func(fn *graph.Node) (FuncSource, error) { + return FuncSource{Src: []byte("func F(a int) int {\n\tb := a\n\treturn b\n}"), StartLine: 1}, nil + } + e := New(g).WithRefiner(NewRefiner(g, resolve, 0)) + paths := e.FlowBetween(paramID, localID, 0, 0) + if len(paths) != 1 { + t.Fatalf("expected one path, got %+v", paths) + } + if got := paths[0].Edges[0].Refined; got != RefinedConfirmed { + t.Errorf("positional-param flow must be confirmed, got %q (the @0 suffix must be stripped)", got) + } +} + +// A confirmed path that is LONGER than a competing pruned path must +// still rank ahead of it: confidence demotion alone can't move a +// pruned hop past a shorter confirmed one when length is the primary +// sort key, so pruned paths must sink categorically. +func TestRefinerSinksPrunedBelowLongerConfirmed(t *testing.T) { + // Reuse the fixture's stale b@+2 → d@+5 hop (pruned, 1 hop) and + // build a longer confirmed route a → b → c → sink (3 hops) to the + // same sink, so the pruned path is strictly shorter. + g := refineGraph(t) + g.AddNode(&graph.Node{ + ID: "main.go::Sink", Kind: graph.KindFunction, Name: "Sink", + FilePath: "main.go", Language: "go", StartLine: 20, EndLine: 22, + }) + // Shorter pruned route: b@+2 → d@+5 → Sink (2 hops, first hop pruned). + g.AddEdge(&graph.Edge{From: refineOwner + "#local:d@+5", To: "main.go::Sink", Kind: graph.EdgeValueFlow, FilePath: "main.go", Origin: graph.OriginASTResolved}) + // Longer confirmed route: b@+2 → c@+3 → Sink (2 hops) — but make it + // genuinely longer by routing through an extra confirmed local. + // c@+3 is live, so b→c is confirmed; c→Sink is out of scope + // (unmarked, not pruned). Add an intermediate to force length 3. + g.AddNode(&graph.Node{ID: "main.go::Mid", Kind: graph.KindFunction, Name: "Mid", FilePath: "main.go", Language: "go", StartLine: 15, EndLine: 16}) + g.AddEdge(&graph.Edge{From: refineOwner + "#local:c@+3", To: "main.go::Mid", Kind: graph.EdgeValueFlow, FilePath: "main.go", Origin: graph.OriginASTResolved}) + g.AddEdge(&graph.Edge{From: "main.go::Mid", To: "main.go::Sink", Kind: graph.EdgeValueFlow, FilePath: "main.go", Origin: graph.OriginASTResolved}) + + e := New(g).WithRefiner(NewRefiner(g, fixtureResolver(t, nil), 0)) + paths := e.FlowBetween(refineOwner+"#local:b@+2", "main.go::Sink", 0, 0) + if len(paths) < 2 { + t.Fatalf("expected at least two competing paths, got %+v", paths) + } + // The shortest path is the 2-hop pruned route; the confirmed route + // is 3 hops. Despite being longer, the confirmed path must rank + // first because the pruned path sinks. + top := paths[0] + if hasPrunedHop(top) { + t.Errorf("top-ranked path must not be the pruned one; got pruned path %v ranked first", top.IDs) + } + // And the pruned path must still be present, just demoted. + sawPruned := false + for _, p := range paths { + if hasPrunedHop(p) { + sawPruned = true + } + } + if !sawPruned { + t.Fatalf("the pruned path must remain in the result set (demoted, not dropped): %+v", paths) + } +} + +func TestRefinerCachesPerFunctionCFG(t *testing.T) { + g := refineGraph(t) + calls := 0 + e := New(g).WithRefiner(NewRefiner(g, fixtureResolver(t, &calls), 0)) + + // Two hops inside the same function: a→b→c. One CFG build. + paths := e.FlowBetween(refineOwner+"#param:a", refineOwner+"#local:c@+3", 0, 0) + if len(paths) != 1 || paths[0].Length() != 2 { + t.Fatalf("expected one two-hop path, got %+v", paths) + } + if calls != 1 { + t.Errorf("CFG source resolved %d times, want 1 (cached per function)", calls) + } + for _, step := range paths[0].Edges { + if step.Refined != RefinedConfirmed { + t.Errorf("hop %s→%s should be confirmed, got %q", step.From, step.To, step.Refined) + } + } +} + +func TestRefinerLeavesOutOfScopeHopsUnmarked(t *testing.T) { + // Function-level nodes (no #local:/#param: binding form) are out + // of refinement scope and stay unmarked. + g := graph.New() + for _, id := range []string{"A", "B"} { + g.AddNode(&graph.Node{ID: id, Kind: graph.KindFunction, Name: id, FilePath: "x.go", Language: "go", StartLine: 1, EndLine: 2}) + } + g.AddEdge(&graph.Edge{From: "A", To: "B", Kind: graph.EdgeValueFlow, FilePath: "x.go", Origin: graph.OriginASTResolved}) + + resolverCalls := 0 + e := New(g).WithRefiner(NewRefiner(g, fixtureResolver(t, &resolverCalls), 0)) + paths := e.FlowBetween("A", "B", 0, 0) + if len(paths) != 1 { + t.Fatalf("expected one path, got %+v", paths) + } + if got := paths[0].Edges[0].Refined; got != "" { + t.Errorf("cross-symbol hop must stay unmarked, got %q", got) + } + if resolverCalls != 0 { + t.Errorf("out-of-scope hops must not trigger CFG builds (%d calls)", resolverCalls) + } +} + +func TestRefinerSurvivesResolverFailure(t *testing.T) { + g := refineGraph(t) + failing := func(fn *graph.Node) (FuncSource, error) { + return FuncSource{}, errors.New("disk gone") + } + e := New(g).WithRefiner(NewRefiner(g, failing, 0)) + paths := e.FlowBetween(refineOwner+"#param:a", refineOwner+"#local:b@+2", 0, 0) + if len(paths) != 1 { + t.Fatalf("paths must survive resolver failure, got %+v", paths) + } + if got := paths[0].Edges[0].Refined; got != "" { + t.Errorf("unresolvable source must leave hops unmarked, got %q", got) + } +} + +func TestSplitBindingID(t *testing.T) { + cases := []struct { + id string + owner, name string + ok bool + }{ + {"f.go::F#local:x@+3", "f.go::F", "x", true}, + {"f.go::F#param:in", "f.go::F", "in", true}, + // Non-Go extractors (Python / TS / Rust / Java / C#) append a + // position disambiguator to param IDs; the suffix must be + // stripped so the bare name matches the CFG's def/use sets. + {"a.py::f#param:limit@0", "a.py::f", "limit", true}, + {"a.ts::f#param:limit@2", "a.ts::f", "limit", true}, + {"f.go::F", "", "", false}, + {"unresolved::X", "", "", false}, + {"", "", "", false}, + } + for _, c := range cases { + owner, name, ok := splitBindingID(c.id) + if owner != c.owner || name != c.name || ok != c.ok { + t.Errorf("splitBindingID(%q) = (%q,%q,%v), want (%q,%q,%v)", + c.id, owner, name, ok, c.owner, c.name, c.ok) + } + } +} diff --git a/internal/mcp/gcx.go b/internal/mcp/gcx.go index 92961b2b..e43fc86f 100644 --- a/internal/mcp/gcx.go +++ b/internal/mcp/gcx.go @@ -593,6 +593,18 @@ func encodeAnalyze(kind string, payload any) ([]byte, error) { } } return buf.Bytes(), enc.Close() + case "def_use": + items, _ := payload.([]defUseItem) + enc := newGCX(&buf, "analyze.def_use", + []string{"symbol", "var", "use_line", "use_text", "def_lines"}, + "count", fmt.Sprintf("%d", len(items)), + ) + for _, it := range items { + if err := enc.WriteRow(it.Symbol, it.Var, it.UseLine, it.UseText, it.DefLines); err != nil { + return nil, err + } + } + return buf.Bytes(), enc.Close() case "channel_ops": items, _ := payload.([]channelOpItem) enc := newGCX(&buf, "analyze.channel_ops", diff --git a/internal/mcp/scope_init.go b/internal/mcp/scope_init.go index 9db56106..7d02a8bb 100644 --- a/internal/mcp/scope_init.go +++ b/internal/mcp/scope_init.go @@ -120,6 +120,10 @@ var defaultToolScopes = map[string]ToolScope{ "flow_between": ScopeRepo, "taint_paths": ScopeRepo, "trace_path": ScopeRepo, + + // Per-function control-flow graphs: built from one symbol's + // source in the active repo. + "get_cfg": ScopeRepo, } // applyDefaultToolScopes registers the canonical scope for every diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 7a63727c..ebab3990 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -1025,6 +1025,7 @@ func NewServer(engine *query.Engine, g graph.Store, idx *indexer.Indexer, watche s.registerGraphInvalidatedTools() s.registerToolProfileTool() s.registerDataflowTools() + s.registerCFGTools() s.registerASTTools() s.registerCloneTools() s.registerSimulationTools() diff --git a/internal/mcp/tools_cfg.go b/internal/mcp/tools_cfg.go new file mode 100644 index 00000000..ceb5b6ef --- /dev/null +++ b/internal/mcp/tools_cfg.go @@ -0,0 +1,388 @@ +package mcp + +import ( + "bytes" + "context" + "fmt" + "sort" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + + wire "github.com/gortexhq/gcx-go" + "github.com/zzet/gortex/internal/cfg" + "github.com/zzet/gortex/internal/graph" +) + +// registerCFGTools wires the control-flow surface: get_cfg returns a +// function's basic blocks, labeled edges, per-statement def/use sets +// and reaching-definition chains — built on demand from the symbol's +// source, never at index time. +func (s *Server) registerCFGTools() { + s.addTool( + mcp.NewTool("get_cfg", + mcp.WithDescription("Builds the intra-procedural control-flow graph for one function/method: basic blocks (ordered statements with line spans and per-statement def/use variable sets), labeled edges (seq / true / false / loop_back / break / continue / return / case / exception / finally), and statement-granular def→use chains from a GEN/KILL reaching-definitions fixpoint. Supports Go, Python, JavaScript, TypeScript, Java, Rust, Ruby. Pass mermaid:true for a Mermaid flowchart rendering. Built on demand from the symbol's source — pairs with flow_between (where a value flows across symbols) by answering how values move inside one function."), + mcp.WithString("id", mcp.Required(), mcp.Description("Function or method symbol node ID")), + mcp.WithBoolean("mermaid", mcp.Description("Include a Mermaid flowchart rendering of the block graph (default: false)")), + mcp.WithString("format", mcp.Description("Output format: json (default), gcx (GCX1 compact wire format), or toon")), + mcp.WithNumber("max_bytes", mcp.Description("Cap the marshaled response at this many bytes. The longest list is trimmed; truncation metadata rides on the response. Omit for no cap.")), + ), + s.handleGetCFG, + ) +} + +// symbolCFG is the resolved bundle both get_cfg and analyze def_use +// consume: the graph node plus its freshly built CFG and reaching- +// definitions result. +type symbolCFG struct { + node *graph.Node + graph *cfg.CFG + reach *cfg.ReachingResult +} + +// buildSymbolCFG fetches a function/method node, slices its source +// out of the owning repo's file (overlay-aware), and builds the CFG. +// Errors are caller-facing strings, suitable for tool results. +func (s *Server) buildSymbolCFG(ctx context.Context, id string) (*symbolCFG, error) { + node := s.engineFor(ctx).GetSymbol(id) + if node == nil { + return nil, fmt.Errorf("symbol not found: %s", id) + } + if !s.nodeInSessionScope(ctx, node) { + return nil, fmt.Errorf("symbol not found: %s", id) + } + if node.Kind != graph.KindFunction && node.Kind != graph.KindMethod { + return nil, fmt.Errorf("symbol %s is a %s — get_cfg needs a function or method", id, node.Kind) + } + if !cfg.SupportedLanguage(node.Language) { + return nil, fmt.Errorf("control-flow graphs are not supported for language %q (supported: go, python, javascript, typescript, java, rust, ruby)", node.Language) + } + if node.StartLine == 0 || node.EndLine == 0 { + return nil, fmt.Errorf("symbol has no line range: %s", id) + } + absPath, err := s.resolveNodePath(node) + if err != nil { + return nil, err + } + source, fromLine, _, err := s.readLinesForCtx(ctx, absPath, node.StartLine, node.EndLine, 0) + if err != nil { + return nil, fmt.Errorf("could not read source: %v", err) + } + c, err := cfg.Build([]byte(source), node.Language, cfg.Options{ + LineOffset: fromLine - 1, + FuncName: node.Name, + }) + if err != nil { + return nil, fmt.Errorf("cfg build failed for %s: %v", id, err) + } + return &symbolCFG{node: node, graph: c, reach: c.ReachingDefinitions()}, nil +} + +func (s *Server) handleGetCFG(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + id, err := s.symbolIDArg(ctx, req) + if err != nil { + return mcp.NewToolResultError("id is required"), nil + } + + // Auto re-index stale file before querying. + if parts := strings.SplitN(id, "::", 2); len(parts) == 2 { + s.ensureFresh([]string{parts[0]}) + } + + sc, buildErr := s.buildSymbolCFG(ctx, id) + if buildErr != nil { + return mcp.NewToolResultError(buildErr.Error()), nil + } + sess := s.sessionFor(ctx) + sess.recordSymbol(id) + sess.recordFile(sc.node.FilePath) + + wantMermaid := req.GetBool("mermaid", false) + + if s.isGCX(ctx, req) { + payload, encErr := encodeGetCFG(sc, wantMermaid) + return s.gcxResponseWithBudget(req)(payload, encErr) + } + + blocks := make([]map[string]any, 0, len(sc.graph.Blocks)) + for _, bl := range sc.graph.Blocks { + stmts := bl.Stmts + if stmts == nil { + stmts = []*cfg.Statement{} + } + blocks = append(blocks, map[string]any{ + "id": bl.ID, + "label": bl.Label, + "statements": stmts, + }) + } + result := map[string]any{ + "id": sc.node.ID, + "name": sc.graph.FuncName, + "kind": string(sc.node.Kind), + "language": sc.graph.Language, + "file_path": sc.node.FilePath, + "start_line": sc.node.StartLine, + "end_line": sc.node.EndLine, + "entry": sc.graph.Entry, + "exit": sc.graph.Exit, + "blocks": blocks, + "edges": sc.graph.Edges, + "def_use": sc.reach.Chains, + "total_blocks": len(sc.graph.Blocks), + "total_edges": len(sc.graph.Edges), + } + if wantMermaid { + result["mermaid"] = sc.graph.Mermaid() + } + if s.isTOON(ctx, req) { + return returnTOON(result) + } + return s.respondJSONOrTOON(ctx, req, result) +} + +// encodeGetCFG emits a GCX1 envelope with four sections — +// get_cfg.summary (one row), get_cfg.stmts (one row per statement +// with its block, span, kind and def/use sets), get_cfg.edges and +// get_cfg.chains — plus an optional get_cfg.mermaid section. +func encodeGetCFG(sc *symbolCFG, wantMermaid bool) ([]byte, error) { + var buf bytes.Buffer + sumEnc := wire.NewEncoder(&buf, wire.Header{ + Tool: "get_cfg.summary", + Fields: []string{"id", "name", "language", "file", "entry", "exit", "blocks", "edges", "stmts", "chains"}, + }) + if err := sumEnc.WriteRow(sc.node.ID, sc.graph.FuncName, sc.graph.Language, sc.node.FilePath, + sc.graph.Entry, sc.graph.Exit, len(sc.graph.Blocks), len(sc.graph.Edges), + len(sc.graph.Stmts), len(sc.reach.Chains)); err != nil { + return nil, err + } + if err := sumEnc.Close(); err != nil { + return nil, err + } + + stmtEnc := wire.NewEncoder(&buf, wire.Header{ + Tool: "get_cfg.stmts", + Fields: []string{"index", "block", "start_line", "end_line", "kind", "defs", "uses", "text"}, + Meta: map[string]string{"count": fmt.Sprintf("%d", len(sc.graph.Stmts))}, + }) + for _, st := range sc.graph.Stmts { + if err := stmtEnc.WriteRow(st.Index, st.Block, st.StartLine, st.EndLine, st.Kind, + strings.Join(st.Defs, ","), strings.Join(st.Uses, ","), st.Text); err != nil { + return nil, err + } + } + if err := stmtEnc.Close(); err != nil { + return nil, err + } + + edgeEnc := wire.NewEncoder(&buf, wire.Header{ + Tool: "get_cfg.edges", + Fields: []string{"from", "to", "label"}, + Meta: map[string]string{"count": fmt.Sprintf("%d", len(sc.graph.Edges))}, + }) + for _, e := range sc.graph.Edges { + if err := edgeEnc.WriteRow(e.From, e.To, string(e.Label)); err != nil { + return nil, err + } + } + if err := edgeEnc.Close(); err != nil { + return nil, err + } + + chainEnc := wire.NewEncoder(&buf, wire.Header{ + Tool: "get_cfg.chains", + Fields: []string{"stmt", "var", "defs"}, + Meta: map[string]string{"count": fmt.Sprintf("%d", len(sc.reach.Chains))}, + }) + for _, ch := range sc.reach.Chains { + if err := chainEnc.WriteRow(ch.Stmt, ch.Var, joinInts(ch.Defs)); err != nil { + return nil, err + } + } + if err := chainEnc.Close(); err != nil { + return nil, err + } + + if wantMermaid { + mEnc := wire.NewEncoder(&buf, wire.Header{ + Tool: "get_cfg.mermaid", + Fields: []string{"diagram"}, + }) + if err := mEnc.WriteRow(sc.graph.Mermaid()); err != nil { + return nil, err + } + if err := mEnc.Close(); err != nil { + return nil, err + } + } + return buf.Bytes(), nil +} + +// --------------------------------------------------------------------------- +// analyze kind=def_use +// --------------------------------------------------------------------------- + +// defUseItem is the GCX row shape for analyze def_use: one row per +// def→use chain, flattened to lines so consumers don't need the +// block table. +type defUseItem struct { + Symbol string + Var string + UseLine int + UseText string + DefLines string +} + +// handleAnalyzeDefUse computes statement-granular def→use chains and +// a per-variable reaching-definition summary for the requested +// function/method symbols. `ids` is a comma-separated ID list (or a +// JSON array); `id` works for a single symbol. Symbols that can't be +// analyzed (wrong kind, unsupported language, unreadable source) +// degrade to a per-symbol error instead of failing the whole call. +func (s *Server) handleAnalyzeDefUse(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ids := symbolIDList(req.GetArguments()) + if len(ids) == 0 { + return mcp.NewToolResultError("def_use requires `ids` (comma-separated symbol IDs) or `id`"), nil + } + + type chainRow struct { + Stmt int `json:"stmt"` + StmtLine int `json:"stmt_line"` + StmtText string `json:"stmt_text"` + Var string `json:"var"` + Defs []int `json:"defs"` + DefLines []int `json:"def_lines"` + } + type varRow struct { + Var string `json:"var"` + Defs int `json:"defs"` + Uses int `json:"uses"` + DefLines []int `json:"def_lines"` + } + type symbolRow struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + File string `json:"file,omitempty"` + Language string `json:"language,omitempty"` + Chains []chainRow `json:"chains,omitempty"` + Variables []varRow `json:"variables,omitempty"` + Error string `json:"error,omitempty"` + } + + rows := make([]symbolRow, 0, len(ids)) + var gcxItems []defUseItem + for _, id := range ids { + sc, err := s.buildSymbolCFG(ctx, id) + if err != nil { + rows = append(rows, symbolRow{ID: id, Error: err.Error()}) + continue + } + row := symbolRow{ + ID: sc.node.ID, + Name: sc.graph.FuncName, + File: sc.node.FilePath, + Language: sc.graph.Language, + } + lineOf := func(stmt int) int { return sc.graph.Stmts[stmt].StartLine } + for _, ch := range sc.reach.Chains { + defLines := make([]int, len(ch.Defs)) + for i, d := range ch.Defs { + defLines[i] = lineOf(d) + } + row.Chains = append(row.Chains, chainRow{ + Stmt: ch.Stmt, + StmtLine: lineOf(ch.Stmt), + StmtText: sc.graph.Stmts[ch.Stmt].Text, + Var: ch.Var, + Defs: ch.Defs, + DefLines: defLines, + }) + gcxItems = append(gcxItems, defUseItem{ + Symbol: sc.node.ID, + Var: ch.Var, + UseLine: lineOf(ch.Stmt), + UseText: sc.graph.Stmts[ch.Stmt].Text, + DefLines: joinInts(defLines), + }) + } + // Per-variable rollup: definition sites and read counts. + defLines := map[string][]int{} + useCount := map[string]int{} + for _, d := range sc.reach.Defs { + defLines[d.Var] = append(defLines[d.Var], lineOf(d.Stmt)) + } + for _, st := range sc.graph.Stmts { + for _, u := range st.Uses { + useCount[u]++ + } + } + vars := make([]string, 0, len(defLines)) + for v := range defLines { + vars = append(vars, v) + } + sort.Strings(vars) + for _, v := range vars { + row.Variables = append(row.Variables, varRow{ + Var: v, + Defs: len(defLines[v]), + Uses: useCount[v], + DefLines: defLines[v], + }) + } + rows = append(rows, row) + } + + if s.isGCX(ctx, req) { + return s.gcxResponseWithBudget(req)(encodeAnalyze("def_use", gcxItems)) + } + if isCompact(req) { + var b strings.Builder + for _, r := range rows { + if r.Error != "" { + fmt.Fprintf(&b, "%s ERROR %s\n", r.ID, r.Error) + continue + } + for _, ch := range r.Chains { + fmt.Fprintf(&b, "%s:%d %s <- defs at %s\n", r.File, ch.StmtLine, ch.Var, joinInts(ch.DefLines)) + } + } + if b.Len() == 0 { + b.WriteString("no def->use chains\n") + } + return mcp.NewToolResultText(b.String()), nil + } + return s.respondJSONOrTOON(ctx, req, map[string]any{ + "symbols": rows, + "total": len(rows), + }) +} + +// symbolIDList parses the def_use id arguments: `ids` as a comma- +// separated string or JSON array, falling back to a single `id`. +func symbolIDList(args map[string]any) []string { + var out []string + add := func(s string) { + s = strings.TrimSpace(s) + if s != "" { + out = append(out, s) + } + } + switch v := args["ids"].(type) { + case string: + for _, part := range strings.Split(v, ",") { + add(part) + } + case []any: + for _, item := range v { + if s, ok := item.(string); ok { + add(s) + } + } + } + if len(out) == 0 { + if id, ok := args["id"].(string); ok { + add(id) + } + } + return out +} diff --git a/internal/mcp/tools_cfg_test.go b/internal/mcp/tools_cfg_test.go new file mode 100644 index 00000000..9d944d0a --- /dev/null +++ b/internal/mcp/tools_cfg_test.go @@ -0,0 +1,330 @@ +package mcp + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/config" + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/indexer" + "github.com/zzet/gortex/internal/parser" + "github.com/zzet/gortex/internal/parser/languages" + "github.com/zzet/gortex/internal/query" +) + +// cfgTestServer indexes one Go file with a branchy function so the +// CFG tools can be exercised end-to-end against a real graph + +// on-disk source. +func cfgTestServer(t *testing.T) *Server { + t.Helper() + dir := t.TempDir() + src := `package main + +func Classify(score int) string { + label := "low" + if score > 90 { + label = "high" + } else if score > 50 { + label = "mid" + } + for i := 0; i < score; i++ { + if i == 3 { + break + } + } + return label +} + +var topLevel = 1 +` + require.NoError(t, os.WriteFile(filepath.Join(dir, "main.go"), []byte(src), 0o644)) + g := graph.New() + reg := parser.NewRegistry() + languages.RegisterAll(reg) + cfgConf := config.Default() + idx := indexer.New(g, reg, cfgConf.Index, zap.NewNop()) + _, err := idx.Index(dir) + require.NoError(t, err) + idx.ResolveAll() + eng := query.NewEngine(g) + return NewServer(eng, g, idx, nil, zap.NewNop(), nil) +} + +func cfgFindSymbol(t *testing.T, srv *Server, name string, kinds ...graph.NodeKind) string { + t.Helper() + kindOK := func(k graph.NodeKind) bool { + if len(kinds) == 0 { + return true + } + for _, want := range kinds { + if k == want { + return true + } + } + return false + } + for _, n := range srv.graph.AllNodes() { + if n.Name == name && kindOK(n.Kind) { + return n.ID + } + } + t.Fatalf("symbol %q not found", name) + return "" +} + +func TestHandleGetCFG_BlocksEdgesChains(t *testing.T) { + srv := cfgTestServer(t) + id := cfgFindSymbol(t, srv, "Classify", graph.KindFunction) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{"id": id} + res, err := srv.handleGetCFG(t.Context(), req) + require.NoError(t, err) + require.False(t, res.IsError, "tool errored: %v", res) + + var payload struct { + Name string `json:"name"` + Blocks []struct { + ID int `json:"id"` + Statements []struct { + Text string `json:"text"` + Defs []string `json:"defs"` + Uses []string `json:"uses"` + } `json:"statements"` + } `json:"blocks"` + Edges []struct { + From int `json:"from"` + To int `json:"to"` + Label string `json:"label"` + } `json:"edges"` + DefUse []struct { + Stmt int `json:"stmt"` + Var string `json:"var"` + Defs []int `json:"defs"` + } `json:"def_use"` + TotalBlocks int `json:"total_blocks"` + } + text := res.Content[0].(mcplib.TextContent).Text + require.NoError(t, json.Unmarshal([]byte(text), &payload)) + require.Equal(t, "Classify", payload.Name) + require.Greater(t, payload.TotalBlocks, 4, "branchy function needs several blocks") + + labels := map[string]bool{} + for _, e := range payload.Edges { + labels[e.Label] = true + } + for _, want := range []string{"true", "false", "loop_back", "break", "return"} { + require.True(t, labels[want], "missing %s edge in %s", want, text) + } + + // The return's use of label must chain to all three defs. + var labelChain []int + for _, ch := range payload.DefUse { + if ch.Var == "label" { + labelChain = ch.Defs + } + } + require.Len(t, labelChain, 3, "label has three reaching defs at the return: %s", text) +} + +func TestHandleGetCFG_Mermaid(t *testing.T) { + srv := cfgTestServer(t) + id := cfgFindSymbol(t, srv, "Classify", graph.KindFunction) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{"id": id, "mermaid": true} + res, err := srv.handleGetCFG(t.Context(), req) + require.NoError(t, err) + require.False(t, res.IsError) + text := res.Content[0].(mcplib.TextContent).Text + var payload struct { + Mermaid string `json:"mermaid"` + } + require.NoError(t, json.Unmarshal([]byte(text), &payload)) + require.Contains(t, payload.Mermaid, "flowchart TD") + require.Contains(t, payload.Mermaid, "-->|true|") +} + +func TestHandleGetCFG_GCXFormat(t *testing.T) { + srv := cfgTestServer(t) + id := cfgFindSymbol(t, srv, "Classify", graph.KindFunction) + + req := mcplib.CallToolRequest{} + req.Params.Name = "get_cfg" + req.Params.Arguments = map[string]any{"id": id, "format": "gcx", "mermaid": true} + res, err := srv.handleGetCFG(t.Context(), req) + require.NoError(t, err) + require.False(t, res.IsError) + text := res.Content[0].(mcplib.TextContent).Text + for _, section := range []string{"get_cfg.summary", "get_cfg.stmts", "get_cfg.edges", "get_cfg.chains", "get_cfg.mermaid"} { + require.Contains(t, text, section) + } +} + +func TestHandleGetCFG_Errors(t *testing.T) { + srv := cfgTestServer(t) + + // Missing id. + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{} + res, err := srv.handleGetCFG(t.Context(), req) + require.NoError(t, err) + require.True(t, res.IsError) + + // Unknown symbol. + req.Params.Arguments = map[string]any{"id": "nope.go::Missing"} + res, err = srv.handleGetCFG(t.Context(), req) + require.NoError(t, err) + require.True(t, res.IsError) + + // Non-function symbol. + varID := cfgFindSymbol(t, srv, "topLevel") + req.Params.Arguments = map[string]any{"id": varID} + res, err = srv.handleGetCFG(t.Context(), req) + require.NoError(t, err) + require.True(t, res.IsError, "get_cfg on a variable must error") +} + +func TestHandleAnalyzeDefUse_ThroughDispatcher(t *testing.T) { + srv := cfgTestServer(t) + id := cfgFindSymbol(t, srv, "Classify", graph.KindFunction) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{"kind": "def_use", "ids": id} + res, err := srv.handleAnalyze(t.Context(), req) + require.NoError(t, err) + require.False(t, res.IsError, "analyze def_use errored: %v", res) + + var payload struct { + Total int `json:"total"` + Symbols []struct { + ID string `json:"id"` + Error string `json:"error"` + Chains []struct { + Var string `json:"var"` + StmtLine int `json:"stmt_line"` + DefLines []int `json:"def_lines"` + } `json:"chains"` + Variables []struct { + Var string `json:"var"` + Defs int `json:"defs"` + Uses int `json:"uses"` + } `json:"variables"` + } `json:"symbols"` + } + text := res.Content[0].(mcplib.TextContent).Text + require.NoError(t, json.Unmarshal([]byte(text), &payload)) + require.Equal(t, 1, payload.Total) + sym := payload.Symbols[0] + require.Empty(t, sym.Error) + require.NotEmpty(t, sym.Chains) + + // label's per-variable summary: 3 defs (init + two arms). + foundLabel := false + for _, v := range sym.Variables { + if v.Var == "label" { + foundLabel = true + require.Equal(t, 3, v.Defs, "label is defined three times") + require.GreaterOrEqual(t, v.Uses, 1) + } + } + require.True(t, foundLabel, "variables rollup must include label: %s", text) +} + +func TestHandleAnalyzeDefUse_DegradesPerSymbol(t *testing.T) { + srv := cfgTestServer(t) + good := cfgFindSymbol(t, srv, "Classify", graph.KindFunction) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{"kind": "def_use", "ids": good + ",missing.go::Nope"} + res, err := srv.handleAnalyze(t.Context(), req) + require.NoError(t, err) + require.False(t, res.IsError) + + var payload struct { + Symbols []struct { + ID string `json:"id"` + Error string `json:"error"` + } `json:"symbols"` + } + text := res.Content[0].(mcplib.TextContent).Text + require.NoError(t, json.Unmarshal([]byte(text), &payload)) + require.Len(t, payload.Symbols, 2) + require.Empty(t, payload.Symbols[0].Error) + require.NotEmpty(t, payload.Symbols[1].Error, "missing symbol must degrade to a per-symbol error") +} + +func TestHandleAnalyzeDefUse_GCX(t *testing.T) { + srv := cfgTestServer(t) + id := cfgFindSymbol(t, srv, "Classify", graph.KindFunction) + + req := mcplib.CallToolRequest{} + req.Params.Name = "analyze" + req.Params.Arguments = map[string]any{"kind": "def_use", "ids": id, "format": "gcx"} + res, err := srv.handleAnalyze(t.Context(), req) + require.NoError(t, err) + require.False(t, res.IsError) + text := res.Content[0].(mcplib.TextContent).Text + require.Contains(t, text, "analyze.def_use") +} + +func TestHandleAnalyzeDefUse_RequiresIDs(t *testing.T) { + srv := cfgTestServer(t) + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{"kind": "def_use"} + res, err := srv.handleAnalyze(t.Context(), req) + require.NoError(t, err) + require.True(t, res.IsError) +} + +// TestHandleFlowBetween_RefinementLive proves the reaching- +// definitions refinement runs on the real flow_between path: the +// indexed fixture's Mid function binds `v := s`, so the +// param-s → local-v hop must come back stamped +// confirmed_intraprocedural. +func TestHandleFlowBetween_RefinementLive(t *testing.T) { + srv := dataflowTestServer(t) + driverID := findFunctionID(t, srv, "Driver") + sinkID := findFunctionID(t, srv, "Sink") + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "source_id": driverID + "#param:input", + "sink_id": sinkID + "#param:payload", + "max_depth": float64(10), + } + res, err := srv.handleFlowBetween(t.Context(), req) + require.NoError(t, err) + require.False(t, res.IsError, "tool errored: %v", res) + + var payload struct { + Paths []struct { + Edges []struct { + From string `json:"from"` + To string `json:"to"` + Kind string `json:"kind"` + Refined string `json:"refined"` + } `json:"edges"` + } `json:"paths"` + } + text := res.Content[0].(mcplib.TextContent).Text + require.NoError(t, json.Unmarshal([]byte(text), &payload)) + require.NotEmpty(t, payload.Paths) + + confirmed := 0 + for _, p := range payload.Paths { + for _, e := range p.Edges { + require.NotEqual(t, "pruned", e.Refined, "fixture has no stale flows: %s -> %s", e.From, e.To) + if e.Refined == "confirmed_intraprocedural" { + confirmed++ + } + } + } + require.Greater(t, confirmed, 0, "at least one same-function value_flow hop must be confirmed: %s", text) +} diff --git a/internal/mcp/tools_dataflow.go b/internal/mcp/tools_dataflow.go index 24417376..6c4f9bcb 100644 --- a/internal/mcp/tools_dataflow.go +++ b/internal/mcp/tools_dataflow.go @@ -84,7 +84,7 @@ func (s *Server) handleFlowBetween(ctx context.Context, req mcp.CallToolRequest) maxPaths := req.GetInt("max_paths", dataflow.DefaultMaxPaths) minTier := req.GetString("min_tier", "") - engine := dataflow.New(s.graph) + engine := dataflow.New(s.graph).WithRefiner(s.dataflowRefiner(ctx)) paths := engine.FlowBetweenWithTier(source, sink, maxDepth, maxPaths, minTier) if s.isGCX(ctx, req) { @@ -154,7 +154,7 @@ func (s *Server) handleTaintPaths(ctx context.Context, req mcp.CallToolRequest) return mcp.NewToolResultError("sink_pattern matched no clauses"), nil } - engine := dataflow.New(s.graph) + engine := dataflow.New(s.graph).WithRefiner(s.dataflowRefiner(ctx)) findings := engine.TaintPathsWithTier(src, sink, maxDepth, limit, minTier) if s.isGCX(ctx, req) { @@ -182,6 +182,29 @@ func (s *Server) handleTaintPaths(ctx context.Context, req mcp.CallToolRequest) return s.respondJSONOrTOON(ctx, req, result) } +// dataflowRefiner builds the per-call CFG-backed refiner that +// confirms or prunes same-function value_flow hops on +// flow_between / taint_paths paths. The source resolver reads +// through the session's overlay so unsaved buffers refine +// consistently with every other tool. +func (s *Server) dataflowRefiner(ctx context.Context) *dataflow.Refiner { + resolve := func(fn *graph.Node) (dataflow.FuncSource, error) { + if fn.StartLine == 0 || fn.EndLine == 0 { + return dataflow.FuncSource{}, fmt.Errorf("symbol has no line range: %s", fn.ID) + } + absPath, err := s.resolveNodePath(fn) + if err != nil { + return dataflow.FuncSource{}, err + } + src, fromLine, _, err := s.readLinesForCtx(ctx, absPath, fn.StartLine, fn.EndLine, 0) + if err != nil { + return dataflow.FuncSource{}, err + } + return dataflow.FuncSource{Src: []byte(src), StartLine: fromLine}, nil + } + return dataflow.NewRefiner(s.graph, resolve, 0) +} + // describeNode returns a JSON-shaped summary of a graph node for // taint findings. func describeNode(n *graph.Node) map[string]any { @@ -225,7 +248,7 @@ func encodeFlowBetween(source, sink string, paths []dataflow.Path) ([]byte, erro } pathEnc := wire.NewEncoder(&buf, wire.Header{ Tool: "flow_between.paths", - Fields: []string{"length", "confidence", "worst_tier", "ids", "kinds", "origins", "tiers"}, + Fields: []string{"length", "confidence", "worst_tier", "ids", "kinds", "origins", "tiers", "refined"}, Meta: map[string]string{ "count": fmt.Sprintf("%d", len(paths)), }, @@ -235,7 +258,8 @@ func encodeFlowBetween(source, sink string, paths []dataflow.Path) ([]byte, erro kinds := joinEdgeKinds(p.Edges) origins := joinEdgeOrigins(p.Edges) tiers := joinEdgeTiers(p.Edges) - if err := pathEnc.WriteRow(p.Length(), p.Confidence, worstTierOnPath(p.Edges), ids, kinds, origins, tiers); err != nil { + refined := joinEdgeRefined(p.Edges) + if err := pathEnc.WriteRow(p.Length(), p.Confidence, worstTierOnPath(p.Edges), ids, kinds, origins, tiers, refined); err != nil { return nil, err } } @@ -370,7 +394,7 @@ func encodeTaintPaths(srcPattern, sinkPattern string, findings []dataflow.TaintF Fields: []string{ "source_id", "source_name", "sink_id", "sink_name", "best_length", "best_confidence", "best_worst_tier", - "paths", "best_ids", "best_kinds", "best_origins", "best_tiers", + "paths", "best_ids", "best_kinds", "best_origins", "best_tiers", "best_refined", }, }) for _, f := range findings { @@ -391,6 +415,7 @@ func encodeTaintPaths(srcPattern, sinkPattern string, findings []dataflow.TaintF joinEdgeKinds(best.Edges), joinEdgeOrigins(best.Edges), joinEdgeTiers(best.Edges), + joinEdgeRefined(best.Edges), } if err := findEnc.WriteRow(row...); err != nil { return nil, err @@ -455,6 +480,26 @@ func joinEdgeOrigins(edges []dataflow.EdgeStep) string { return b.String() } +// joinEdgeRefined flattens the per-step refinement markers; empty +// markers stay empty so positions align with the kinds/tiers fields. +func joinEdgeRefined(edges []dataflow.EdgeStep) string { + if len(edges) == 0 { + return "" + } + parts := make([]string, len(edges)) + any := false + for i, e := range edges { + parts[i] = e.Refined + if e.Refined != "" { + any = true + } + } + if !any { + return "" + } + return strings.Join(parts, ",") +} + func joinEdgeTiers(edges []dataflow.EdgeStep) string { if len(edges) == 0 { return "" diff --git a/internal/mcp/tools_enhancements.go b/internal/mcp/tools_enhancements.go index 8b07c480..f2d25875 100644 --- a/internal/mcp/tools_enhancements.go +++ b/internal/mcp/tools_enhancements.go @@ -732,7 +732,7 @@ func (s *Server) handlePrefetchContext(ctx context.Context, req mcp.CallToolRequ func (s *Server) handleAnalyze(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { kind, err := req.RequireString("kind") if err != nil { - return mcp.NewToolResultError("kind is required (one of: dead_code, hotspots, cycles, would_create_cycle, todos, blame, coverage, stale_code, ownership, coverage_gaps, stale_flags, releases, cgo_users, wasm_users, orphan_tables, unreferenced_tables, coverage_summary, channel_ops, goroutine_spawns, field_writers, race_writes, unclosed_channels, unsafe_patterns, health_score, annotation_users, config_readers, event_emitters, pubsub, string_emitters, error_surface, log_events, sql_rebuild, external_calls, synthesizers, resolution_outcomes, retrieval_log, routes, models, components, k8s_resources, images, kustomize, cross_repo, impact, named, tests_as_edges, connectivity_health, pagerank, louvain, wcc, scc, kcore)"), nil + return mcp.NewToolResultError("kind is required (one of: dead_code, hotspots, cycles, would_create_cycle, todos, blame, coverage, stale_code, ownership, coverage_gaps, stale_flags, releases, cgo_users, wasm_users, orphan_tables, unreferenced_tables, coverage_summary, channel_ops, def_use, goroutine_spawns, field_writers, race_writes, unclosed_channels, unsafe_patterns, health_score, annotation_users, config_readers, event_emitters, pubsub, string_emitters, error_surface, log_events, sql_rebuild, external_calls, synthesizers, resolution_outcomes, retrieval_log, routes, models, components, k8s_resources, images, kustomize, cross_repo, impact, named, tests_as_edges, connectivity_health, pagerank, louvain, wcc, scc, kcore)"), nil } switch kind { case "dead_code": @@ -771,6 +771,8 @@ func (s *Server) handleAnalyze(ctx context.Context, req mcp.CallToolRequest) (*m return s.handleAnalyzeCoverageSummary(ctx, req) case "channel_ops": return s.handleAnalyzeChannelOps(ctx, req) + case "def_use": + return s.handleAnalyzeDefUse(ctx, req) case "goroutine_spawns": return s.handleAnalyzeGoroutineSpawns(ctx, req) case "field_writers": @@ -872,7 +874,7 @@ func (s *Server) handleAnalyze(ctx context.Context, req mcp.CallToolRequest) (*m case "kcore": return s.handleAnalyzeKCore(ctx, req) default: - return mcp.NewToolResultError("unknown analyze kind: " + kind + " (expected: dead_code, hotspots, cycles, would_create_cycle, todos, blame, coverage, stale_code, ownership, coverage_gaps, stale_flags, releases, cgo_users, wasm_users, orphan_tables, unreferenced_tables, coverage_summary, channel_ops, goroutine_spawns, field_writers, race_writes, unclosed_channels, unsafe_patterns, sast, hygiene, review, health_score, annotation_users, config_readers, env_var_users, sql_call_sites, fixes_history, edge_audit, domain, event_emitters, pubsub, string_emitters, error_surface, log_events, sql_rebuild, external_calls, resolution_outcomes, retrieval_log, routes, models, components, k8s_resources, images, kustomize, cross_repo, dbt_models, impact, bottlenecks, named, tests_as_edges, connectivity_health, pagerank, louvain, wcc, scc, kcore)"), nil + return mcp.NewToolResultError("unknown analyze kind: " + kind + " (expected: dead_code, hotspots, cycles, would_create_cycle, todos, blame, coverage, stale_code, ownership, coverage_gaps, stale_flags, releases, cgo_users, wasm_users, orphan_tables, unreferenced_tables, coverage_summary, channel_ops, def_use, goroutine_spawns, field_writers, race_writes, unclosed_channels, unsafe_patterns, sast, hygiene, review, health_score, annotation_users, config_readers, env_var_users, sql_call_sites, fixes_history, edge_audit, domain, event_emitters, pubsub, string_emitters, error_surface, log_events, sql_rebuild, external_calls, resolution_outcomes, retrieval_log, routes, models, components, k8s_resources, images, kustomize, cross_repo, dbt_models, impact, bottlenecks, named, tests_as_edges, connectivity_health, pagerank, louvain, wcc, scc, kcore)"), nil } } From d14ebaf3fe64d0af0ce9046a2d645ea86412acd6 Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Sat, 13 Jun 2026 08:06:34 +0200 Subject: [PATCH 3/5] feat(indexer): affected-by re-resolution on incremental sync When an incremental sync changes a file's symbol signatures (or removes symbols / changes their kind), the files that reference those symbols are re-resolved synchronously in the same pipeline; a body-only edit produces no delta and fans out to nothing. The delta is computed on a line-insensitive, graph-derived symbol shape so it is meaningful across languages and not defeated by line-embedded node IDs, and parse failures are not mistaken for symbol removal. Affected files come from a reverse reference-facts lookup (RefFactsReader.LoadRefFactsByTargets, backed by a new by-target index) unioned with a pre-evict in-edge snapshot, capped with truncation accounting. The no-delta path stays cheap. Also fixes the cold-index resolver shadow-swap and a stale ref-facts row left when a reference disappears. --- internal/config/config.go | 11 + internal/graph/store.go | 8 + internal/graph/store_sqlite/schema.go | 5 + internal/graph/store_sqlite/store_reffacts.go | 54 ++ .../graph/store_sqlite/store_reffacts_test.go | 72 +++ internal/indexer/affected_by.go | 487 ++++++++++++++++++ .../indexer/affected_by_crosslang_test.go | 197 +++++++ internal/indexer/affected_by_e2e_test.go | 263 ++++++++++ internal/indexer/indexer.go | 52 +- internal/indexer/ref_facts.go | 29 +- internal/resolver/resolver.go | 53 +- 11 files changed, 1215 insertions(+), 16 deletions(-) create mode 100644 internal/indexer/affected_by.go create mode 100644 internal/indexer/affected_by_crosslang_test.go create mode 100644 internal/indexer/affected_by_e2e_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 5765daa0..28c12fdf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -606,6 +606,17 @@ type IndexConfig struct { // ON. Set false (or GORTEX_INDEX_SCOPED_GLOBAL_PASSES=0) to restore the // whole-graph behaviour. ScopedGlobalPasses *bool `mapstructure:"scoped_global_passes" yaml:"scoped_global_passes,omitempty"` + // AffectedByReresolveMax caps how many referencing files a single + // save's affected-by re-resolution pass will re-resolve. When a + // file's symbol signatures change (or symbols are removed / change + // kind), the files that referenced those symbols are re-resolved so + // their edges and persisted reference facts track the new shape; a + // symbol with thousands of callers must not turn one save into a + // near-whole-graph resolve, so the affected set is truncated at this + // bound (truncation is logged with the dropped count). Zero (the + // default) uses the built-in cap of 200 files. Configured under + // `index.affected_by_reresolve_max` in .gortex.yaml. + AffectedByReresolveMax int `mapstructure:"affected_by_reresolve_max" yaml:"affected_by_reresolve_max,omitempty"` // Transforms are pluggable pre-ingestion content processors. Each // rewrites a matching file's bytes before the parser sees them — // expanding minified bundles, normalising SVG/TOON, converting a diff --git a/internal/graph/store.go b/internal/graph/store.go index 3f6d3d78..e8bd259e 100644 --- a/internal/graph/store.go +++ b/internal/graph/store.go @@ -1052,6 +1052,14 @@ type RefFactsWriter interface { // set of source files (all files when files is empty), as the audit/diff seed. type RefFactsReader interface { LoadRefFactsByFiles(repoPrefix string, files []string) ([]RefFact, error) + // LoadRefFactsByTargets is the reverse lookup: the persisted facts that + // resolve TO any of the given node IDs, grouped by source file path. It + // answers "which files reference these symbols" durably — live in-edges + // are dropped when their target file is re-indexed, but the sidecar row + // keyed by to_id survives, so incremental re-resolution can find the + // referencing files after the eviction. Empty input yields an empty, + // non-nil map. + LoadRefFactsByTargets(repoPrefix string, targetIDs []string) (map[string][]RefFact, error) } // ChurnEnrichment is one node's git-churn enrichment, moved out of diff --git a/internal/graph/store_sqlite/schema.go b/internal/graph/store_sqlite/schema.go index 9a816d34..8c7a794d 100644 --- a/internal/graph/store_sqlite/schema.go +++ b/internal/graph/store_sqlite/schema.go @@ -140,6 +140,11 @@ CREATE TABLE IF NOT EXISTS ref_facts ( PRIMARY KEY (repo_prefix, from_id, to_id, kind, line) ) WITHOUT ROWID; CREATE INDEX IF NOT EXISTS ref_facts_by_file ON ref_facts(repo_prefix, file_path); +-- ref_facts_by_target backs the reverse lookup ("which files hold a fact +-- resolving TO these symbols") that affected-by re-resolution runs when a +-- file's symbol signatures change. Without it that query is a full +-- ref_facts scan — the PK leads with from_id, not to_id. +CREATE INDEX IF NOT EXISTS ref_facts_by_target ON ref_facts(repo_prefix, to_id); CREATE TABLE IF NOT EXISTS vectors ( node_id TEXT PRIMARY KEY, diff --git a/internal/graph/store_sqlite/store_reffacts.go b/internal/graph/store_sqlite/store_reffacts.go index fd3fea5f..f024dd45 100644 --- a/internal/graph/store_sqlite/store_reffacts.go +++ b/internal/graph/store_sqlite/store_reffacts.go @@ -163,3 +163,57 @@ func (s *Store) LoadRefFactsByFiles(repoPrefix string, files []string) ([]graph. } return out, nil } + +// LoadRefFactsByTargets returns the persisted facts that resolve TO any of +// the given node IDs for one repo prefix, grouped by source file path — the +// reverse lookup incremental re-resolution uses to find the files that +// referenced a changed symbol after its live in-edges were evicted. Served by +// the ref_facts_by_target index, chunked under the host-parameter limit. +// Always non-nil; empty input is a no-op. +func (s *Store) LoadRefFactsByTargets(repoPrefix string, targetIDs []string) (map[string][]graph.RefFact, error) { + out := map[string][]graph.RefFact{} + if len(targetIDs) == 0 { + return out, nil + } + const cols = `from_id, to_id, kind, ref_name, line, origin, tier, candidates, file_path, lang` + for start := 0; start < len(targetIDs); start += refFactChunk { + end := start + refFactChunk + if end > len(targetIDs) { + end = len(targetIDs) + } + chunk := targetIDs[start:end] + args := make([]any, 0, len(chunk)+1) + args = append(args, repoPrefix) + stmt := make([]byte, 0, 96+len(chunk)*2) + stmt = append(stmt, "SELECT "+cols+" FROM ref_facts WHERE repo_prefix = ? AND to_id IN ("...) + for i, id := range chunk { + if i > 0 { + stmt = append(stmt, ',') + } + stmt = append(stmt, '?') + args = append(args, id) + } + stmt = append(stmt, ')') + rows, err := s.db.Query(string(stmt), args...) + if err != nil { + return nil, err + } + for rows.Next() { + var f graph.RefFact + var cand string + if err := rows.Scan(&f.FromID, &f.ToID, &f.Kind, &f.RefName, &f.Line, &f.Origin, &f.Tier, &cand, &f.FilePath, &f.Lang); err != nil { + _ = rows.Close() + return nil, err + } + f.RepoPrefix = repoPrefix + f.Candidates = decodeCandidates(cand) + out[f.FilePath] = append(out[f.FilePath], f) + } + err = rows.Err() + _ = rows.Close() + if err != nil { + return nil, err + } + } + return out, nil +} diff --git a/internal/graph/store_sqlite/store_reffacts_test.go b/internal/graph/store_sqlite/store_reffacts_test.go index 9b861cc6..e7d6c3ed 100644 --- a/internal/graph/store_sqlite/store_reffacts_test.go +++ b/internal/graph/store_sqlite/store_reffacts_test.go @@ -98,3 +98,75 @@ func TestRefFacts_EmptyNoop(t *testing.T) { require.NoError(t, err) require.Empty(t, got) } + +func TestRefFacts_LoadByTargets(t *testing.T) { + s := openRefFactStore(t) + require.NoError(t, s.BulkSetRefFacts("", []graph.RefFact{ + {FromID: "b.go::Caller", ToID: "a.go::F", Kind: "calls", RefName: "F", Line: 3, Origin: "ast_resolved", Tier: "ast", FilePath: "b.go", Lang: "go"}, + {FromID: "c.go::Other", ToID: "a.go::F", Kind: "references", RefName: "F", FilePath: "c.go"}, + {FromID: "c.go::Other", ToID: "a.go::G", Kind: "calls", RefName: "G", FilePath: "c.go"}, + {FromID: "d.go::X", ToID: "z.go::Z", Kind: "calls", RefName: "Z", FilePath: "d.go"}, + })) + + byFile, err := s.LoadRefFactsByTargets("", []string{"a.go::F", "a.go::G"}) + require.NoError(t, err) + require.Len(t, byFile, 2, "facts must be grouped by source file: %v", byFile) + require.Len(t, byFile["b.go"], 1) + require.Equal(t, "a.go::F", byFile["b.go"][0].ToID) + require.Equal(t, "F", byFile["b.go"][0].RefName) + require.Equal(t, "ast_resolved", byFile["b.go"][0].Origin) + require.Len(t, byFile["c.go"], 2, "both of c.go's facts target the queried symbols") + require.NotContains(t, byFile, "d.go", "a fact targeting an unqueried symbol must not match") +} + +func TestRefFacts_LoadByTargets_EmptyAndMissing(t *testing.T) { + s := openRefFactStore(t) + require.NoError(t, s.BulkSetRefFacts("", []graph.RefFact{ + {FromID: "a.go::A", ToID: "b.go::B", Kind: "calls", FilePath: "a.go"}, + })) + + // Empty input: empty, non-nil map. + empty, err := s.LoadRefFactsByTargets("", nil) + require.NoError(t, err) + require.NotNil(t, empty) + require.Empty(t, empty) + + // A target nothing references: no rows, no error. + miss, err := s.LoadRefFactsByTargets("", []string{"nope::Missing"}) + require.NoError(t, err) + require.Empty(t, miss) +} + +func TestRefFacts_LoadByTargets_RepoScoping(t *testing.T) { + s := openRefFactStore(t) + require.NoError(t, s.BulkSetRefFacts("repoA", []graph.RefFact{{FromID: "fa::A", ToID: "shared::T", Kind: "calls", FilePath: "fa.go"}})) + require.NoError(t, s.BulkSetRefFacts("repoB", []graph.RefFact{{FromID: "fb::B", ToID: "shared::T", Kind: "calls", FilePath: "fb.go"}})) + + a, err := s.LoadRefFactsByTargets("repoA", []string{"shared::T"}) + require.NoError(t, err) + require.Len(t, a, 1) + require.Len(t, a["fa.go"], 1) + require.Equal(t, "repoA", a["fa.go"][0].RepoPrefix, "loaded facts must carry the queried repo prefix") + require.NotContains(t, a, "fb.go", "another repo's facts must not leak into the result") +} + +func TestRefFacts_LoadByTargets_Chunking(t *testing.T) { + s := openRefFactStore(t) + const n = 500 // > refFactChunk (80) + facts := make([]graph.RefFact, n) + targets := make([]string, n) + for i := range facts { + facts[i] = graph.RefFact{ + FromID: fmt.Sprintf("src%d.go::f", i), + ToID: fmt.Sprintf("dst.go::t%d", i), + Kind: "calls", + FilePath: fmt.Sprintf("src%d.go", i), + } + targets[i] = fmt.Sprintf("dst.go::t%d", i) + } + require.NoError(t, s.BulkSetRefFacts("", facts)) + + byFile, err := s.LoadRefFactsByTargets("", targets) + require.NoError(t, err) + require.Len(t, byFile, n, "every chunked target must come back grouped under its source file") +} diff --git a/internal/indexer/affected_by.go b/internal/indexer/affected_by.go new file mode 100644 index 00000000..85e00a57 --- /dev/null +++ b/internal/indexer/affected_by.go @@ -0,0 +1,487 @@ +package indexer + +import ( + "sort" + "strconv" + "strings" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/resolver" +) + +// Affected-by re-resolution. When a save changes a file's symbol +// SIGNATURES — or removes symbols, or changes their kind — the files +// that referenced those symbols hold edges and persisted reference +// facts derived against the old shape. This pass re-resolves exactly +// those files, synchronously inside the incremental pipeline, bounded +// by a configurable cap: no goroutines, no whole-graph resolves. A +// body-only edit produces no signature delta and fans out to nothing — +// the delta gate is the point. +// +// The delta is computed on a LINE-INSENSITIVE identity (kind + name, +// the name already carrying any container qualifier), not the raw node +// ID: several languages embed a definition's start line in the node ID +// (TS/JS object members `name@`, and the `_L` disambiguator +// C++/Java/C#/Kotlin/Dart/Scala/PHP append to an overloaded or +// same-named member). A body-only edit ABOVE such a symbol shifts its +// line and rewrites its ID; keying the delta on the raw ID would read +// that as a remove + add and fan out on every line-shifting edit, +// defeating the gate. The stable key collapses the old and new IDs to +// the same slot so only a genuine shape change counts. +// +// The comparable shape is derived from more than Meta["signature"]: +// only Go and C stamp a parameter-bearing signature string there, so a +// signature-only delta is blind to a real parameter/return change in +// every other language. symbolShapeFor folds in the language-agnostic +// structure the extractors emit around a definition — its parameter +// nodes (kind, position, type) reached through EdgeParamOf, its return +// types through EdgeReturns, and the C++ parameter-shape Meta keys — so +// the delta fires on a Java/Python/TS/… parameter change too. +// +// The referencing files come from two sources, unioned: +// +// - the persisted ref-facts sidecar's reverse lookup (to_id → source +// file), which survives graph eviction and daemon restarts; and +// - a live in-edge snapshot taken BEFORE the changed file is evicted +// (Graph.EvictFile drops edges where the evicted node is EITHER +// endpoint, so in-edges from unchanged files are gone afterwards). +// +// The snapshot is the only source on the in-memory backend (whose live +// edges ARE the facts); on a durable backend it also covers references +// the sidecar can't see, e.g. edges currently parked on an unresolved +// stub and cross-repo sources persisted under a sibling repo prefix. + +// defaultAffectedByMax bounds the re-resolve fan-out when the config +// carries no explicit cap. See IndexConfig.AffectedByReresolveMax. +const defaultAffectedByMax = 200 + +// affectedByMaxFiles returns the effective fan-out cap. +func (idx *Indexer) affectedByMaxFiles() int { + if n := idx.config.AffectedByReresolveMax; n > 0 { + return n + } + return defaultAffectedByMax +} + +// symbolShape is the per-symbol contract the delta compares under the +// stable (line-insensitive) key: kind plus a derived shape string that +// changes when, and only when, a referrer-visible aspect of the symbol +// changes (parameters, return types, signature). Body edits change none +// of these. +type symbolShape struct { + kind graph.NodeKind + shape string +} + +// stableSymbolKey is the line-insensitive identity the delta is keyed +// on: kind plus name. The name already carries any container qualifier +// the extractor minted (e.g. `Owner.member` for a JS object member), +// while the start line lives only in the raw node ID — so this key is +// stable across a body-only edit that shifts the definition's line and +// rewrites its `name@` / `..._L` ID. +func stableSymbolKey(n *graph.Node) string { + return string(n.Kind) + "\x00" + n.Name +} + +// symbolShapeFor derives the comparable shape string for a referenceable +// symbol: the stamped signature (Go/C carry a parameter-bearing one), the +// C++ parameter-shape Meta keys, and — language-agnostically — the +// parameter and return structure the function-shape extractors emit as +// graph nodes/edges around the definition. Folding the structure in is +// what lets the delta detect a parameter or return-type change in a +// language whose Meta["signature"] is absent or only a name-bearing +// constant (Java, C#, Kotlin, Python, Rust, TS/JS, …). +func symbolShapeFor(g graph.Store, n *graph.Node) string { + var b strings.Builder + if sig, _ := n.Meta["signature"].(string); sig != "" { + b.WriteString(sig) + } + // C++ stamps its parameter shape under dedicated Meta keys rather + // than a "signature" string; fold those in so an overload's argument + // change registers. + if v, ok := n.Meta["cpp_param_types"].(string); ok && v != "" { + b.WriteString("|cppt:") + b.WriteString(v) + } + if v, ok := n.Meta["cpp_param_shapes"].(string); ok && v != "" { + b.WriteString("|cpps:") + b.WriteString(v) + } + if v, ok := n.Meta["cpp_req_params"]; ok { + b.WriteString("|cppr:") + b.WriteString(metaToString(v)) + } + if _, ok := n.Meta["cpp_variadic"]; ok { + b.WriteString("|cppv") + } + // Language-agnostic parameter shape: the function-shape extractors + // emit one KindParam node per parameter, linked to the owner by an + // inbound EdgeParamOf, carrying position + type Meta. Sort by + // position so the shape is order-stable regardless of edge insertion + // order, and include the type so a same-arity type change still + // registers. + type paramShape struct { + pos int + typ string + variadic bool + } + var params []paramShape + for _, e := range g.GetInEdges(n.ID) { + if e == nil || e.Kind != graph.EdgeParamOf { + continue + } + p := g.GetNode(e.From) + if p == nil || p.Kind != graph.KindParam { + continue + } + ps := paramShape{} + if v, ok := p.Meta["position"]; ok { + ps.pos = metaToInt(v) + } + if t, ok := p.Meta["type"].(string); ok { + ps.typ = t + } + if _, ok := p.Meta["variadic"]; ok { + ps.variadic = true + } + params = append(params, ps) + } + if len(params) > 0 { + sort.Slice(params, func(i, j int) bool { + if params[i].pos != params[j].pos { + return params[i].pos < params[j].pos + } + return params[i].typ < params[j].typ + }) + b.WriteString("|p:") + for _, p := range params { + b.WriteString(strconv.Itoa(p.pos)) + b.WriteByte(':') + b.WriteString(p.typ) + if p.variadic { + b.WriteByte('*') + } + b.WriteByte(';') + } + } + // Return shape: one EdgeReturns per declared return type, owner → + // type. Collect the target names (a return-type change re-points the + // edge) ordered by the position Meta the extractors stamp. + type retShape struct { + pos int + target string + } + var rets []retShape + for _, e := range g.GetOutEdges(n.ID) { + if e == nil || e.Kind != graph.EdgeReturns { + continue + } + // Reduce the return target to its bare type name so the shape is + // resolution-insensitive: the snapshot reads the pre-resolve edge + // (still an `unresolved::T` stub) while the delta reads it after the + // changed file's reverse resolve may have rebound it to a concrete + // `pkg/x.go::T` node. Both must hash to the same `T`, or a body-only + // edit whose return type happened to (re)bind would fan out. + rs := retShape{target: bareTypeName(e.To), pos: metaToInt(e.Meta["position"])} + rets = append(rets, rs) + } + if len(rets) > 0 { + sort.Slice(rets, func(i, j int) bool { + if rets[i].pos != rets[j].pos { + return rets[i].pos < rets[j].pos + } + return rets[i].target < rets[j].target + }) + b.WriteString("|r:") + for _, r := range rets { + b.WriteString(strconv.Itoa(r.pos)) + b.WriteByte(':') + b.WriteString(r.target) + b.WriteByte(';') + } + } + return b.String() +} + +// bareTypeName reduces a type-reference edge target to a bare, line- and +// resolution-insensitive name. It strips an `unresolved::` stub prefix, +// then keeps only the trailing component after the last `::` (the +// file/repo scope) and the last `.` (an owner qualifier) — so a stub +// `unresolved::T`, a resolved `pkg/x.go::T`, and a resolved member +// `pkg/x.go::Owner.T` all reduce to `T`. +func bareTypeName(target string) string { + if target == "" { + return "" + } + if n := graph.UnresolvedName(target); n != "" { + target = n + } + if i := strings.LastIndex(target, "::"); i >= 0 { + target = target[i+2:] + } + if i := strings.LastIndex(target, "."); i >= 0 { + target = target[i+1:] + } + return target +} + +// metaToString renders an int/int64/string Meta value as a string for +// shape composition. The function-shape extractors stamp counts as int. +func metaToString(v any) string { + switch t := v.(type) { + case string: + return t + case int: + return strconv.Itoa(t) + case int64: + return strconv.FormatInt(t, 10) + default: + return "" + } +} + +// metaToInt reads an int-ish Meta value, tolerating the int / int64 / +// float64 (JSON round-trip) forms a persisted node can carry. +func metaToInt(v any) int { + switch t := v.(type) { + case int: + return t + case int64: + return int(t) + case float64: + return int(t) + default: + return 0 + } +} + +// affectedBySnapshot captures, before a changed file's nodes are +// evicted, the two things eviction destroys: the file's referenceable +// symbol shapes (for the signature delta) and the source node IDs of +// live reference edges into those symbols (the reverse-lookup fallback). +// +// refSources holds the FROM node IDs of the in-edges, not their file +// paths: resolving each source's file is deferred to affectedFilesFor, +// which runs only when a delta exists. A body-only edit therefore never +// pays a single GetNode for a referrer — its delta is empty and the +// fallback is never consulted. +// +// idsByKey records the concrete (old) target node IDs that hashed into +// each stable key. The persisted ref-facts sidecar indexes facts by the +// exact target ID — which still embeds the pre-edit line for a +// line-suffixed language — so the durable reverse lookup must be issued +// against these old IDs, not the changed file's fresh post-edit nodes. +type affectedBySnapshot struct { + symbols map[string]symbolShape // stable key → pre-edit shape + refSources map[string]map[string]struct{} // stable key → referencing source node IDs + idsByKey map[string][]string // stable key → pre-edit target node IDs +} + +// snapshotAffectedBy builds the pre-evict snapshot for graphPath. Must +// run before restubIncomingRefs / EvictFile — afterwards the in-edges +// point at unresolved stubs (or are gone) and the old signatures are +// unreadable. Returns nil when the file defines no referenceable +// symbols, which callers treat as "no pass". +// +// The in-edge fan-in is read in one batched GetInEdgesByNodeIDs call so +// the durable backend pays one query for the whole file rather than one +// per symbol; the per-edge source-node lookup is deferred to the delta +// path entirely. +func (idx *Indexer) snapshotAffectedBy(graphPath string) *affectedBySnapshot { + nodes := idx.graph.GetFileNodes(graphPath) + if len(nodes) == 0 { + return nil + } + snap := &affectedBySnapshot{ + symbols: make(map[string]symbolShape), + refSources: make(map[string]map[string]struct{}), + idsByKey: make(map[string][]string), + } + refIDs := make([]string, 0, len(nodes)) + keyByID := make(map[string]string, len(nodes)) + ownIDs := make(map[string]struct{}, len(nodes)) + for _, n := range nodes { + ownIDs[n.ID] = struct{}{} + } + for _, n := range nodes { + if n == nil || n.Name == "" || !graph.IsReferenceableSymbol(n.Kind) { + continue + } + key := stableSymbolKey(n) + // A file can carry two same-name same-kind definitions (an + // overload); they share a stable key. Their shapes compose so the + // key's shape changes if either overload's shape does — coarse but + // correct for a fan-out gate (over-firing is bounded and safe, + // under-firing leaves a stale edge). + cur := snap.symbols[key] + cur.kind = n.Kind + cur.shape += symbolShapeFor(idx.graph, n) + "\n" + snap.symbols[key] = cur + snap.idsByKey[key] = append(snap.idsByKey[key], n.ID) + refIDs = append(refIDs, n.ID) + keyByID[n.ID] = key + } + if len(snap.symbols) == 0 { + return nil + } + inEdges := idx.graph.GetInEdgesByNodeIDs(refIDs) + for id, edges := range inEdges { + key := keyByID[id] + if key == "" { + continue + } + for _, e := range edges { + if e == nil || !graph.IsResolvableRefEdge(e.Kind) { + continue + } + if _, ours := ownIDs[e.From]; ours { + continue // intra-file reference: re-resolved with the file itself + } + set := snap.refSources[key] + if set == nil { + set = make(map[string]struct{}) + snap.refSources[key] = set + } + set[e.From] = struct{}{} + } + } + return snap +} + +// affectedByDelta returns the stable keys of snapshot symbols whose +// contract changed against the freshly indexed node set: shape changed, +// kind changed, or the symbol is gone. A rename is a remove of the old +// key plus an add of the new one, so it lands here through the removed +// side. Newly added symbols are never part of the delta — nothing can +// hold a stale reference to a symbol that did not exist. +func affectedByDelta(g graph.Store, snap *affectedBySnapshot, newNodes []*graph.Node) []string { + current := make(map[string]symbolShape, len(newNodes)) + for _, n := range newNodes { + if n == nil || n.Name == "" || !graph.IsReferenceableSymbol(n.Kind) { + continue + } + key := stableSymbolKey(n) + cur := current[key] + cur.kind = n.Kind + cur.shape += symbolShapeFor(g, n) + "\n" + current[key] = cur + } + var delta []string + for key, old := range snap.symbols { + now, exists := current[key] + if !exists || now.kind != old.kind || now.shape != old.shape { + delta = append(delta, key) + } + } + sort.Strings(delta) + return delta +} + +// affectedFilesFor unions the persisted reverse lookup with the +// pre-evict in-edge snapshot for the delta symbols, excluding the +// changed file itself. The snapshot stores referrer NODE IDs; their +// files are resolved here in one batched GetNodesByIDs — work that runs +// only for a real delta, never on the body-only path. Sorted for +// deterministic truncation and tests. +func (idx *Indexer) affectedFilesFor(changedPath string, deltaKeys []string, snap *affectedBySnapshot) []string { + fileSet := make(map[string]struct{}) + // Durable reverse lookup: the sidecar answers by concrete target ID, + // which still embeds the pre-edit line for a line-suffixed language. + // Issue it against the OLD target IDs the snapshot recorded for each + // delta key — the changed file's fresh nodes carry new IDs that the + // seeded sidecar has never seen. + if r, ok := idx.graph.(graph.RefFactsReader); ok { + var targetIDs []string + for _, key := range deltaKeys { + targetIDs = append(targetIDs, snap.idsByKey[key]...) + } + if len(targetIDs) > 0 { + byFile, err := r.LoadRefFactsByTargets(idx.repoPrefix, targetIDs) + if err != nil { + idx.logger.Debug("affected-by: ref-facts reverse lookup failed", zap.Error(err)) + } + for file := range byFile { + if file != "" && file != changedPath { + fileSet[file] = struct{}{} + } + } + } + } + // In-edge snapshot fallback: the snapshot stored referrer NODE IDs; + // resolve them to their files in one batch — work that runs only for a + // real delta, never on the body-only path. + srcIDSet := make(map[string]struct{}) + for _, key := range deltaKeys { + for from := range snap.refSources[key] { + srcIDSet[from] = struct{}{} + } + } + if len(srcIDSet) > 0 { + ids := make([]string, 0, len(srcIDSet)) + for id := range srcIDSet { + ids = append(ids, id) + } + byID := idx.graph.GetNodesByIDs(ids) + for _, n := range byID { + if n == nil || n.FilePath == "" || n.FilePath == changedPath { + continue + } + fileSet[n.FilePath] = struct{}{} + } + } + files := make([]string, 0, len(fileSet)) + for f := range fileSet { + files = append(files, f) + } + sort.Strings(files) + return files +} + +// reresolveAffectedBy is the pass entry point, called from the +// incremental per-file index path after the changed file itself has +// been re-indexed, re-resolved, and had its own facts re-persisted. +// snap is the pre-evict snapshot (nil ⇒ no-op); newNodes are the +// changed file's freshly added nodes (already repo-prefixed). For each +// affected file it re-runs the per-save resolve pair (forward + reverse, +// batched so the resolver's pass indexes are built once), re-materialises +// external-call placeholders, and re-persists that file's reference +// facts — so an edge that degraded to an unresolved stub also drops its +// stale persisted fact, and a rebound edge records its new resolution. +func (idx *Indexer) reresolveAffectedBy(changedPath string, snap *affectedBySnapshot, newNodes []*graph.Node) { + if snap == nil { + return + } + delta := affectedByDelta(idx.graph, snap, newNodes) + if len(delta) == 0 { + return // body-only edit: no contract change, no fan-out + } + files := idx.affectedFilesFor(changedPath, delta, snap) + if len(files) == 0 { + return + } + if maxFiles := idx.affectedByMaxFiles(); len(files) > maxFiles { + idx.logger.Debug("affected-by: re-resolve set truncated", + zap.String("file", changedPath), + zap.Int("affected", len(files)), + zap.Int("cap", maxFiles), + zap.Int("dropped", len(files)-maxFiles)) + idx.affectedByDropped.Add(int64(len(files) - maxFiles)) + files = files[:maxFiles] + } + idx.affectedByPasses.Add(1) + idx.affectedByFilesResolved.Add(int64(len(files))) + + idx.resolver.ResolveFilesAndIncoming(files) + resolver.SynthesizeExternalCallsForFiles(idx.graph, idx.externalCallSynthesisEnabled(), files) + idx.persistRefFactsForFiles(files) +} + +// AffectedByCounts reports the affected-by pass activity for this +// indexer: passes run, referencing files re-resolved, and files dropped +// by the fan-out cap. Diagnostic/test hook — the body-only-edit gate is +// observable as an unchanged pass count. +func (idx *Indexer) AffectedByCounts() (passes, files, dropped int64) { + return idx.affectedByPasses.Load(), idx.affectedByFilesResolved.Load(), idx.affectedByDropped.Load() +} diff --git a/internal/indexer/affected_by_crosslang_test.go b/internal/indexer/affected_by_crosslang_test.go new file mode 100644 index 00000000..a1d69e60 --- /dev/null +++ b/internal/indexer/affected_by_crosslang_test.go @@ -0,0 +1,197 @@ +package indexer + +import ( + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zzet/gortex/internal/graph" +) + +// TestAffectedBy_TypeScriptParamChange_ReresolvesCaller is the +// cross-language proof for the signature delta. TypeScript stamps only a +// name-bearing Meta["signature"] ("function F()") that does not move when +// a parameter is added — so a signature-only delta is blind to it and the +// caller is never re-resolved. The delta now folds in the parameter +// structure the extractor emits (KindParam nodes via EdgeParamOf), so +// adding a parameter to F re-resolves the file that calls it. +func TestAffectedBy_TypeScriptParamChange_ReresolvesCaller(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.ts") + bPath := filepath.Join(dir, "b.ts") + writeFile(t, aPath, "export function F(x: number): number {\n return x\n}\n") + writeFile(t, bPath, "import { F } from './a'\n\nexport function Caller(): number {\n return F(1)\n}\n") + + idx, _ := newSQLiteIndexer(t) + _, err := idx.Index(dir) + require.NoError(t, err) + idx.ResolveAll() + + g := idx.Graph() + callerID := fnNodeID(t, g, "b.ts", "Caller") + require.Equal(t, "a.ts::F", callTargetFrom(t, g, callerID), + "baseline: Caller must bind to the local F") + + // Add a parameter — the name-only "function F()" signature is + // unchanged; only the parameter structure differs. + bumpMtime(t, aPath, "export function F(x: number, y: number): number {\n return x + y\n}\n") + _, err = idx.IncrementalReindexPaths(dir, []string{aPath}) + require.NoError(t, err) + + passes, files, _ := idx.AffectedByCounts() + assert.Equal(t, int64(1), passes, + "a TypeScript parameter change must trigger the pass even though Meta[signature] is name-only") + assert.Equal(t, int64(1), files) + assert.Equal(t, "a.ts::F", callTargetFrom(t, g, callerID), + "the caller must be re-resolved to the fresh F") +} + +// TestAffectedBy_TypeScriptBodyOnly_NoFanout is the gate counterpart for +// the structural shape: a TypeScript body edit that leaves the parameter +// list untouched must not fan out — proving the structural shape is body- +// insensitive, not just a coarse "anything changed" trigger. +func TestAffectedBy_TypeScriptBodyOnly_NoFanout(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.ts") + bPath := filepath.Join(dir, "b.ts") + writeFile(t, aPath, "export function F(x: number): number {\n return x\n}\n") + writeFile(t, bPath, "import { F } from './a'\n\nexport function Caller(): number {\n return F(1)\n}\n") + + idx, _ := newSQLiteIndexer(t) + _, err := idx.Index(dir) + require.NoError(t, err) + idx.ResolveAll() + + bumpMtime(t, aPath, "export function F(x: number): number {\n return x + 1 + 2 + 3\n}\n") + _, err = idx.IncrementalReindexPaths(dir, []string{aPath}) + require.NoError(t, err) + + passes, _, _ := idx.AffectedByCounts() + assert.Equal(t, int64(0), passes, + "a TypeScript body-only edit must not fan out") +} + +// TestAffectedByDelta_LineSuffixedID_NoSpuriousDelta proves the line- +// insensitive identity. Several languages embed a definition's start line +// in its node ID (TS/JS `name@`, the `_L` overload +// disambiguator). A body-only edit ABOVE such a symbol shifts its line and +// rewrites its ID; keying the delta on the raw ID would read that as a +// remove + add and fire on every line-shifting edit. The delta is keyed on +// kind + name instead, so an ID that differs only in its line-suffix — +// with an identical shape — yields no delta. +func TestAffectedByDelta_LineSuffixedID_NoSpuriousDelta(t *testing.T) { + g := graph.New() + // Snapshot: an object member whose ID embeds its line. + old := &graph.Node{ + ID: "a.ts::api.health@4", Kind: graph.KindFunction, Name: "api.health", + FilePath: "a.ts", Meta: map[string]any{"signature": "api.health()"}, + } + g.AddNode(old) + key := stableSymbolKey(old) + snap := &affectedBySnapshot{ + symbols: map[string]symbolShape{key: {kind: old.Kind, shape: symbolShapeFor(g, old) + "\n"}}, + refSources: map[string]map[string]struct{}{}, + idsByKey: map[string][]string{key: {old.ID}}, + } + + // After a body-only edit above it, the same member is re-minted at a + // new line — new ID, identical name/kind/shape. + shifted := &graph.Node{ + ID: "a.ts::api.health@9", Kind: graph.KindFunction, Name: "api.health", + FilePath: "a.ts", Meta: map[string]any{"signature": "api.health()"}, + } + g2 := graph.New() + g2.AddNode(shifted) + + delta := affectedByDelta(g2, snap, []*graph.Node{shifted}) + assert.Empty(t, delta, + "a line-only ID shift with an unchanged shape must not be a delta") + + // Control: a genuine kind change at the same name IS a delta. + retyped := &graph.Node{ + ID: "a.ts::api.health@9", Kind: graph.KindVariable, Name: "api.health", + FilePath: "a.ts", Meta: map[string]any{"signature": "api.health()"}, + } + g3 := graph.New() + g3.AddNode(retyped) + delta = affectedByDelta(g3, snap, []*graph.Node{retyped}) + assert.Equal(t, []string{key}, delta, + "a kind change under the same name must still be a delta") +} + +// TestAffectedBy_MinifiedSkip_PreservesFactsNoFanout is the parse-failure +// guard. When a file that previously parsed cleanly is overwritten with a +// minified bundle, the incremental path yields a synthetic skip node with +// zero symbols. That must NOT be read as "every symbol was removed": the +// affected-by pass must not fan out, and the caller's persisted reference +// fact must survive — a transient un-parseable save must not durably +// delete reverse-lookup rows that won't be rebuilt until a clean reparse. +func TestAffectedBy_MinifiedSkip_PreservesFactsNoFanout(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.js") + bPath := filepath.Join(dir, "b.js") + writeFile(t, aPath, "export function F(x) {\n return x\n}\n") + writeFile(t, bPath, "import { F } from './a.js'\n\nexport function Caller() {\n return F(1)\n}\n") + + idx, store := newSQLiteIndexer(t) + _, err := idx.Index(dir) + require.NoError(t, err) + idx.ResolveAll() + + g := idx.Graph() + fID := fnNodeID(t, g, "a.js", "F") + before, err := store.LoadRefFactsByTargets("", []string{fID}) + require.NoError(t, err) + require.Contains(t, before, "b.js", + "baseline: the sidecar must record b.js referencing F") + + // Overwrite a.js with a minified bundle: one long line over the + // minified-detection floor. extractFile classifies it as a build + // artifact and returns a synthetic skip node — zero symbols. + blob := "var x=" + strings.Repeat("1+", 1500) + "1;\n" + bumpMtime(t, aPath, blob) + _, err = idx.IncrementalReindexPaths(dir, []string{aPath}) + require.NoError(t, err) + + passes, files, _ := idx.AffectedByCounts() + assert.Equal(t, int64(0), passes, + "a minified-skip (zero symbols) must not fan out as if every symbol was removed") + assert.Equal(t, int64(0), files) + + after, err := store.LoadRefFactsByTargets("", []string{fID}) + require.NoError(t, err) + assert.Contains(t, after, "b.js", + "a transient parse-skip must not delete the caller's persisted reference fact") +} + +// TestAffectedBy_NoDeltaPath_DefersReferrerLookup proves the cheap no- +// delta path: the pre-evict snapshot must record referrer NODE IDs rather +// than eagerly resolving each to its file, so a body-only edit pays no +// per-referrer node lookup. We assert the snapshot's structure directly — +// refSources holds source node IDs, and resolving them to files is the +// delta path's job, reached only when a delta exists. +func TestAffectedBy_NoDeltaPath_DefersReferrerLookup(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.go") + bPath := filepath.Join(dir, "b.go") + writeFile(t, aPath, "package p\n\nfunc F(x int) int { return x }\n") + writeFile(t, bPath, "package p\n\nfunc Caller() int { return F(1) }\n") + + g := graph.New() + idx := newTestIndexer(g) + _, err := idx.Index(dir) + require.NoError(t, err) + + snap := idx.snapshotAffectedBy("a.go") + require.NotNil(t, snap) + key := "function\x00F" + require.Contains(t, snap.refSources, key, "F's referrers must be captured") + // The recorded source is the caller's NODE ID, not its file path — + // the file is resolved lazily only on the delta path. + callerID := fnNodeID(t, g, "b.go", "Caller") + _, ok := snap.refSources[key][callerID] + assert.True(t, ok, "refSources must hold the referrer node ID, deferring its file lookup") +} diff --git a/internal/indexer/affected_by_e2e_test.go b/internal/indexer/affected_by_e2e_test.go new file mode 100644 index 00000000..374ef668 --- /dev/null +++ b/internal/indexer/affected_by_e2e_test.go @@ -0,0 +1,263 @@ +package indexer + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/config" + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/graph/store_sqlite" + "github.com/zzet/gortex/internal/parser" + "github.com/zzet/gortex/internal/parser/languages" +) + +// newSQLiteIndexer builds an indexer over a sqlite-backed graph, the +// configuration the affected-by pass's persisted reverse lookup runs on. +func newSQLiteIndexer(t *testing.T) (*Indexer, *store_sqlite.Store) { + t.Helper() + store, err := store_sqlite.Open(filepath.Join(t.TempDir(), "g.sqlite")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + reg := parser.NewRegistry() + languages.RegisterAll(reg) + cfg := config.Default().Index + cfg.Workers = 1 + return New(store, reg, cfg, zap.NewNop()), store +} + +// TestAffectedBy_SignatureChange_ReresolvesCaller is the headline case +// on the in-memory backend (whose reverse lookup is the pre-evict +// in-edge snapshot): b.go calls F defined in a.go; changing F's +// SIGNATURE re-resolves b.go — its call edge lands on the fresh F node +// and exactly one bounded affected-by pass ran over exactly one file. +func TestAffectedBy_SignatureChange_ReresolvesCaller(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.go") + bPath := filepath.Join(dir, "b.go") + writeFile(t, aPath, "package p\n\nfunc F(x int) int { return x }\n") + writeFile(t, bPath, "package p\n\nfunc Caller() int { return F(1) }\n") + + g := graph.New() + idx := newTestIndexer(g) + _, err := idx.Index(dir) + require.NoError(t, err) + + fID := fnNodeID(t, g, "a.go", "F") + callerID := fnNodeID(t, g, "b.go", "Caller") + require.Equal(t, fID, callTargetFrom(t, g, callerID), + "baseline: Caller's call must resolve to F") + + bumpMtime(t, aPath, "package p\n\nfunc F(x int, y int) int { return x + y }\n") + _, err = idx.IncrementalReindexPaths(dir, []string{aPath}) + require.NoError(t, err) + + newFID := fnNodeID(t, g, "a.go", "F") + assert.Equal(t, newFID, callTargetFrom(t, g, callerID), + "after F's signature changed, Caller's edge must be re-resolved to the fresh F") + + passes, files, dropped := idx.AffectedByCounts() + assert.Equal(t, int64(1), passes, "a signature change must trigger exactly one affected-by pass") + assert.Equal(t, int64(1), files, "the pass must re-resolve exactly the one referencing file") + assert.Equal(t, int64(0), dropped) +} + +// TestAffectedBy_BodyOnlyEdit_NoFanout proves the gate: an edit that +// changes only a function BODY produces no signature delta and must not +// fan out — the whole point of delta detection is that the common case +// (a body edit) costs nothing beyond the changed file itself. Driven +// through whole-root IncrementalReindex to cover that sync route too. +func TestAffectedBy_BodyOnlyEdit_NoFanout(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.go") + writeFile(t, aPath, "package p\n\nfunc F(x int) int { return x }\n") + writeFile(t, filepath.Join(dir, "b.go"), "package p\n\nfunc Caller() int { return F(1) }\n") + + g := graph.New() + idx := newTestIndexer(g) + _, err := idx.Index(dir) + require.NoError(t, err) + callerID := fnNodeID(t, g, "b.go", "Caller") + + bumpMtime(t, aPath, "package p\n\nfunc F(x int) int { return x + 1 }\n") + res, err := idx.IncrementalReindex(dir) + require.NoError(t, err) + require.Equal(t, 1, res.StaleFileCount) + + passes, files, _ := idx.AffectedByCounts() + assert.Equal(t, int64(0), passes, "a body-only edit must not trigger an affected-by pass") + assert.Equal(t, int64(0), files) + + // The caller's edge still survives the definition re-index via the + // existing restub + reverse-resolve pair. + assert.Equal(t, fnNodeID(t, g, "a.go", "F"), callTargetFrom(t, g, callerID), + "the caller edge must still be bound after a body-only re-index") +} + +// TestAffectedBy_PerSaveIndexFile_ReresolvesCaller drives the same +// signature change through IndexFile directly — the watcher's per-save +// patch path — proving every sync route shares the hook. +func TestAffectedBy_PerSaveIndexFile_ReresolvesCaller(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.go") + writeFile(t, aPath, "package p\n\nfunc F() int { return 0 }\n") + writeFile(t, filepath.Join(dir, "b.go"), "package p\n\nfunc Caller() int { return F() }\n") + + g := graph.New() + idx := newTestIndexer(g) + _, err := idx.Index(dir) + require.NoError(t, err) + callerID := fnNodeID(t, g, "b.go", "Caller") + + writeFile(t, aPath, "package p\n\nfunc F(n int) int { return n }\n") + require.NoError(t, idx.IndexFile(aPath)) + + assert.Equal(t, fnNodeID(t, g, "a.go", "F"), callTargetFrom(t, g, callerID)) + passes, files, _ := idx.AffectedByCounts() + assert.Equal(t, int64(1), passes, "the per-save IndexFile route must run the pass") + assert.Equal(t, int64(1), files) +} + +// TestAffectedBy_RemovedSymbol_SQLite removes a called symbol from its +// definition file on a sqlite-backed graph: the caller's edge must +// degrade to the resolver's normal unresolved stub (no dangling old-ID +// edge), and the caller's now-stale persisted reference fact must be +// dropped by the pass's re-persist. +func TestAffectedBy_RemovedSymbol_SQLite(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.go") + writeFile(t, aPath, "package p\n\nfunc F() {}\n\nfunc Keep() {}\n") + writeFile(t, filepath.Join(dir, "b.go"), "package p\n\nfunc Caller() { F() }\n") + + idx, store := newSQLiteIndexer(t) + _, err := idx.Index(dir) + require.NoError(t, err) + idx.ResolveAll() // seeds the persisted ref-facts sidecar + + g := idx.Graph() + fID := fnNodeID(t, g, "a.go", "F") + callerID := fnNodeID(t, g, "b.go", "Caller") + require.Equal(t, fID, callTargetFrom(t, g, callerID)) + + // The reverse lookup must already know b.go references F. + byFile, err := store.LoadRefFactsByTargets("", []string{fID}) + require.NoError(t, err) + require.Contains(t, byFile, "b.go", + "the seeded sidecar must answer the by-target reverse lookup") + + bumpMtime(t, aPath, "package p\n\nfunc Keep() {}\n") + _, err = idx.IncrementalReindexPaths(dir, []string{aPath}) + require.NoError(t, err) + + target := callTargetFrom(t, g, callerID) + assert.True(t, graph.IsUnresolvedTarget(target), + "the caller's edge must degrade to an unresolved stub, got %q", target) + assert.Equal(t, "F", graph.UnresolvedName(target)) + + facts, err := store.LoadRefFactsByFiles("", []string{"b.go"}) + require.NoError(t, err) + for _, f := range facts { + assert.NotEqual(t, fID, f.ToID, + "the stale fact pointing at the removed symbol must be re-persisted away") + } + passes, _, _ := idx.AffectedByCounts() + assert.Equal(t, int64(1), passes) +} + +// TestAffectedBy_SidecarDiscovery_SQLite proves the persisted reverse +// lookup is a real discovery source, not just a mirror of the live +// graph: the caller's in-edge is parked on an unresolved stub BEFORE +// the change (so the pre-evict snapshot sees no in-edge for F), and the +// affected file is still found — via LoadRefFactsByTargets — and +// re-resolved. +func TestAffectedBy_SidecarDiscovery_SQLite(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.go") + writeFile(t, aPath, "package p\n\nfunc F() int { return 0 }\n") + writeFile(t, filepath.Join(dir, "b.go"), "package p\n\nfunc Caller() int { return F() }\n") + + idx, _ := newSQLiteIndexer(t) + _, err := idx.Index(dir) + require.NoError(t, err) + idx.ResolveAll() + + g := idx.Graph() + callerID := fnNodeID(t, g, "b.go", "Caller") + require.Equal(t, fnNodeID(t, g, "a.go", "F"), callTargetFrom(t, g, callerID)) + + // Park the caller's live edge on a stub — the state a prior evict + // leaves behind — so only the sidecar can name b.go as affected. + idx.restubIncomingRefs("a.go") + require.True(t, graph.IsUnresolvedTarget(callTargetFrom(t, g, callerID)), + "precondition: the live in-edge must be parked on a stub") + + bumpMtime(t, aPath, "package p\n\nfunc F(n int) int { return n }\n") + _, err = idx.IncrementalReindexPaths(dir, []string{aPath}) + require.NoError(t, err) + + passes, files, _ := idx.AffectedByCounts() + assert.Equal(t, int64(1), passes, + "the pass must run even with no live in-edges — discovery comes from the sidecar") + assert.Equal(t, int64(1), files) + assert.Equal(t, fnNodeID(t, g, "a.go", "F"), callTargetFrom(t, g, callerID), + "the sidecar-discovered caller must be re-resolved to the fresh F") +} + +// TestAffectedBy_CapBoundsFanout configures a fan-out cap of 1 with +// three referencing files: the pass must re-resolve exactly one file +// and account for the two it dropped — the cap is loud, not silent. +func TestAffectedBy_CapBoundsFanout(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.go") + writeFile(t, aPath, "package p\n\nfunc F(x int) int { return x }\n") + writeFile(t, filepath.Join(dir, "b.go"), "package p\n\nfunc CallerB() int { return F(1) }\n") + writeFile(t, filepath.Join(dir, "c.go"), "package p\n\nfunc CallerC() int { return F(2) }\n") + writeFile(t, filepath.Join(dir, "d.go"), "package p\n\nfunc CallerD() int { return F(3) }\n") + + g := graph.New() + reg := parser.NewRegistry() + reg.Register(languages.NewGoExtractor()) + cfg := config.Default().Index + cfg.Workers = 1 + cfg.AffectedByReresolveMax = 1 + idx := New(g, reg, cfg, zap.NewNop()) + _, err := idx.Index(dir) + require.NoError(t, err) + + bumpMtime(t, aPath, "package p\n\nfunc F(x int, y int) int { return x + y }\n") + _, err = idx.IncrementalReindexPaths(dir, []string{aPath}) + require.NoError(t, err) + + passes, files, dropped := idx.AffectedByCounts() + assert.Equal(t, int64(1), passes) + assert.Equal(t, int64(1), files, "the cap must bound the re-resolve set") + assert.Equal(t, int64(2), dropped, "the truncated files must be accounted, not silently lost") +} + +// TestAffectedBy_DeferredBatchPath_NoFanout proves the batch guard: a +// caller that defers global passes (warmup, ReconcileAll) runs one +// resolve at the end of the batch, so the per-file affected-by pass +// must stay off even for a genuine signature change. +func TestAffectedBy_DeferredBatchPath_NoFanout(t *testing.T) { + dir := t.TempDir() + aPath := filepath.Join(dir, "a.go") + writeFile(t, aPath, "package p\n\nfunc F(x int) int { return x }\n") + writeFile(t, filepath.Join(dir, "b.go"), "package p\n\nfunc Caller() int { return F(1) }\n") + + g := graph.New() + idx := newTestIndexer(g) + _, err := idx.Index(dir) + require.NoError(t, err) + + idx.SetDeferGlobalPasses(true) + bumpMtime(t, aPath, "package p\n\nfunc F(x int, y int) int { return x + y }\n") + require.NoError(t, idx.IndexFile(aPath)) + + passes, _, _ := idx.AffectedByCounts() + assert.Equal(t, int64(0), passes, + "deferred-batch indexing must not fan out per file — the batch caller resolves once at the end") +} diff --git a/internal/indexer/indexer.go b/internal/indexer/indexer.go index f99f274b..a53cbe10 100644 --- a/internal/indexer/indexer.go +++ b/internal/indexer/indexer.go @@ -306,6 +306,16 @@ type Indexer struct { // after which indexFile drives EvictFuncs/UpdateFuncs. While un-built // indexFile falls back to the whole-graph pass. cloneIndex *incrementalCloneIndex + + // affectedByPasses / affectedByFilesResolved / affectedByDropped + // count the affected-by re-resolution activity (see affected_by.go): + // passes that found a signature delta and ran, referencing files + // re-resolved by them, and files dropped by the fan-out cap. + // Exposed via AffectedByCounts so tests and diagnostics can observe + // that a body-only edit triggered no fan-out. + affectedByPasses atomic.Int64 + affectedByFilesResolved atomic.Int64 + affectedByDropped atomic.Int64 } // contractCacheEntry is a cached contract-extraction result for one file. @@ -2050,6 +2060,14 @@ func (idx *Indexer) IndexCtx(ctx context.Context, root string) (result *IndexRes } reporter.Report("persisting bulk graph", 1, 1) idx.graph = diskTarget + // Mirror of the SetGraph(inMemShadow) above: the resolver + // must follow the graph pointer back to the disk store, or + // every post-index per-file resolve (the watcher save path, + // incremental reindex) reads the drained — now empty — + // shadow and silently resolves nothing. + if idx.resolver != nil { + idx.resolver.SetGraph(diskTarget) + } }() } else if diskTarget == nil && idx.graph.NodeCount() == 0 && idx.graph.EdgeCount() == 0 { if _, isBulk := idx.graph.(graph.BulkLoader); isBulk && len(files) > shadowMaxFileCount() { @@ -2733,6 +2751,24 @@ func (idx *Indexer) indexFile(filePath string, resolve bool) error { return err } + // Affected-by snapshot: the symbol shapes and reverse-reference + // sources the post-resolve signature-delta pass compares against, + // captured BEFORE eviction — EvictFile drops in-edges from + // unchanged files and replaces this file's nodes, so neither is + // recoverable afterwards. Skipped on the no-resolve and + // deferred-batch paths, whose callers run a full resolve (and + // persistAllRefFacts) once at the end of the batch. + // + // Also skipped for a quarantined / timed-out / minified-skipped file: + // its synthetic result carries zero symbols, so the delta would read + // every prior symbol as removed and fan out to re-resolve the whole + // reverse graph on a transient parse failure. A failure that yields no + // symbols is not the same as a symbol genuinely deleted from source. + var abSnap *affectedBySnapshot + if resolve && !idx.deferGlobalPasses && !skipped { + abSnap = idx.snapshotAffectedBy(graphPath) + } + // We hold a usable result: evict the old state now, then add the // new — the window where the file has no nodes is just this gap. evictExisting() @@ -2807,8 +2843,20 @@ func (idx *Indexer) indexFile(filePath string, resolve bool) error { } // Persist this file's resolved-reference facts to the durable sidecar // (delete-then-set so removed references don't linger). No-op on the - // in-memory backend. - idx.persistRefFactsForFiles([]string{graphPath}) + // in-memory backend. Skipped for a quarantined / timed-out / + // minified file: its synthetic result yields no facts, so a + // delete-then-set would durably drop the file's real facts on a + // transient parse failure and leave them gone until a clean + // reparse — abSnap is nil here too, so the affected-by pass that + // would also fan out is already a no-op. + if !skipped { + idx.persistRefFactsForFiles([]string{graphPath}) + // Affected-by re-resolution: if this save changed a symbol's + // signature or kind, or removed a symbol, re-resolve the files + // that referenced it — bounded, synchronous, and gated on the + // signature delta so a body-only edit fans out to nothing. + idx.reresolveAffectedBy(graphPath, abSnap, result.Nodes) + } } // Update mtime for this file. relPath is already the canonical diff --git a/internal/indexer/ref_facts.go b/internal/indexer/ref_facts.go index da3d2353..a4310f4e 100644 --- a/internal/indexer/ref_facts.go +++ b/internal/indexer/ref_facts.go @@ -64,8 +64,11 @@ func collectRefFacts(g graph.Store, nodes []*graph.Node) []graph.RefFact { // persistRefFactsForFiles re-derives and persists the resolved-reference facts // for the given graph file paths (delete-then-set per file so stale facts from -// removed references don't linger). No-op when the backend has no durable layer -// or the file list is empty. +// removed references don't linger). Every requested file is deleted even when +// it yields no fresh facts — a file whose last resolvable reference just +// degraded to an unresolved stub must drop its stale rows, not keep them +// because there is nothing new to write. No-op when the backend has no +// durable layer or the file list is empty. func (idx *Indexer) persistRefFactsForFiles(graphPaths []string) { w, ok := idx.refFactsWriter() if !ok || len(graphPaths) == 0 { @@ -73,14 +76,28 @@ func (idx *Indexer) persistRefFactsForFiles(graphPaths []string) { } byRepo := map[string][]graph.RefFact{} filesByRepo := map[string]map[string]struct{}{} + addFile := func(repo, file string) { + if filesByRepo[repo] == nil { + filesByRepo[repo] = map[string]struct{}{} + } + filesByRepo[repo][file] = struct{}{} + } for _, p := range graphPaths { nodes := idx.graph.GetFileNodes(p) + // Register the file for deletion under its nodes' repo prefix + // (falling back to the indexer's own) regardless of whether any + // fresh facts come out of it below. + repo := idx.repoPrefix + for _, n := range nodes { + if n != nil && n.RepoPrefix != "" { + repo = n.RepoPrefix + break + } + } + addFile(repo, p) for _, f := range collectRefFacts(idx.graph, nodes) { byRepo[f.RepoPrefix] = append(byRepo[f.RepoPrefix], f) - if filesByRepo[f.RepoPrefix] == nil { - filesByRepo[f.RepoPrefix] = map[string]struct{}{} - } - filesByRepo[f.RepoPrefix][f.FilePath] = struct{}{} + addFile(f.RepoPrefix, f.FilePath) } } for repo, fileSet := range filesByRepo { diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index f356405e..8ee2c693 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -813,9 +813,43 @@ func (r *Resolver) ResolveFileAndIncoming(filePath string) *ResolveStats { return stats } +// ResolveFilesAndIncoming runs the forward and reverse passes for a +// batch of files under one lock, one build of the per-pass indexes, and +// one run of the attribution passes. The affected-by re-resolution path +// uses this: calling ResolveFileAndIncoming per file would rebuild the +// four pass indexes and re-run the whole-graph attribution sweeps once +// per file, turning a bounded fan-out into N whole-graph passes. +func (r *Resolver) ResolveFilesAndIncoming(filePaths []string) *ResolveStats { + stats := &ResolveStats{} + if len(filePaths) == 0 { + return stats + } + r.mu.Lock() + defer r.mu.Unlock() + + clear := r.buildPassIndexes() + defer clear() + + for _, p := range filePaths { + r.resolveFileEdgesLocked(p, stats) + r.resolveIncomingLocked(p, stats) + } + r.runFileAttributionPassesLocked() + return stats +} + // resolveFileLocked is the forward-pass core. Caller holds r.mu and // has built the per-pass indexes. func (r *Resolver) resolveFileLocked(filePath string, stats *ResolveStats) { + r.resolveFileEdgesLocked(filePath, stats) + r.runFileAttributionPassesLocked() +} + +// resolveFileEdgesLocked walks one file's outgoing unresolved edges and +// binds them, without the attribution tail — batch callers run the +// attribution passes once after the whole batch instead of once per +// file. Caller holds r.mu and has built the per-pass indexes. +func (r *Resolver) resolveFileEdgesLocked(filePath string, stats *ResolveStats) { // Get all nodes in the file, then check their outgoing edges. // Single-threaded path — collect mutations into a batch and flush // in one ReindexEdges call after the file's edges are walked, so a @@ -864,14 +898,17 @@ func (r *Resolver) resolveFileLocked(filePath string, stats *ResolveStats) { } } - // Re-run the attribution passes that ResolveAll runs. ResolveFile - // handles incremental updates — a re-parse of one file emits - // fresh `unresolved::` edges that haven't been seen by these - // passes yet, so without re-running them the incremental graph - // diverges from a cold re-index (caught by - // TestIncrementalReindex_ConvergesToFullIndex). Each pass is - // idempotent on already-rewritten edges (the `unresolved::` - // prefix check makes a second sweep a no-op). +} + +// runFileAttributionPassesLocked re-runs the attribution passes that +// ResolveAll runs. The per-file resolve paths handle incremental +// updates — a re-parse of one file emits fresh `unresolved::` +// edges that haven't been seen by these passes yet, so without +// re-running them the incremental graph diverges from a cold re-index +// (caught by TestIncrementalReindex_ConvergesToFullIndex). Each pass is +// idempotent on already-rewritten edges (the `unresolved::` prefix +// check makes a second sweep a no-op). Caller holds r.mu. +func (r *Resolver) runFileAttributionPassesLocked() { r.rebindGoMethodReceivers() r.bindBareNameScopeRefs() r.bindGenericParamRefs() From 87329ec29e50c2e8bea9d0b2900b6a6c96413d89 Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Sat, 13 Jun 2026 08:08:07 +0200 Subject: [PATCH 4/5] feat(contracts): persisted cross-service contract-bridge subgraph IDL-aware contract extraction (.proto package/service/method canonical identities with brace-bounded service blocks; a Thrift extractor) plus a matcher join that pairs gRPC/Thrift providers and consumers across casing and package-qualification, gated on real gRPC evidence so plain Register*Server function definitions do not mint phantom providers. Each matched provider-consumer group materializes one persisted contract-bridge node, scoped to the (workspace, project) match boundary so unrelated services never merge, with deterministic node fields and reconcile serialization. The contracts tool gained action=bridge: a reciprocal-rank-fusion group query and a cross-service impact mode. --- internal/agents/claudecode/content.go | 2 +- internal/contracts/contract.go | 13 +- internal/contracts/grpc.go | 240 ++++++- internal/contracts/grpc_test.go | 134 ++++ internal/contracts/matcher.go | 264 ++++++++ internal/contracts/matcher_canonical_test.go | 293 ++++++++ internal/contracts/proto_idl_test.go | 182 +++++ internal/contracts/thrift.go | 222 ++++++ internal/contracts/thrift_test.go | 157 +++++ internal/graph/edge.go | 14 +- internal/graph/node.go | 19 +- internal/graph/storetest/storetest.go | 122 ++++ internal/indexer/contract_bridge.go | 267 ++++++++ internal/indexer/contract_bridge_test.go | 454 +++++++++++++ internal/indexer/indexer.go | 2 + internal/indexer/multi.go | 26 + internal/mcp/resources.go | 12 + internal/mcp/tools_contract_bridge.go | 676 +++++++++++++++++++ internal/mcp/tools_contract_bridge_test.go | 358 ++++++++++ internal/mcp/tools_enhancements.go | 13 +- 20 files changed, 3446 insertions(+), 24 deletions(-) create mode 100644 internal/contracts/matcher_canonical_test.go create mode 100644 internal/contracts/proto_idl_test.go create mode 100644 internal/contracts/thrift.go create mode 100644 internal/contracts/thrift_test.go create mode 100644 internal/indexer/contract_bridge.go create mode 100644 internal/indexer/contract_bridge_test.go create mode 100644 internal/mcp/tools_contract_bridge.go create mode 100644 internal/mcp/tools_contract_bridge_test.go diff --git a/internal/agents/claudecode/content.go b/internal/agents/claudecode/content.go index 75885f57..ee63905e 100644 --- a/internal/agents/claudecode/content.go +++ b/internal/agents/claudecode/content.go @@ -461,7 +461,7 @@ These wrap the discovery + impact + memory surfaces into ordered playbooks so po ### API Contracts | Tool | What it gives you | |------|-------------------| -| contracts | API contracts: action=list (default) lists detected contracts; action=check matches providers/consumers and reports orphans across repos. Scope either action with ` + "`repo`" + `, ` + "`project`" + `, or ` + "`ref`" + ` | +| contracts | API contracts: action=list (default) lists detected contracts; action=check matches providers/consumers and reports orphans across repos; action=bridge ranks matched provider↔consumer groups (RRF over text / path-repo / adjacency / degree signals; mode=impact for a symbol's cross-service blast radius). Scope any action with ` + "`repo`" + `, ` + "`project`" + `, or ` + "`ref`" + ` | ### Config Hygiene | Tool | What it gives you | diff --git a/internal/contracts/contract.go b/internal/contracts/contract.go index f8fc1777..f7b33adb 100644 --- a/internal/contracts/contract.go +++ b/internal/contracts/contract.go @@ -10,9 +10,16 @@ import ( type ContractType string const ( - ContractHTTP ContractType = "http" - ContractGRPC ContractType = "grpc" - ContractGraphQL ContractType = "graphql" + ContractHTTP ContractType = "http" + ContractGRPC ContractType = "grpc" + // ContractThrift covers Apache Thrift IDL services. Provider + // contracts come from `service { ... }` blocks in .thrift files; + // the consumer side is usually detected through the generated-stub + // patterns the gRPC extractor recognises (NewClient), so + // the matcher's canonical-name join treats grpc and thrift as one + // RPC family when pairing. + ContractThrift ContractType = "thrift" + ContractGraphQL ContractType = "graphql" ContractTopic ContractType = "topic" ContractWS ContractType = "ws" ContractEnv ContractType = "env" diff --git a/internal/contracts/grpc.go b/internal/contracts/grpc.go index 38611345..17d45593 100644 --- a/internal/contracts/grpc.go +++ b/internal/contracts/grpc.go @@ -17,6 +17,10 @@ var ( // Proto service definitions: service Foo { rpc Bar(...) returns (...) } protoServiceRe = regexp.MustCompile(`(?m)service\s+(\w+)\s*\{`) protoRPCRe = regexp.MustCompile(`(?m)rpc\s+(\w+)\s*\(`) + // Proto package declaration: `package billing.v1;` — the namespace + // half of the canonical gRPC method name + // `./`. + protoPackageRe = regexp.MustCompile(`(?m)^\s*package\s+([\w.]+)\s*;`) // Richer RPC pattern that captures the request / response message // types along with optional `stream` modifiers on either side: @@ -44,6 +48,11 @@ var ( // the service as a consumer contract with SymbolID on the // enclosing function, even when we can't resolve method calls. goGRPCNewClientRe = regexp.MustCompile(`(?:[\w.]+\.)?New(\w+)Client\s*\(`) + // Go server registration: pb.RegisterUserServiceServer(s, impl). + // The code-side provider anchor — a service-level provider + // contract joins the IDL definition to the implementing repo even + // when per-method handler binding can't resolve. + goGRPCRegisterServerRe = regexp.MustCompile(`(?:[\w.]+\.)?Register(\w+)Server\s*\(`) // Inline chained call: pb.NewServiceClient(conn).Method(...). The // constructor's argument list is balance-scanned at match time, so // the regex only needs to anchor the `NewClient(` head; @@ -73,13 +82,15 @@ func (e *GRPCExtractor) SupportedLanguages() []string { // Cheap substring markers that act as a pre-filter before the regex // scans. Every gRPC consumer pattern in extractConsumers hinges on a -// `Client(` construction (Go, TS) or `Stub(` (Python) — so if the -// file has neither, none of the 9 regexes can match. bytes.Contains -// is ~100× cheaper than a regex walk and short-circuits 99% of files -// in gRPC-free repositories. +// `Client(` construction (Go, TS) or `Stub(` (Python), and every Go +// server registration on `Server(` — so if the file has none, none of +// the regexes can match. bytes.Contains is ~100× cheaper than a regex +// walk and short-circuits 99% of files in gRPC-free repositories. var ( - grpcClientMarker = []byte("Client(") - grpcStubMarker = []byte("Stub(") + grpcClientMarker = []byte("Client(") + grpcStubMarker = []byte("Stub(") + grpcRegisterMarker = []byte("Register") + grpcServerMarker = []byte("Server(") ) func (e *GRPCExtractor) Extract(filePath string, src []byte, nodes []*graph.Node, edges []*graph.Edge) []Contract { @@ -89,23 +100,187 @@ func (e *GRPCExtractor) Extract(filePath string, src []byte, nodes []*graph.Node contracts = append(contracts, e.extractProtoProviders(filePath, src)...) return contracts } - if !bytes.Contains(src, grpcClientMarker) && !bytes.Contains(src, grpcStubMarker) { + hasClient := bytes.Contains(src, grpcClientMarker) || bytes.Contains(src, grpcStubMarker) + hasRegistration := strings.HasSuffix(filePath, ".go") && + bytes.Contains(src, grpcRegisterMarker) && bytes.Contains(src, grpcServerMarker) + if !hasClient && !hasRegistration { return nil } fileNodes := filterFileNodes(filePath, nodes) sort.Slice(fileNodes, func(i, j int) bool { return fileNodes[i].StartLine < fileNodes[j].StartLine }) - contracts = append(contracts, e.extractConsumers(filePath, src, fileNodes)...) + if hasClient { + contracts = append(contracts, e.extractConsumers(filePath, src, fileNodes)...) + } + if hasRegistration { + contracts = append(contracts, e.extractServerRegistrations(filePath, src, fileNodes)...) + } return contracts } +// extractServerRegistrations detects Go gRPC server registration +// sites — `pb.RegisterUserServiceServer(s, impl)` — and emits one +// service-level provider contract per registered service. Generated +// stubs name the registration after the service, so the contract ID +// `grpc::` anchors the implementing repo to the IDL +// definition: the matcher's canonical-name join pairs it with +// method-level consumers/providers of the same service. +// +// The `RegisterServer(` shape also matches plain function +// definitions (`func RegisterHTTPServer(mux *http.ServeMux)`) and +// helpers with no gRPC involvement. Those are not registration call +// sites, so we reject any match that is a `func` declaration head and +// require real gRPC evidence in the file before treating a match as a +// provider — without the gate, latent `NewClient` consumer +// detections gain a spurious exact-ID partner and a false bridge forms. +func (e *GRPCExtractor) extractServerRegistrations(filePath string, src []byte, fileNodes []*graph.Node) []Contract { + text := string(src) + if !fileHasGRPCEvidence(text) { + return nil + } + var out []Contract + lines := strings.Split(text, "\n") + seen := make(map[string]struct{}) + for _, m := range goGRPCRegisterServerRe.FindAllStringSubmatchIndex(text, -1) { + svc := text[m[2]:m[3]] + if svc == "" { + continue + } + // Skip function definitions — `func RegisterServer(...)` is a + // declaration, not a registration call against a *grpc.Server. + if precededByFuncKeyword(text, m[0]) { + continue + } + if _, dup := seen[svc]; dup { + continue + } + // A bare `RegisterServer(arg)` with no package selector is + // almost always a local helper call or an unqualified definition + // reference, not a generated-stub registration. Generated gRPC + // registration funcs live in the protobuf package and are always + // invoked through it (`pb.RegisterUsersServer(...)`). Accept the + // unqualified form only when the file independently shows gRPC + // involvement (a grpc import), so a same-package registration + // still records while a plain `RegisterHTTPServer(mux)` helper + // call in a grpc-free file does not. + if !matchHasPackageSelector(text, m[0], m[1]) && !fileImportsGRPC(text) { + continue + } + seen[svc] = struct{}{} + ln := lineNumber(lines, m[0]) + out = append(out, Contract{ + ID: fmt.Sprintf("grpc::%s", svc), + Type: ContractGRPC, + Role: RoleProvider, + SymbolID: findEnclosingSymbol(fileNodes, ln), + FilePath: filePath, + Line: ln, + Meta: map[string]any{ + "service": svc, + "lang": "go", + "registration": true, + }, + Confidence: 0.85, + }) + } + return out +} + +// fileHasGRPCEvidence reports whether a Go source file shows any sign of +// being a gRPC server-registration site rather than coincidentally +// containing a `RegisterServer(` token. Evidence is either a grpc +// import or a package-qualified registration call (`pb.Register…`). +// Files with neither cannot host a real registration, so the registration +// scan is skipped wholesale — preventing a latent `NewClient` consumer +// from gaining a spurious exact-ID provider partner. +func fileHasGRPCEvidence(text string) bool { + if fileImportsGRPC(text) { + return true + } + for _, m := range goGRPCRegisterServerRe.FindAllStringSubmatchIndex(text, -1) { + if precededByFuncKeyword(text, m[0]) { + continue + } + if matchHasPackageSelector(text, m[0], m[1]) { + return true + } + } + return false +} + +// fileImportsGRPC reports whether the source imports a gRPC package. +// The substring is specific enough that a false hit is implausible. +func fileImportsGRPC(text string) bool { + return strings.Contains(text, "google.golang.org/grpc") +} + +// precededByFuncKeyword reports whether the token starting at off is a +// function declaration head — i.e. immediately preceded by the `func ` +// keyword (allowing for whitespace). `func RegisterHTTPServer(...)` is a +// definition, not a registration call. +func precededByFuncKeyword(text string, off int) bool { + i := off + // Skip whitespace immediately before the identifier. + for i > 0 && (text[i-1] == ' ' || text[i-1] == '\t') { + i-- + } + const kw = "func" + if i < len(kw) { + return false + } + if text[i-len(kw):i] != kw { + return false + } + // Ensure "func" is a standalone keyword, not the tail of an + // identifier like "myfunc". + if i-len(kw) > 0 { + prev := text[i-len(kw)-1] + if prev == '_' || isAlphaNum(prev) { + return false + } + } + return true +} + +// matchHasPackageSelector reports whether the registration token whose +// full match spans [start,end) is reached through a package/receiver +// selector (`pb.RegisterUsersServer(`) rather than bare +// (`RegisterHTTPServer(`). The regex's optional `[\w.]+\.` selector +// prefix is part of the match, so a selector is present exactly when the +// matched text holds a `.` ahead of `Register`. +func matchHasPackageSelector(text string, start, end int) bool { + if start < 0 || end > len(text) || start >= end { + return false + } + span := text[start:end] + dot := strings.IndexByte(span, '.') + reg := strings.Index(span, "Register") + return dot >= 0 && reg >= 0 && dot < reg +} + +// isAlphaNum reports whether b is an ASCII letter or digit. +func isAlphaNum(b byte) bool { + return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') +} + func (e *GRPCExtractor) extractProtoProviders(filePath string, src []byte) []Contract { var contracts []Contract text := string(src) lines := strings.Split(text, "\n") + // The proto package declaration provides the namespace half of the + // canonical gRPC method name `./` — the + // identity a client uses on the wire. Contract IDs stay + // package-free (`grpc::::`) so cross-repo exact-ID + // pairing keeps working against generated-stub consumers that only + // know the bare service name; the canonical name rides in Meta. + protoPackage := "" + if m := protoPackageRe.FindStringSubmatch(text); m != nil { + protoPackage = m[1] + } + // Build a fast lookup from (methodName → shape) so we can attach // request/response types to each RPC contract below. We run the // shape regex across the whole file once; services don't overlap @@ -126,15 +301,28 @@ func (e *GRPCExtractor) extractProtoProviders(filePath string, src []byte) []Con } } - // Find service blocks and their RPC methods. + // Find service blocks and their RPC methods. Each block is + // brace-bounded so a file declaring multiple services doesn't + // attribute a later service's RPCs to an earlier one (the open-ended + // "scan to EOF" form double-counted every method after the first + // service header). serviceMatches := protoServiceRe.FindAllStringSubmatchIndex(text, -1) for _, sMatch := range serviceMatches { serviceName := text[sMatch[2]:sMatch[3]] - // Find RPCs within the remainder of this service block. serviceStart := sMatch[0] - rest := text[serviceStart:] - rpcMatches := protoRPCRe.FindAllStringSubmatch(rest, -1) - rpcLocs := protoRPCRe.FindAllStringIndex(rest, -1) + // sMatch[1] points just past the `{`; balance-scan to the + // closing brace so the RPC scan stays inside this service. + blockEnd := matchCloseBrace(text, sMatch[1]) + if blockEnd < 0 { + blockEnd = len(text) + } + block := text[serviceStart:blockEnd] + rpcMatches := protoRPCRe.FindAllStringSubmatch(block, -1) + rpcLocs := protoRPCRe.FindAllStringIndex(block, -1) + qualService := serviceName + if protoPackage != "" { + qualService = protoPackage + "." + serviceName + } for i, rpc := range rpcMatches { methodName := rpc[1] absOffset := serviceStart + rpcLocs[i][0] @@ -143,8 +331,12 @@ func (e *GRPCExtractor) extractProtoProviders(filePath string, src []byte) []Con meta := map[string]any{ "service": serviceName, "method": methodName, + "canonical": qualService + "/" + methodName, "schema_source": "none", } + if protoPackage != "" { + meta["package"] = protoPackage + } if s, ok := shapes[methodName]; ok { meta["request_type"] = s.requestType meta["response_type"] = s.responseType @@ -172,6 +364,28 @@ func (e *GRPCExtractor) extractProtoProviders(filePath string, src []byte) []Con return contracts } +// matchCloseBrace returns the byte offset of the `}` that closes the +// `{` whose position is just before openEnd (i.e. openEnd points one +// past the opening brace). Returns -1 when the braces are unbalanced. +func matchCloseBrace(text string, openEnd int) int { + if openEnd <= 0 || openEnd > len(text) { + return -1 + } + depth := 1 + for i := openEnd; i < len(text); i++ { + switch text[i] { + case '{': + depth++ + case '}': + depth-- + if depth == 0 { + return i + } + } + } + return -1 +} + // extractConsumers detects gRPC client usage in a non-proto source file. // For Go it uses a two-pass scan so per-method RPC calls can emit // specific "grpc::Service::Method" contracts that match the provider ID diff --git a/internal/contracts/grpc_test.go b/internal/contracts/grpc_test.go index 94422dd7..40f57ce5 100644 --- a/internal/contracts/grpc_test.go +++ b/internal/contracts/grpc_test.go @@ -182,6 +182,140 @@ func main() { } } +// TestGRPCExtractor_RegisterServerDefinitionIsNotProvider guards the +// registration-site false positive: `RegisterServer(` also matches a +// plain function definition (`func RegisterHTTPServer(mux ...)`) and +// non-gRPC helper calls. Minting a provider from those activates a +// latent `NewClient` consumer into a false exact-ID match. None of +// the inputs below name google.golang.org/grpc, so they must produce no +// provider contract. +func TestGRPCExtractor_RegisterServerDefinitionIsNotProvider(t *testing.T) { + ext := &GRPCExtractor{} + + cases := map[string]string{ + "func definition": `package httpx + +import "net/http" + +// A plain registration helper — not a gRPC server registration. +func RegisterHTTPServer(mux *http.ServeMux) { + mux.Handle("/", nil) +} +`, + "bare helper call in grpc-free file": `package app + +func boot() { + RegisterMetricsServer(localRegistry) +} +`, + } + + for name, src := range cases { + contracts := ext.Extract("x.go", []byte(src), nil, nil) + for _, c := range contracts { + if c.Role == RoleProvider { + t.Errorf("%s: unexpected provider contract %+v — no gRPC evidence present", name, c) + } + } + } +} + +// TestGRPCExtractor_RegisterServerNoFalseBridgePartner is the end-to-end +// guard for the activation chain: a file with a `NewClient` consumer +// plus a `func RegisterServer` definition (the same service name) +// must NOT produce a same-ID provider, so the consumer stays an orphan +// and no false EdgeMatches / bridge can form. +func TestGRPCExtractor_RegisterServerNoFalseBridgePartner(t *testing.T) { + ext := &GRPCExtractor{} + src := []byte(`package app + +func wire() { + c := NewHTTPClient() + _ = c +} + +func RegisterHTTPServer(mux interface{}) {} +`) + contracts := ext.Extract("app.go", []byte(src), nil, nil) + var providers, consumers []Contract + for _, c := range contracts { + switch c.Role { + case RoleProvider: + providers = append(providers, c) + case RoleConsumer: + consumers = append(consumers, c) + } + } + if len(providers) != 0 { + t.Fatalf("expected no provider contracts from a func definition; got %+v", providers) + } + // The consumer side legitimately records grpc::HTTP, but with no + // provider it can never pair into a bridge. + for _, c := range consumers { + if c.ID == "grpc::HTTP" { + // Acceptable: an orphan consumer. Just assert no provider + // shares its ID (already checked above). + return + } + } +} + +// TestGRPCExtractor_RegisterServerPackageQualifiedIsProvider pins the +// positive path: a generated-stub registration call +// (`pb.RegisterUsersServer(grpcServer, impl)`) is package-qualified, so +// it remains a service-level provider even without a grpc import. +func TestGRPCExtractor_RegisterServerPackageQualifiedIsProvider(t *testing.T) { + ext := &GRPCExtractor{} + src := []byte(`package main + +import "context" + +type UsersServer struct{} + +func register(grpcServer interface{}) { + pb.RegisterUsersServer(grpcServer, &UsersServer{}) +} +`) + contracts := ext.Extract("server.go", []byte(src), nil, nil) + var found bool + for _, c := range contracts { + if c.Role == RoleProvider && c.ID == "grpc::Users" { + found = true + if reg, _ := c.Meta["registration"].(bool); !reg { + t.Errorf("expected registration=true on the provider contract, got %+v", c.Meta) + } + } + } + if !found { + t.Fatalf("expected provider contract grpc::Users from pb.RegisterUsersServer; got %+v", contracts) + } +} + +// TestGRPCExtractor_RegisterServerGRPCImportAllowsBareCall: a same- +// package registration with no selector still records when the file +// independently imports grpc. +func TestGRPCExtractor_RegisterServerGRPCImportAllowsBareCall(t *testing.T) { + ext := &GRPCExtractor{} + src := []byte(`package pb + +import "google.golang.org/grpc" + +func wire(s *grpc.Server, impl interface{}) { + RegisterUsersServer(s, impl) +} +`) + contracts := ext.Extract("wire.go", []byte(src), nil, nil) + var found bool + for _, c := range contracts { + if c.Role == RoleProvider && c.ID == "grpc::Users" { + found = true + } + } + if !found { + t.Fatalf("expected provider contract grpc::Users for a bare registration in a grpc-importing file; got %+v", contracts) + } +} + func assertContract(t *testing.T, c Contract, id string, ctype ContractType, role Role) { t.Helper() if c.ID != id { diff --git a/internal/contracts/matcher.go b/internal/contracts/matcher.go index a40ef72a..5579ec90 100644 --- a/internal/contracts/matcher.go +++ b/internal/contracts/matcher.go @@ -1,5 +1,7 @@ package contracts +import "strings" + // CrossLink represents a matched provider-consumer pair, possibly across repos. type CrossLink struct { ContractID string `json:"contract_id"` @@ -36,6 +38,15 @@ type MatchResult struct { // have different RepoPrefixes (legitimately so — two repos belonging // to one workspace, e.g. `tuck-api` provider matched with `tuck-app` // consumer when both declare WorkspaceID = "tuck"). +// +// After the exact-ID pairing, a second pass joins the RPC IDL family +// (gRPC + Thrift) by canonical service/method names — see +// joinRPCCanonical. IDL definitions and generated-stub call sites +// frequently disagree on the literal contract ID (package-qualified +// vs bare service names, camelCase vs PascalCase method casing, +// service-level registrations vs method-level calls); the canonical +// join recovers those pairs so cross-service traversal doesn't stop +// at a spelling difference. func Match(reg *Registry) MatchResult { var result MatchResult @@ -99,5 +110,258 @@ func Match(reg *Registry) MatchResult { result.OrphanConsumers = append(result.OrphanConsumers, cons...) } + joinRPCCanonical(&result) + return result } + +// isRPCFamily reports whether a contract belongs to the RPC IDL +// family the canonical-name join pairs across. gRPC and Thrift share +// the same generated-stub surface (`NewClient(...)`), so a +// code-side consumer detected as grpc legitimately pairs with a +// thrift IDL definition of the same service. +func isRPCFamily(c Contract) bool { + return c.Type == ContractGRPC || c.Type == ContractThrift +} + +// rpcServiceMethod extracts the canonical (service, method) join key +// from an RPC-family contract, both lowercased. The service name is +// stripped of any namespace/package qualifier (`billing.v1.Users` → +// `users`) and methods compare case-insensitively because generated +// stubs re-case them per language convention (Go GetUser vs TS +// getUser). Falls back to parsing the contract ID's +// `::[::]` segments when Meta is missing. An +// empty method means the contract is service-level (a client +// construction or a server registration without method granularity). +func rpcServiceMethod(c Contract) (service, method string) { + if c.Meta != nil { + service, _ = c.Meta["service"].(string) + method, _ = c.Meta["method"].(string) + } + if service == "" || method == "" { + parts := strings.Split(c.ID, "::") + if service == "" && len(parts) >= 2 { + service = parts[1] + } + if method == "" && len(parts) >= 3 { + method = parts[2] + } + } + if dot := strings.LastIndex(service, "."); dot >= 0 { + service = service[dot+1:] + } + return strings.ToLower(service), strings.ToLower(method) +} + +// rpcGroupID picks the contract ID that names the joined group: the +// method-level side wins (it is strictly more specific), provider +// first so two method-level sides group under the provider's ID — the +// ID every exact-matched link for the same RPC already uses. +func rpcGroupID(provider, consumer Contract) string { + if _, pm := rpcServiceMethod(provider); pm != "" { + return provider.ID + } + if _, cm := rpcServiceMethod(consumer); cm != "" { + return consumer.ID + } + return provider.ID +} + +// matcherIdentity is the per-record identity key the orphan-removal +// bookkeeping uses. Mirrors removeContract's field set so two registry +// entries that the Registry treats as distinct stay distinct here. +func matcherIdentity(c Contract) string { + return c.ID + "|" + c.FilePath + "|" + c.SymbolID + "|" + string(c.Role) + "|" + c.RepoPrefix +} + +// joinRPCCanonical pairs the RPC-family orphans left over from exact- +// ID matching by canonical service/method names, within the same +// (workspace, project) boundary the exact pass uses. Three shapes are +// recovered: +// +// - method-level consumer ↔ method-level provider whose IDs differ +// only in service qualification or method casing (TS camelCase +// stubs vs proto PascalCase RPCs); +// - service-level consumer (bare client construction) ↔ every +// provider of that service; +// - service-level provider (Go `RegisterServer` site) ↔ +// every consumer of that service. +// +// Joined contracts are removed from the orphan lists; the emitted +// CrossLinks group under the method-level side's contract ID (see +// rpcGroupID) so bridge materialisation keeps per-RPC granularity. +func joinRPCCanonical(result *MatchResult) { + type svcKey struct{ ws, proj, svc string } + type methodKey struct { + ws, proj, svc, method string + } + + // Index every RPC-family contract on BOTH sides of the existing + // result — matched and orphaned. A service-level orphan must be + // able to join contracts that already exact-matched (e.g. a TS + // client construction joining a proto RPC that a Go consumer + // already paired with). + var allProviders, allConsumers []Contract + for _, m := range result.Matched { + allProviders = append(allProviders, m.Provider) + allConsumers = append(allConsumers, m.Consumer) + } + allProviders = append(allProviders, result.OrphanProviders...) + allConsumers = append(allConsumers, result.OrphanConsumers...) + + provByMethod := make(map[methodKey][]Contract) + provBySvc := make(map[svcKey][]Contract) + provSeen := make(map[string]struct{}) + for _, p := range allProviders { + if !isRPCFamily(p) { + continue + } + // The matched list repeats a provider once per consumer it + // paired with; index each record once. + idKey := matcherIdentity(p) + if _, dup := provSeen[idKey]; dup { + continue + } + provSeen[idKey] = struct{}{} + svc, method := rpcServiceMethod(p) + if svc == "" { + continue + } + sk := svcKey{p.EffectiveWorkspace(), p.EffectiveProject(), svc} + provBySvc[sk] = append(provBySvc[sk], p) + if method != "" { + provByMethod[methodKey{sk.ws, sk.proj, svc, method}] = append( + provByMethod[methodKey{sk.ws, sk.proj, svc, method}], p) + } + } + + consByMethod := make(map[methodKey][]Contract) + consBySvc := make(map[svcKey][]Contract) + consSeen := make(map[string]struct{}) + for _, c := range allConsumers { + if !isRPCFamily(c) { + continue + } + idKey := matcherIdentity(c) + if _, dup := consSeen[idKey]; dup { + continue + } + consSeen[idKey] = struct{}{} + svc, method := rpcServiceMethod(c) + if svc == "" { + continue + } + sk := svcKey{c.EffectiveWorkspace(), c.EffectiveProject(), svc} + consBySvc[sk] = append(consBySvc[sk], c) + if method != "" { + consByMethod[methodKey{sk.ws, sk.proj, svc, method}] = append( + consByMethod[methodKey{sk.ws, sk.proj, svc, method}], c) + } + } + + joinedProv := make(map[string]struct{}) + joinedCons := make(map[string]struct{}) + linked := make(map[string]struct{}) + emit := func(p, c Contract) { + lk := matcherIdentity(p) + "->" + matcherIdentity(c) + if _, dup := linked[lk]; dup { + return + } + linked[lk] = struct{}{} + result.Matched = append(result.Matched, CrossLink{ + ContractID: rpcGroupID(p, c), + Provider: p, + Consumer: c, + CrossRepo: p.RepoPrefix != c.RepoPrefix, + }) + joinedProv[matcherIdentity(p)] = struct{}{} + joinedCons[matcherIdentity(c)] = struct{}{} + } + + // Orphan consumers seek providers. + for _, c := range result.OrphanConsumers { + if !isRPCFamily(c) { + continue + } + svc, method := rpcServiceMethod(c) + if svc == "" { + continue + } + sk := svcKey{c.EffectiveWorkspace(), c.EffectiveProject(), svc} + if method != "" { + provs := provByMethod[methodKey{sk.ws, sk.proj, svc, method}] + if len(provs) == 0 { + // No method-level provider — fall back to service- + // level providers only. Joining a different method's + // provider would be wrong. + for _, p := range provBySvc[sk] { + if _, pm := rpcServiceMethod(p); pm == "" { + provs = append(provs, p) + } + } + } + for _, p := range provs { + emit(p, c) + } + continue + } + // Service-level consumer joins every provider of the service. + for _, p := range provBySvc[sk] { + emit(p, c) + } + } + + // Orphan providers seek consumers (covers the registration-site + // provider whose consumers all exact-matched the IDL definition). + for _, p := range result.OrphanProviders { + if !isRPCFamily(p) { + continue + } + if _, done := joinedProv[matcherIdentity(p)]; done { + continue + } + svc, method := rpcServiceMethod(p) + if svc == "" { + continue + } + sk := svcKey{p.EffectiveWorkspace(), p.EffectiveProject(), svc} + if method != "" { + cons := consByMethod[methodKey{sk.ws, sk.proj, svc, method}] + if len(cons) == 0 { + for _, c := range consBySvc[sk] { + if _, cm := rpcServiceMethod(c); cm == "" { + cons = append(cons, c) + } + } + } + for _, c := range cons { + emit(p, c) + } + continue + } + for _, c := range consBySvc[sk] { + emit(p, c) + } + } + + if len(joinedProv) > 0 { + kept := result.OrphanProviders[:0] + for _, p := range result.OrphanProviders { + if _, done := joinedProv[matcherIdentity(p)]; done { + continue + } + kept = append(kept, p) + } + result.OrphanProviders = kept + } + if len(joinedCons) > 0 { + kept := result.OrphanConsumers[:0] + for _, c := range result.OrphanConsumers { + if _, done := joinedCons[matcherIdentity(c)]; done { + continue + } + kept = append(kept, c) + } + result.OrphanConsumers = kept + } +} diff --git a/internal/contracts/matcher_canonical_test.go b/internal/contracts/matcher_canonical_test.go new file mode 100644 index 00000000..2e7ce872 --- /dev/null +++ b/internal/contracts/matcher_canonical_test.go @@ -0,0 +1,293 @@ +package contracts + +import "testing" + +// TestMatch_RPCCanonicalJoin_MethodCasing: a TS stub call site emits +// camelCase method IDs while the proto IDL declares PascalCase. Exact +// ID pairing misses; the canonical join must pair them. +func TestMatch_RPCCanonicalJoin_MethodCasing(t *testing.T) { + reg := NewRegistry() + reg.Add(Contract{ + ID: "grpc::UserService::GetUser", + Type: ContractGRPC, + Role: RoleProvider, + FilePath: "proto/user.proto", + RepoPrefix: "svc-users", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "GetUser", "package": "billing.v1"}, + }) + reg.Add(Contract{ + ID: "grpc::UserService::getUser", + Type: ContractGRPC, + Role: RoleConsumer, + SymbolID: "web/api.ts::loadUser", + FilePath: "web/api.ts", + RepoPrefix: "webapp", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "getUser", "lang": "typescript"}, + }) + + result := Match(reg) + if len(result.Matched) != 1 { + t.Fatalf("expected 1 canonical-joined match, got %d (orphan providers=%d consumers=%d)", + len(result.Matched), len(result.OrphanProviders), len(result.OrphanConsumers)) + } + m := result.Matched[0] + if m.ContractID != "grpc::UserService::GetUser" { + t.Errorf("group ID should be the provider's method-level ID, got %s", m.ContractID) + } + if !m.CrossRepo { + t.Error("expected cross-repo join") + } + if len(result.OrphanProviders) != 0 || len(result.OrphanConsumers) != 0 { + t.Errorf("joined contracts must leave the orphan lists: providers=%d consumers=%d", + len(result.OrphanProviders), len(result.OrphanConsumers)) + } +} + +// TestMatch_RPCCanonicalJoin_IDLPlusStub is the IDL↔generated-stub +// scenario end to end at registry level: the .proto definition, a Go +// server registration in the implementing repo, and a Go client stub +// call in a consuming repo must all collapse into one linked group. +func TestMatch_RPCCanonicalJoin_IDLPlusStub(t *testing.T) { + idl := Contract{ + ID: "grpc::UserService::GetUser", + Type: ContractGRPC, + Role: RoleProvider, + FilePath: "proto/user.proto", + RepoPrefix: "svc-users", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "GetUser"}, + } + registration := Contract{ + ID: "grpc::UserService", + Type: ContractGRPC, + Role: RoleProvider, + SymbolID: "cmd/server/main.go::main", + FilePath: "cmd/server/main.go", + RepoPrefix: "svc-users", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "registration": true}, + } + stubCall := Contract{ + ID: "grpc::UserService::GetUser", + Type: ContractGRPC, + Role: RoleConsumer, + SymbolID: "client/users.go::fetchUser", + FilePath: "client/users.go", + RepoPrefix: "gateway", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "GetUser", "lang": "go"}, + } + + reg := NewRegistry() + reg.Add(idl) + reg.Add(registration) + reg.Add(stubCall) + + result := Match(reg) + + // Exact pass: IDL provider ↔ stub consumer. Canonical pass: the + // service-level registration provider joins the same consumer. + if len(result.Matched) != 2 { + t.Fatalf("expected 2 links (exact + canonical), got %d: %+v", len(result.Matched), result.Matched) + } + groupIDs := map[string]int{} + providers := map[string]bool{} + for _, m := range result.Matched { + groupIDs[m.ContractID]++ + providers[m.Provider.FilePath] = true + if m.Consumer.SymbolID != "client/users.go::fetchUser" { + t.Errorf("unexpected consumer: %+v", m.Consumer) + } + } + // Both links group under the method-level contract ID — one + // bridge group for the RPC. + if groupIDs["grpc::UserService::GetUser"] != 2 { + t.Errorf("links should group under the method-level ID: %v", groupIDs) + } + if !providers["proto/user.proto"] || !providers["cmd/server/main.go"] { + t.Errorf("both IDL and registration providers must link: %v", providers) + } + if len(result.OrphanProviders) != 0 { + t.Errorf("registration provider should be joined, orphans: %+v", result.OrphanProviders) + } +} + +// TestMatch_RPCCanonicalJoin_ServiceLevelConsumer: a bare client +// construction (no resolvable method calls) joins every method the +// service provides. +func TestMatch_RPCCanonicalJoin_ServiceLevelConsumer(t *testing.T) { + reg := NewRegistry() + for _, method := range []string{"GetUser", "ListUsers"} { + reg.Add(Contract{ + ID: "grpc::UserService::" + method, + Type: ContractGRPC, + Role: RoleProvider, + FilePath: "proto/user.proto", + RepoPrefix: "svc-users", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": method}, + }) + } + reg.Add(Contract{ + ID: "grpc::UserService", + Type: ContractGRPC, + Role: RoleConsumer, + SymbolID: "app/client.py::build_stub", + FilePath: "app/client.py", + RepoPrefix: "py-app", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "lang": "python"}, + }) + + result := Match(reg) + if len(result.Matched) != 2 { + t.Fatalf("service-level consumer should join both providers, got %d links", len(result.Matched)) + } + if len(result.OrphanConsumers) != 0 { + t.Errorf("service-level consumer should be joined: %+v", result.OrphanConsumers) + } +} + +// TestMatch_RPCCanonicalJoin_ThriftProviderGRPCStyleConsumer: thrift +// IDL providers pair with consumers detected through the shared +// generated-stub patterns (typed grpc by the code-side extractor). +func TestMatch_RPCCanonicalJoin_ThriftFamily(t *testing.T) { + reg := NewRegistry() + reg.Add(Contract{ + ID: "thrift::Calculator::add", + Type: ContractThrift, + Role: RoleProvider, + FilePath: "idl/calc.thrift", + RepoPrefix: "calc-svc", + WorkspaceID: "acme", + ProjectID: "calc", + Meta: map[string]any{"service": "Calculator", "method": "add"}, + }) + reg.Add(Contract{ + ID: "grpc::Calculator::add", + Type: ContractGRPC, + Role: RoleConsumer, + SymbolID: "main.go::compute", + FilePath: "main.go", + RepoPrefix: "calc-cli", + WorkspaceID: "acme", + ProjectID: "calc", + Meta: map[string]any{"service": "Calculator", "method": "add", "lang": "go"}, + }) + + result := Match(reg) + if len(result.Matched) != 1 { + t.Fatalf("expected thrift/grpc family join, got %d matches", len(result.Matched)) + } + if result.Matched[0].ContractID != "thrift::Calculator::add" { + t.Errorf("group ID should be the provider's method-level ID, got %s", result.Matched[0].ContractID) + } +} + +// TestMatch_RPCCanonicalJoin_RespectsBoundary: the canonical join must +// honour the same (workspace, project) boundary as exact matching. +func TestMatch_RPCCanonicalJoin_RespectsBoundary(t *testing.T) { + reg := NewRegistry() + reg.Add(Contract{ + ID: "grpc::UserService::GetUser", + Type: ContractGRPC, + Role: RoleProvider, + FilePath: "user.proto", + RepoPrefix: "svc-users", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "GetUser"}, + }) + reg.Add(Contract{ + ID: "grpc::UserService::getUser", + Type: ContractGRPC, + Role: RoleConsumer, + FilePath: "api.ts", + RepoPrefix: "other-app", + WorkspaceID: "globex", // different workspace — must NOT join + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "getUser"}, + }) + + result := Match(reg) + if len(result.Matched) != 0 { + t.Fatalf("across-workspace contracts must not join: %+v", result.Matched) + } + if len(result.OrphanProviders) != 1 || len(result.OrphanConsumers) != 1 { + t.Errorf("both sides stay orphaned: providers=%d consumers=%d", + len(result.OrphanProviders), len(result.OrphanConsumers)) + } +} + +// TestMatch_RPCCanonicalJoin_NoWrongMethodJoin: a method-level +// consumer with no matching provider method must not join a different +// method's provider. +func TestMatch_RPCCanonicalJoin_NoWrongMethodJoin(t *testing.T) { + reg := NewRegistry() + reg.Add(Contract{ + ID: "grpc::UserService::DeleteUser", + Type: ContractGRPC, + Role: RoleProvider, + FilePath: "user.proto", + RepoPrefix: "svc-users", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "DeleteUser"}, + }) + reg.Add(Contract{ + ID: "grpc::UserService::GetUser", + Type: ContractGRPC, + Role: RoleConsumer, + FilePath: "client.go", + RepoPrefix: "gateway", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "GetUser"}, + }) + + result := Match(reg) + if len(result.Matched) != 0 { + t.Fatalf("different methods must not join: %+v", result.Matched) + } +} + +// TestMatch_RPCCanonicalJoin_PackageQualifiedService: a provider whose +// Meta carries a package-qualified service name joins a bare-named +// consumer of the same service. +func TestMatch_RPCCanonicalJoin_PackageQualifiedService(t *testing.T) { + reg := NewRegistry() + reg.Add(Contract{ + ID: "grpc::billing.v1.UserService::GetUser", + Type: ContractGRPC, + Role: RoleProvider, + FilePath: "user.proto", + RepoPrefix: "svc-users", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "billing.v1.UserService", "method": "GetUser"}, + }) + reg.Add(Contract{ + ID: "grpc::UserService::GetUser", + Type: ContractGRPC, + Role: RoleConsumer, + FilePath: "client.go", + RepoPrefix: "gateway", + WorkspaceID: "acme", + ProjectID: "users", + Meta: map[string]any{"service": "UserService", "method": "GetUser"}, + }) + + result := Match(reg) + if len(result.Matched) != 1 { + t.Fatalf("package-qualified service should join bare-named consumer, got %d", len(result.Matched)) + } +} diff --git a/internal/contracts/proto_idl_test.go b/internal/contracts/proto_idl_test.go new file mode 100644 index 00000000..66432641 --- /dev/null +++ b/internal/contracts/proto_idl_test.go @@ -0,0 +1,182 @@ +package contracts + +import "testing" + +// TestGRPCExtractor_ProtoProvider_PackageAndCanonical covers the +// IDL-aware provider extraction: the proto package declaration rides +// on Meta["package"] and the fully-qualified canonical method name +// (`./` — the on-wire gRPC identity) on +// Meta["canonical"], while the contract ID stays package-free so +// exact-ID pairing against bare-named stub consumers keeps working. +func TestGRPCExtractor_ProtoProvider_PackageAndCanonical(t *testing.T) { + ext := &GRPCExtractor{} + src := []byte(` +syntax = "proto3"; + +package billing.v1; + +service UserService { + rpc GetUser(GetUserRequest) returns (GetUserResponse) {} + rpc WatchUsers(WatchUsersRequest) returns (stream UserEvent) {} +} +`) + out := ext.Extract("user.proto", src, nil, nil) + if len(out) != 2 { + t.Fatalf("expected 2 contracts, got %d: %+v", len(out), out) + } + + get := out[0] + assertContract(t, get, "grpc::UserService::GetUser", ContractGRPC, RoleProvider) + if get.Meta["package"] != "billing.v1" { + t.Errorf("Meta[package] = %v, want billing.v1", get.Meta["package"]) + } + if get.Meta["canonical"] != "billing.v1.UserService/GetUser" { + t.Errorf("Meta[canonical] = %v, want billing.v1.UserService/GetUser", get.Meta["canonical"]) + } + if get.Meta["service"] != "UserService" || get.Meta["method"] != "GetUser" { + t.Errorf("service/method meta = %v / %v", get.Meta["service"], get.Meta["method"]) + } + if get.Meta["request_type"] != "GetUserRequest" || get.Meta["response_type"] != "GetUserResponse" { + t.Errorf("request/response meta = %v / %v", get.Meta["request_type"], get.Meta["response_type"]) + } + + watch := out[1] + assertContract(t, watch, "grpc::UserService::WatchUsers", ContractGRPC, RoleProvider) + if watch.Meta["response_stream"] != true { + t.Errorf("Meta[response_stream] = %v, want true", watch.Meta["response_stream"]) + } + if watch.Meta["canonical"] != "billing.v1.UserService/WatchUsers" { + t.Errorf("Meta[canonical] = %v", watch.Meta["canonical"]) + } +} + +// TestGRPCExtractor_ProtoProvider_NoPackage: without a package +// declaration the canonical name degrades to `/` and +// Meta["package"] is absent. +func TestGRPCExtractor_ProtoProvider_NoPackage(t *testing.T) { + ext := &GRPCExtractor{} + src := []byte(` +service Pinger { + rpc Ping(PingRequest) returns (PingResponse); +} +`) + out := ext.Extract("ping.proto", src, nil, nil) + if len(out) != 1 { + t.Fatalf("expected 1 contract, got %d", len(out)) + } + if _, has := out[0].Meta["package"]; has { + t.Errorf("Meta[package] should be absent, got %v", out[0].Meta["package"]) + } + if out[0].Meta["canonical"] != "Pinger/Ping" { + t.Errorf("Meta[canonical] = %v, want Pinger/Ping", out[0].Meta["canonical"]) + } +} + +// TestGRPCExtractor_ProtoProvider_MultipleServicesBounded guards the +// brace-bounded service scan: a file declaring two services must not +// attribute the second service's RPCs to the first (the open-ended +// scan emitted OrderService's methods under UserService too). +func TestGRPCExtractor_ProtoProvider_MultipleServicesBounded(t *testing.T) { + ext := &GRPCExtractor{} + src := []byte(` +syntax = "proto3"; +package shop.v1; + +service UserService { + rpc GetUser(GetUserRequest) returns (GetUserResponse); +} + +service OrderService { + rpc PlaceOrder(PlaceOrderRequest) returns (PlaceOrderResponse); + rpc CancelOrder(CancelOrderRequest) returns (CancelOrderResponse); +} +`) + out := ext.Extract("shop.proto", src, nil, nil) + if len(out) != 3 { + t.Fatalf("expected 3 contracts (1 + 2), got %d: %+v", len(out), out) + } + got := map[string]bool{} + for _, c := range out { + got[c.ID] = true + } + for _, want := range []string{ + "grpc::UserService::GetUser", + "grpc::OrderService::PlaceOrder", + "grpc::OrderService::CancelOrder", + } { + if !got[want] { + t.Errorf("missing contract %s; got %v", want, got) + } + } + if got["grpc::UserService::PlaceOrder"] || got["grpc::UserService::CancelOrder"] { + t.Errorf("OrderService RPCs leaked into UserService: %v", got) + } +} + +// TestGRPCExtractor_GoServerRegistration covers the code-side provider +// anchor: a `pb.RegisterServer(...)` call emits one service- +// level provider contract bound to the enclosing function. +func TestGRPCExtractor_GoServerRegistration(t *testing.T) { + ext := &GRPCExtractor{} + src := []byte(`package main + +import ( + "google.golang.org/grpc" + + pb "example.com/gen/users" +) + +func main() { + s := grpc.NewServer() + pb.RegisterUserServiceServer(s, &userServer{}) + s.Serve(lis) +} +`) + nodes := makeNodes("main.go", []struct { + name string + start, end int + }{ + {"main", 9, 13}, + }) + + out := ext.Extract("main.go", src, nodes, nil) + var reg *Contract + for i := range out { + if out[i].Role == RoleProvider && out[i].ID == "grpc::UserService" { + reg = &out[i] + } + } + if reg == nil { + t.Fatalf("missing registration provider contract grpc::UserService; got %+v", out) + } + if reg.SymbolID != "main.go::main" { + t.Errorf("SymbolID = %q, want main.go::main", reg.SymbolID) + } + if reg.Meta["registration"] != true { + t.Errorf("Meta[registration] = %v, want true", reg.Meta["registration"]) + } + if reg.Meta["service"] != "UserService" { + t.Errorf("Meta[service] = %v, want UserService", reg.Meta["service"]) + } +} + +// TestGRPCExtractor_GoServerRegistration_NotDuplicated: a file with +// neither client constructions nor registrations stays contract-free +// even when it mentions Register-prefixed identifiers without the +// generated-server shape. +func TestGRPCExtractor_GoServerRegistration_NoFalsePositive(t *testing.T) { + ext := &GRPCExtractor{} + src := []byte(`package main + +func main() { + registry.RegisterAll(handlers) + prometheus.MustRegister(collector) +} +`) + out := ext.Extract("main.go", src, nil, nil) + for _, c := range out { + if c.Type == ContractGRPC { + t.Errorf("unexpected gRPC contract from non-gRPC Register call: %+v", c) + } + } +} diff --git a/internal/contracts/thrift.go b/internal/contracts/thrift.go new file mode 100644 index 00000000..c0edabfa --- /dev/null +++ b/internal/contracts/thrift.go @@ -0,0 +1,222 @@ +package contracts + +import ( + "fmt" + "regexp" + "strings" + + "github.com/zzet/gortex/internal/graph" +) + +// ThriftExtractor detects Apache Thrift IDL service definitions and +// emits one provider contract per declared service function. The +// consumer side rides on the generated-stub patterns the gRPC +// extractor already recognises (`NewClient(`), so the +// matcher's canonical-name join pairs thrift IDL providers with code +// that calls the generated client. +type ThriftExtractor struct{} + +var ( + // namespace go shared / namespace java com.example.shared / + // namespace * everything + thriftNamespaceRe = regexp.MustCompile(`(?m)^\s*namespace\s+([\w.*]+)\s+([\w.]+)`) + // service Calculator extends shared.SharedService { + thriftServiceRe = regexp.MustCompile(`(?m)^\s*service\s+(\w+)(?:\s+extends\s+([\w.]+))?\s*\{`) + // One function declaration inside a service block: + // void ping(), + // i32 add(1:i32 num1, 2:i32 num2), + // oneway void zip() + // list fetch(1: string id) throws (1: NotFound nf); + // Groups: 1 = oneway (or ""), 2 = return type (incl. container + // generics), 3 = function name. + thriftFunctionRe = regexp.MustCompile(`(?m)^\s*(oneway\s+)?([\w.]+(?:\s*<[^>{}]*>)?)\s+(\w+)\s*\(`) +) + +// thriftNamespacePreference orders the namespace scopes used to pick +// the single Meta["package"] value when a file declares several. "*" +// applies to every generator, so it wins; after that the order is an +// arbitrary-but-stable language preference. +var thriftNamespacePreference = []string{"*", "go", "java", "py", "python", "js", "rb", "cpp"} + +func (e *ThriftExtractor) SupportedLanguages() []string { + // "thrift" matches the parser registry's Language() for .thrift + // files (see internal/parser/languages forest registrations). + return []string{"thrift"} +} + +func (e *ThriftExtractor) Extract(filePath string, src []byte, nodes []*graph.Node, edges []*graph.Edge) []Contract { + if !strings.HasSuffix(filePath, ".thrift") { + return nil + } + var out []Contract + text := string(src) + lines := strings.Split(text, "\n") + + namespaces := make(map[string]string) + for _, m := range thriftNamespaceRe.FindAllStringSubmatch(text, -1) { + if _, exists := namespaces[m[1]]; !exists { + namespaces[m[1]] = m[2] + } + } + pkg := pickThriftNamespace(namespaces) + + for _, sMatch := range thriftServiceRe.FindAllStringSubmatchIndex(text, -1) { + serviceName := text[sMatch[2]:sMatch[3]] + // sMatch[1] points just past the `{`; bound the function scan + // to this service's block so sibling services stay separate. + blockEnd := matchCloseBrace(text, sMatch[1]) + if blockEnd < 0 { + blockEnd = len(text) + } + blockStart := sMatch[1] + block := text[blockStart:blockEnd] + qualService := serviceName + if pkg != "" { + qualService = pkg + "." + serviceName + } + + for _, fm := range thriftFunctionRe.FindAllStringSubmatchIndex(block, -1) { + oneway := fm[2] >= 0 + retType := strings.TrimSpace(block[fm[4]:fm[5]]) + name := block[fm[6]:fm[7]] + // Filter declarations the loose line regex can't tell + // apart from functions: keyword-led lines are struct / + // enum / typedef bodies that only appear in malformed + // files, but cheap to guard anyway. + if isThriftKeyword(retType) && retType != "void" { + continue + } + absOffset := blockStart + fm[0] + line := lineNumber(lines, absOffset) + + meta := map[string]any{ + "service": serviceName, + "method": name, + "canonical": qualService + "/" + name, + } + if pkg != "" { + meta["package"] = pkg + } + if len(namespaces) > 0 { + meta["namespaces"] = namespaces + } + if oneway { + meta["oneway"] = true + } + if retType != "void" { + meta["response_type"] = retType + meta["schema_source"] = "extracted" + } else { + meta["schema_source"] = "none" + } + // fm[7] points at the function name's end; the `(` + // follows after optional whitespace. Balance-scan the + // argument list and record each field as "name:type". + if args := thriftArgList(block, fm[7]); len(args) > 0 { + meta["args"] = args + } + + out = append(out, Contract{ + ID: fmt.Sprintf("thrift::%s::%s", serviceName, name), + Type: ContractThrift, + Role: RoleProvider, + FilePath: filePath, + Line: line, + Meta: meta, + Confidence: 0.95, + }) + } + } + + return out +} + +// pickThriftNamespace selects the Meta["package"] value from the +// declared namespaces, preferring the generator-agnostic "*" scope, +// then a stable per-language order, then any remaining scope. +func pickThriftNamespace(namespaces map[string]string) string { + for _, scope := range thriftNamespacePreference { + if ns, ok := namespaces[scope]; ok { + return ns + } + } + for _, ns := range namespaces { + return ns + } + return "" +} + +// thriftArgList parses the parenthesised field list that starts after +// nameEnd (the byte offset just past the function name) and returns +// one "name:type" entry per field. Thrift fields look like +// `1: required i32 num1` — the numeric id and requiredness qualifier +// are stripped, the declared type and name kept. +func thriftArgList(block string, nameEnd int) []string { + open := strings.Index(block[nameEnd:], "(") + if open < 0 { + return nil + } + openEnd := nameEnd + open + 1 + closeAt := matchCloseParen(block, openEnd) + if closeAt < 0 { + return nil + } + var out []string + for _, raw := range splitTopLevelArgs(block[openEnd:closeAt]) { + field := strings.TrimSpace(raw) + if field == "" { + continue + } + // Strip the leading `N:` field id. + if colon := strings.Index(field, ":"); colon >= 0 && isThriftFieldID(field[:colon]) { + field = strings.TrimSpace(field[colon+1:]) + } + // Strip requiredness qualifiers. + for _, q := range []string{"required ", "optional "} { + field = strings.TrimPrefix(field, q) + } + // Drop a default value (`= 42`). + if eq := strings.Index(field, "="); eq >= 0 { + field = strings.TrimSpace(field[:eq]) + } + // What remains is ` ` with the type possibly + // containing generics and spaces (map). The name + // is the final identifier. + if sp := strings.LastIndexAny(field, " \t>"); sp >= 0 && sp+1 < len(field) { + name := strings.TrimSpace(field[sp+1:]) + typ := strings.TrimSpace(field[:sp+1]) + if name != "" && typ != "" { + out = append(out, name+":"+typ) + continue + } + } + out = append(out, field) + } + return out +} + +// isThriftFieldID reports whether s is a numeric thrift field id (the +// `1` in `1: i32 num1`). +func isThriftFieldID(s string) bool { + s = strings.TrimSpace(s) + if s == "" { + return false + } + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} + +// isThriftKeyword reports whether the token is a thrift declaration +// keyword that cannot be a function return type. +func isThriftKeyword(s string) bool { + switch s { + case "struct", "enum", "union", "exception", "typedef", "const", + "service", "include", "namespace", "throws", "void": + return true + } + return false +} diff --git a/internal/contracts/thrift_test.go b/internal/contracts/thrift_test.go new file mode 100644 index 00000000..8f4232d8 --- /dev/null +++ b/internal/contracts/thrift_test.go @@ -0,0 +1,157 @@ +package contracts + +import "testing" + +func TestThriftExtractor_ServiceFunctions(t *testing.T) { + ext := &ThriftExtractor{} + src := []byte(`namespace go shared.calc +namespace java com.example.calc + +struct Work { + 1: i32 num1, + 2: i32 num2, +} + +service Calculator extends shared.SharedService { + void ping(), + + i32 add(1: i32 num1, 2: i32 num2), + + Work calculate(1: i32 logid, 2: Work w) throws (1: InvalidOperation ouch), + + oneway void zip() +} +`) + out := ext.Extract("calc.thrift", src, nil, nil) + if len(out) != 4 { + t.Fatalf("expected 4 contracts, got %d: %+v", len(out), out) + } + + byID := map[string]Contract{} + for _, c := range out { + byID[c.ID] = c + if c.Type != ContractThrift { + t.Errorf("%s: Type = %q, want thrift", c.ID, c.Type) + } + if c.Role != RoleProvider { + t.Errorf("%s: Role = %q, want provider", c.ID, c.Role) + } + } + + ping, ok := byID["thrift::Calculator::ping"] + if !ok { + t.Fatalf("missing thrift::Calculator::ping; got %v", byID) + } + if ping.Meta["canonical"] != "shared.calc.Calculator/ping" { + t.Errorf("ping canonical = %v", ping.Meta["canonical"]) + } + if ping.Meta["package"] != "shared.calc" { + t.Errorf("ping package = %v, want the go namespace", ping.Meta["package"]) + } + if _, hasResp := ping.Meta["response_type"]; hasResp { + t.Errorf("void function must not carry response_type: %v", ping.Meta["response_type"]) + } + + add, ok := byID["thrift::Calculator::add"] + if !ok { + t.Fatalf("missing thrift::Calculator::add") + } + if add.Meta["response_type"] != "i32" { + t.Errorf("add response_type = %v, want i32", add.Meta["response_type"]) + } + args, _ := add.Meta["args"].([]string) + if len(args) != 2 || args[0] != "num1:i32" || args[1] != "num2:i32" { + t.Errorf("add args = %v, want [num1:i32 num2:i32]", args) + } + + calc, ok := byID["thrift::Calculator::calculate"] + if !ok { + t.Fatalf("missing thrift::Calculator::calculate") + } + if calc.Meta["response_type"] != "Work" { + t.Errorf("calculate response_type = %v, want Work", calc.Meta["response_type"]) + } + + zip, ok := byID["thrift::Calculator::zip"] + if !ok { + t.Fatalf("missing thrift::Calculator::zip") + } + if zip.Meta["oneway"] != true { + t.Errorf("zip oneway = %v, want true", zip.Meta["oneway"]) + } + + // Struct fields must not be misread as service functions. + if _, leaked := byID["thrift::Calculator::num1"]; leaked { + t.Errorf("struct field leaked into service functions") + } +} + +func TestThriftExtractor_MultipleServicesBounded(t *testing.T) { + ext := &ThriftExtractor{} + src := []byte(`namespace py users + +service UserService { + User getUser(1: string id) +} + +service AdminService { + void purge(1: string id) +} +`) + out := ext.Extract("users.thrift", src, nil, nil) + if len(out) != 2 { + t.Fatalf("expected 2 contracts, got %d: %+v", len(out), out) + } + got := map[string]bool{} + for _, c := range out { + got[c.ID] = true + } + if !got["thrift::UserService::getUser"] || !got["thrift::AdminService::purge"] { + t.Errorf("missing expected contracts: %v", got) + } + if got["thrift::UserService::purge"] { + t.Errorf("AdminService function leaked into UserService") + } +} + +func TestThriftExtractor_ContainerTypesAndNamespacePreference(t *testing.T) { + ext := &ThriftExtractor{} + src := []byte(`namespace java com.example.inventory +namespace * inventory + +service Inventory { + list listItems(1: map filters), + map labels() +} +`) + out := ext.Extract("inv.thrift", src, nil, nil) + if len(out) != 2 { + t.Fatalf("expected 2 contracts, got %d: %+v", len(out), out) + } + byID := map[string]Contract{} + for _, c := range out { + byID[c.ID] = c + } + li, ok := byID["thrift::Inventory::listItems"] + if !ok { + t.Fatalf("missing listItems; got %v", byID) + } + // "*" namespace applies to every generator, so it wins over java. + if li.Meta["package"] != "inventory" { + t.Errorf("package = %v, want inventory (the * namespace)", li.Meta["package"]) + } + if li.Meta["response_type"] != "list" { + t.Errorf("listItems response_type = %v", li.Meta["response_type"]) + } + if _, ok := byID["thrift::Inventory::labels"]; !ok { + t.Errorf("generic return type with no args not extracted") + } +} + +func TestThriftExtractor_NonThriftFileIgnored(t *testing.T) { + ext := &ThriftExtractor{} + out := ext.Extract("main.go", []byte("service Foo { void bar() }"), nil, nil) + if len(out) != 0 { + t.Fatalf("non-.thrift file must produce no contracts, got %+v", out) + } +} diff --git a/internal/graph/edge.go b/internal/graph/edge.go index cc84d8f9..f7643e66 100644 --- a/internal/graph/edge.go +++ b/internal/graph/edge.go @@ -30,6 +30,16 @@ const ( // boundaries by hopping Consumer → EdgeConsumes⁻¹ → consumer-contract // → EdgeMatches → provider-contract → EdgeProvides⁻¹ → handler. EdgeMatches EdgeKind = "matches" + // EdgeBridges links a KindContractBridge group node to one of the + // KindContract nodes participating in the group. Direction: + // bridge → contract. Meta["side"] ∈ provider | consumer | both + // ("both" when the provider and consumer contracts share one ID + // and therefore collapse into a single contract node in the + // graph). Emitted by the contract-bridge materialisation pass in + // ReconcileContractEdges alongside EdgeMatches; the whole bridge + // generation is evicted and re-derived on every reconcile so the + // edges never outlive the contracts they group. + EdgeBridges EdgeKind = "bridges" // EdgeAnnotated links a symbol to a synthetic annotation node // representing a decorator / annotation / attribute applied to it // (e.g. @Component, @Test, @Deprecated, #[derive(Debug)], @@ -700,7 +710,7 @@ func DefaultOriginFor(kind EdgeKind, confidence float64, semanticSource string) // Structural AST edges are unambiguous by construction. switch kind { case EdgeDefines, EdgeImports, EdgeContains, EdgeExtends, EdgeMemberOf, - EdgeImplements, EdgeProvides, EdgeConsumes, EdgeMatches, + EdgeImplements, EdgeProvides, EdgeConsumes, EdgeMatches, EdgeBridges, // Coverage structural edges: the extractor produces an // unambiguous source→target binding for each, so they share // the AST-resolved tier. @@ -940,7 +950,7 @@ func ConfidenceLabelFor(kind EdgeKind, confidence float64) string { // Structural edges from AST are always extracted. switch kind { case EdgeDefines, EdgeImports, EdgeContains, EdgeExtends, EdgeMemberOf, EdgeImplements, - EdgeProvides, EdgeConsumes, EdgeMatches, + EdgeProvides, EdgeConsumes, EdgeMatches, EdgeBridges, EdgeParamOf, EdgeAliases, EdgeComposes, EdgeOverrides, EdgeLicensedAs, EdgeOwns, EdgeAuthored, EdgeGeneratedBy, EdgeDependsOnModule, EdgePackageWorkspaceMember, diff --git a/internal/graph/node.go b/internal/graph/node.go index 8ac1f829..3866c3f9 100644 --- a/internal/graph/node.go +++ b/internal/graph/node.go @@ -238,6 +238,23 @@ const ( // the persistent code graph; this kind names the entity so it reads as // first-class in tool output and query filters. KindAgent NodeKind = "agent" + // KindContractBridge represents one matched provider↔consumer + // contract group — an HTTP route, a gRPC/Thrift method, or a + // pub/sub topic — materialised as a single graph node that spans + // every repo participating in the group. ID convention: + // `bridge::` where contract-id is the canonical + // contract key (`http::GET::/v1/users`, `grpc::Users::GetUser`, + // `topic::kafka::orders`), so the bridge for any contract is + // addressable from the contract ID alone, across repos. Meta + // carries contract_type, canonical_key, repos (sorted slice of + // participating repo prefixes), provider_count, consumer_count + // and cross_repo. EdgeBridges links the bridge to each + // participating KindContract node. Bridge nodes are re-derived + // from the matcher result on every contract reconcile — all of + // them share the synthetic FilePath "contracts://bridges" so the + // reconcile pass can evict the stale generation with one + // EvictFile call before re-minting. + KindContractBridge NodeKind = "contract_bridge" ) var validNodeKinds = map[NodeKind]bool{ @@ -254,7 +271,7 @@ var validNodeKinds = map[NodeKind]bool{ KindRelease: true, KindLicense: true, KindString: true, KindResource: true, KindKustomization: true, KindImage: true, KindArtifact: true, KindDoc: true, KindTopic: true, - KindMacro: true, KindAgent: true, + KindMacro: true, KindAgent: true, KindContractBridge: true, } type Node struct { diff --git a/internal/graph/storetest/storetest.go b/internal/graph/storetest/storetest.go index 5edaeb2e..163d743e 100644 --- a/internal/graph/storetest/storetest.go +++ b/internal/graph/storetest/storetest.go @@ -102,6 +102,7 @@ func RunConformance(t *testing.T, factory Factory) { t.Run("CoverageEnrichmentSidecar", func(t *testing.T) { testCoverageEnrichmentSidecar(t, factory) }) t.Run("ReleaseEnrichmentSidecar", func(t *testing.T) { testReleaseEnrichmentSidecar(t, factory) }) t.Run("BlameEnrichmentSidecar", func(t *testing.T) { testBlameEnrichmentSidecar(t, factory) }) + t.Run("ContractBridgeRoundTrip", func(t *testing.T) { testContractBridgeRoundTrip(t, factory) }) } // -- fixture helpers --------------------------------------------------- @@ -3650,3 +3651,124 @@ func testBlameEnrichmentSidecar(t *testing.T, factory Factory) { t.Fatalf("delete must not touch repoB: %d", len(got)) } } + +// testContractBridgeRoundTrip verifies the contract-bridge subgraph +// survives the backend: a KindContractBridge node with its grouped +// Meta (incl. a []string repo spread), EdgeBridges edges carrying +// Meta["side"], kind-filtered retrieval, and the EvictFile path the +// bridge re-materialisation pass uses to drop the prior generation. +func testContractBridgeRoundTrip(t *testing.T, factory Factory) { + t.Helper() + s := factory(t) + + const bridgeFile = "contracts://bridges" + s.AddNode(&graph.Node{ + ID: "http::GET::/v1/users", Kind: graph.KindContract, + Name: "http::GET::/v1/users", FilePath: "svc-a/routes.go", + Language: "contract", RepoPrefix: "svc-a", + }) + s.AddNode(&graph.Node{ + ID: "grpc::Users::GetUser", Kind: graph.KindContract, + Name: "grpc::Users::GetUser", FilePath: "svc-b/client.go", + Language: "contract", RepoPrefix: "svc-b", + }) + s.AddNode(&graph.Node{ + ID: "bridge::http::GET::/v1/users", Kind: graph.KindContractBridge, + Name: "GET /v1/users", FilePath: bridgeFile, + Language: "contract", RepoPrefix: "svc-a", + Meta: map[string]any{ + "contract_type": "http", + "canonical_key": "GET /v1/users", + "repos": []string{"svc-a", "svc-b"}, + "provider_count": 1, + "consumer_count": 2, + "cross_repo": true, + }, + }) + s.AddEdge(&graph.Edge{ + From: "bridge::http::GET::/v1/users", To: "http::GET::/v1/users", + Kind: graph.EdgeBridges, FilePath: bridgeFile, + Meta: map[string]any{"side": "both"}, + }) + s.AddEdge(&graph.Edge{ + From: "bridge::http::GET::/v1/users", To: "grpc::Users::GetUser", + Kind: graph.EdgeBridges, FilePath: bridgeFile, + Meta: map[string]any{"side": "consumer"}, + }) + + got := s.GetNode("bridge::http::GET::/v1/users") + if got == nil { + t.Fatalf("bridge node did not round-trip") + } + if got.Kind != graph.KindContractBridge { + t.Fatalf("bridge kind = %q, want %q", got.Kind, graph.KindContractBridge) + } + if got.Meta == nil { + t.Fatalf("bridge Meta not preserved") + } + if got.Meta["canonical_key"] != "GET /v1/users" { + t.Fatalf("Meta[canonical_key] = %v", got.Meta["canonical_key"]) + } + switch repos := got.Meta["repos"].(type) { + case []string: + if len(repos) != 2 || repos[0] != "svc-a" || repos[1] != "svc-b" { + t.Fatalf("Meta[repos] = %v", repos) + } + case []any: + if len(repos) != 2 || repos[0] != "svc-a" || repos[1] != "svc-b" { + t.Fatalf("Meta[repos] = %v", repos) + } + default: + t.Fatalf("Meta[repos] has unexpected type %T (%v)", got.Meta["repos"], got.Meta["repos"]) + } + + var byKind []*graph.Node + for n := range s.NodesByKind(graph.KindContractBridge) { + byKind = append(byKind, n) + } + if len(byKind) != 1 || byKind[0].ID != "bridge::http::GET::/v1/users" { + t.Fatalf("NodesByKind(contract_bridge) = %v", byKind) + } + + out := s.GetOutEdges("bridge::http::GET::/v1/users") + sides := map[string]string{} + for _, e := range out { + if e.Kind != graph.EdgeBridges { + t.Fatalf("unexpected out-edge kind %q", e.Kind) + } + side, _ := e.Meta["side"].(string) + sides[e.To] = side + } + if sides["http::GET::/v1/users"] != "both" || sides["grpc::Users::GetUser"] != "consumer" { + t.Fatalf("EdgeBridges side meta did not round-trip: %v", sides) + } + + // The reverse direction the impact query walks. + var inKinds []graph.EdgeKind + for _, e := range s.GetInEdges("http::GET::/v1/users") { + inKinds = append(inKinds, e.Kind) + } + if len(inKinds) != 1 || inKinds[0] != graph.EdgeBridges { + t.Fatalf("GetInEdges(contract) = %v, want one bridges edge", inKinds) + } + + // Evicting the synthetic bridge file must drop the bridge node and + // its edges while leaving the contract nodes untouched — this is + // the idempotency mechanism the materialisation pass relies on. + nodesRemoved, edgesRemoved := s.EvictFile(bridgeFile) + if nodesRemoved != 1 { + t.Fatalf("EvictFile removed %d nodes, want 1", nodesRemoved) + } + if edgesRemoved != 2 { + t.Fatalf("EvictFile removed %d edges, want 2", edgesRemoved) + } + if s.GetNode("bridge::http::GET::/v1/users") != nil { + t.Fatalf("bridge node survived EvictFile") + } + if s.GetNode("http::GET::/v1/users") == nil || s.GetNode("grpc::Users::GetUser") == nil { + t.Fatalf("contract nodes must survive bridge eviction") + } + if got := s.GetInEdges("http::GET::/v1/users"); len(got) != 0 { + t.Fatalf("stale bridge in-edges survived eviction: %v", got) + } +} diff --git a/internal/indexer/contract_bridge.go b/internal/indexer/contract_bridge.go new file mode 100644 index 00000000..2f2ce0c1 --- /dev/null +++ b/internal/indexer/contract_bridge.go @@ -0,0 +1,267 @@ +package indexer + +import ( + "sort" + "strings" + + "github.com/zzet/gortex/internal/contracts" + "github.com/zzet/gortex/internal/graph" +) + +// ContractBridgeFilePath is the synthetic FilePath every +// KindContractBridge node (and its EdgeBridges edges) carries. Bridge +// nodes are derived state — re-computed from the matcher result on +// every contract reconcile — so they share one virtual "file" and the +// materialisation pass evicts the previous generation with a single +// EvictFile call before re-minting. That makes the pass idempotent +// and self-cleaning: a contract group that disappears (file deleted, +// repo untracked, route renamed) takes its bridge with it on the next +// reconcile. +const ContractBridgeFilePath = "contracts://bridges" + +// bridgeGroupKey is the identity a matched contract group materialises +// under. It mirrors the matcher's pairing boundary — Match buckets +// provider/consumer pairs by (EffectiveWorkspace, EffectiveProject, +// ContractID) and never pairs across that boundary, so two unrelated +// workspaces that each serve the same route (`GET /api/users`) produce +// two distinct groups, not one merged bridge. Keying the bridge on the +// bare ContractID alone collapsed them, summing counts and asserting a +// cross-repo blast radius the matcher never produced. +type bridgeGroupKey struct { + workspace string + project string + contractID string +} + +// bridgeGroup accumulates one matched provider↔consumer contract +// group while the materialisation pass walks the CrossLink list. +type bridgeGroup struct { + contractType contracts.ContractType + workspaceID string + projectID string + providerRepo string + repos map[string]struct{} + // side membership per participating contract node ID. + providerIDs map[string]struct{} + consumerIDs map[string]struct{} + // distinct provider/consumer records (a contract node collapses + // same-ID records, so counts come from registry identities). + providerKeys map[string]struct{} + consumerKeys map[string]struct{} + crossRepo bool + minLine int +} + +// MaterializeContractBridges persists the matcher's view of the +// contract surface as a queryable subgraph: one KindContractBridge +// node per matched provider↔consumer contract group (an HTTP route, +// a gRPC/Thrift method, a pub/sub topic), linked to every +// participating KindContract node via EdgeBridges (Meta["side"] = +// provider | consumer | both). +// +// Identity: the bridge node ID is +// `bridge::::::`, where contract-id is +// the canonical contract key (`http::GET::/v1/users`, +// `grpc::Users::GetUser`, `topic::kafka::orders`) and workspace/project +// are the matched group's effective slugs. Pinning the bridge to the +// match boundary keeps two unrelated workspaces that each serve the +// same route from collapsing into one bridge — the matcher already +// pairs only inside one (workspace, project), and the bridge identity +// must respect the same boundary. The key is repo-free within a +// boundary, so one bridge spans every repo of that workspace's group; +// the bridge node's RepoPrefix is the lexicographically-smallest +// provider repo (a deterministic owner for per-repo rollups) and +// Meta["repos"] carries the full sorted spread. +// +// The previous bridge generation is always evicted first (see +// ContractBridgeFilePath), even when matched is empty — that is what +// makes re-runs idempotent and removes bridges whose contracts +// disappeared. Returns the number of bridge nodes minted. +func MaterializeContractBridges(g graph.Store, matched []contracts.CrossLink) int { + if g == nil { + return 0 + } + g.EvictFile(ContractBridgeFilePath) + if len(matched) == 0 { + return 0 + } + + groups := make(map[bridgeGroupKey]*bridgeGroup) + for _, m := range matched { + if m.ContractID == "" { + continue + } + key := bridgeGroupKey{ + workspace: m.Provider.EffectiveWorkspace(), + project: m.Provider.EffectiveProject(), + contractID: m.ContractID, + } + grp, ok := groups[key] + if !ok { + grp = &bridgeGroup{ + contractType: m.Provider.Type, + workspaceID: key.workspace, + projectID: key.project, + repos: make(map[string]struct{}), + providerIDs: make(map[string]struct{}), + consumerIDs: make(map[string]struct{}), + providerKeys: make(map[string]struct{}), + consumerKeys: make(map[string]struct{}), + // minLine starts unset (0) and folds in a true min over + // every provider line below, so the persisted StartLine is + // independent of the (map-ordered) match iteration order. + minLine: 0, + } + groups[key] = grp + } + if m.Provider.RepoPrefix != "" { + grp.repos[m.Provider.RepoPrefix] = struct{}{} + if grp.providerRepo == "" || m.Provider.RepoPrefix < grp.providerRepo { + grp.providerRepo = m.Provider.RepoPrefix + } + } + if m.Consumer.RepoPrefix != "" { + grp.repos[m.Consumer.RepoPrefix] = struct{}{} + } + // True min over all provider lines so StartLine doesn't flap with + // the match-iteration order. A zero/negative line (spec-only + // provider with no resolved line) never lowers a real minimum. + if m.Provider.Line > 0 && (grp.minLine == 0 || m.Provider.Line < grp.minLine) { + grp.minLine = m.Provider.Line + } + grp.providerIDs[m.Provider.ID] = struct{}{} + grp.consumerIDs[m.Consumer.ID] = struct{}{} + grp.providerKeys[contractRecordKey(m.Provider)] = struct{}{} + grp.consumerKeys[contractRecordKey(m.Consumer)] = struct{}{} + if m.CrossRepo { + grp.crossRepo = true + } + } + + // Deterministic emit order keeps re-runs byte-stable on ordered + // backends and makes test assertions reproducible. + groupKeys := make([]bridgeGroupKey, 0, len(groups)) + for k := range groups { + groupKeys = append(groupKeys, k) + } + sort.Slice(groupKeys, func(i, j int) bool { + if groupKeys[i].workspace != groupKeys[j].workspace { + return groupKeys[i].workspace < groupKeys[j].workspace + } + if groupKeys[i].project != groupKeys[j].project { + return groupKeys[i].project < groupKeys[j].project + } + return groupKeys[i].contractID < groupKeys[j].contractID + }) + + minted := 0 + for _, key := range groupKeys { + grp := groups[key] + groupID := key.contractID + bridgeID := bridgeNodeID(key) + + repos := make([]string, 0, len(grp.repos)) + for r := range grp.repos { + repos = append(repos, r) + } + sort.Strings(repos) + + g.AddNode(&graph.Node{ + ID: bridgeID, + Kind: graph.KindContractBridge, + Name: bridgeCanonicalKey(groupID, grp.contractType), + FilePath: ContractBridgeFilePath, + StartLine: grp.minLine, + Language: "contract", + RepoPrefix: grp.providerRepo, + WorkspaceID: grp.workspaceID, + Meta: map[string]any{ + "contract_type": string(grp.contractType), + "canonical_key": bridgeCanonicalKey(groupID, grp.contractType), + "contract_id": groupID, + "workspace": grp.workspaceID, + "project": grp.projectID, + "repos": repos, + "provider_count": len(grp.providerKeys), + "consumer_count": len(grp.consumerKeys), + "cross_repo": grp.crossRepo, + }, + }) + minted++ + + // One EdgeBridges per participating contract node. A contract + // node that carries records on BOTH sides (exact-ID matches + // collapse provider and consumer into one node) gets a single + // edge with side="both" — two same-(from,to,kind) edges would + // collide in the adjacency dedup anyway. + contractIDs := make(map[string]struct{}, len(grp.providerIDs)+len(grp.consumerIDs)) + for id := range grp.providerIDs { + contractIDs[id] = struct{}{} + } + for id := range grp.consumerIDs { + contractIDs[id] = struct{}{} + } + ordered := make([]string, 0, len(contractIDs)) + for id := range contractIDs { + ordered = append(ordered, id) + } + sort.Strings(ordered) + for _, contractID := range ordered { + _, isProv := grp.providerIDs[contractID] + _, isCons := grp.consumerIDs[contractID] + side := "provider" + switch { + case isProv && isCons: + side = "both" + case isCons: + side = "consumer" + } + g.AddEdge(&graph.Edge{ + From: bridgeID, + To: contractID, + Kind: graph.EdgeBridges, + FilePath: ContractBridgeFilePath, + Confidence: 1.0, + ConfidenceLabel: "EXTRACTED", + Origin: graph.OriginASTResolved, + CrossRepo: grp.crossRepo, + Meta: map[string]any{"side": side}, + }) + } + } + + return minted +} + +// bridgeNodeID renders the persisted node ID for a contract-bridge +// group: `bridge::::::`. The boundary +// slugs are part of the identity so two unrelated workspaces serving +// the same contract never share a bridge node (see bridgeGroupKey). +func bridgeNodeID(key bridgeGroupKey) string { + return "bridge::" + key.workspace + "::" + key.project + "::" + key.contractID +} + +// contractRecordKey identifies one registry record (the same dedupe +// fields Registry.All uses) so provider/consumer counts reflect +// distinct call sites rather than distinct contract node IDs. +func contractRecordKey(c contracts.Contract) string { + return c.ID + "|" + c.FilePath + "|" + c.SymbolID + "|" + c.RepoPrefix +} + +// bridgeCanonicalKey renders the human-facing canonical key for a +// contract group ID: the `::` prefix is dropped and the +// remaining segments joined per protocol convention — "GET /v1/users" +// for HTTP, "Users.GetUser" for RPC, "kafka::orders" for topics. +func bridgeCanonicalKey(groupID string, t contracts.ContractType) string { + rest := groupID + if i := strings.Index(rest, "::"); i >= 0 { + rest = rest[i+2:] + } + switch t { + case contracts.ContractHTTP, contracts.ContractOpenAPI: + return strings.Replace(rest, "::", " ", 1) + case contracts.ContractGRPC, contracts.ContractThrift, contracts.ContractGraphQL: + return strings.Replace(rest, "::", ".", 1) + } + return rest +} diff --git a/internal/indexer/contract_bridge_test.go b/internal/indexer/contract_bridge_test.go new file mode 100644 index 00000000..47b10788 --- /dev/null +++ b/internal/indexer/contract_bridge_test.go @@ -0,0 +1,454 @@ +package indexer + +import ( + "context" + "path/filepath" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/config" + "github.com/zzet/gortex/internal/contracts" + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/search" +) + +// mkBridgeLink builds one matched CrossLink between hand-rolled +// provider/consumer contracts for the unit-level materialisation tests. +func mkBridgeLink(groupID string, provider, consumer contracts.Contract) contracts.CrossLink { + return contracts.CrossLink{ + ContractID: groupID, + Provider: provider, + Consumer: consumer, + CrossRepo: provider.RepoPrefix != consumer.RepoPrefix, + } +} + +func collectBridgeNodes(g graph.Store) []*graph.Node { + var out []*graph.Node + for n := range g.NodesByKind(graph.KindContractBridge) { + out = append(out, n) + } + return out +} + +func collectBridgeEdges(g graph.Store) []*graph.Edge { + var out []*graph.Edge + for e := range g.EdgesByKind(graph.EdgeBridges) { + out = append(out, e) + } + return out +} + +// TestMaterializeContractBridges_GroupsAndSides covers the core +// materialisation contract: one bridge node per matched group, repo +// spread + counts in Meta, and EdgeBridges fan-out with side meta — +// including the "both" collapse when an exact-ID match shares one +// contract node across roles. +func TestMaterializeContractBridges_GroupsAndSides(t *testing.T) { + g := graph.New() + + httpProvider := contracts.Contract{ + ID: "http::GET::/api/users", Type: contracts.ContractHTTP, Role: contracts.RoleProvider, + SymbolID: "svc-a/routes.go::listUsers", FilePath: "svc-a/routes.go", Line: 10, + RepoPrefix: "svc-a", WorkspaceID: "acme", ProjectID: "users", + } + httpConsumer := contracts.Contract{ + ID: "http::GET::/api/users", Type: contracts.ContractHTTP, Role: contracts.RoleConsumer, + SymbolID: "svc-b/client.go::fetchUsers", FilePath: "svc-b/client.go", Line: 7, + RepoPrefix: "svc-b", WorkspaceID: "acme", ProjectID: "users", + } + grpcIDL := contracts.Contract{ + ID: "grpc::Users::GetUser", Type: contracts.ContractGRPC, Role: contracts.RoleProvider, + FilePath: "svc-a/users.proto", Line: 5, + RepoPrefix: "svc-a", WorkspaceID: "acme", ProjectID: "users", + Meta: map[string]any{"service": "Users", "method": "GetUser"}, + } + grpcStub := contracts.Contract{ + ID: "grpc::Users::getUser", Type: contracts.ContractGRPC, Role: contracts.RoleConsumer, + SymbolID: "web/api.ts::loadUser", FilePath: "web/api.ts", Line: 3, + RepoPrefix: "webapp", WorkspaceID: "acme", ProjectID: "users", + Meta: map[string]any{"service": "Users", "method": "getUser"}, + } + + // Contract nodes as commitContracts would mint them. + for _, id := range []string{"http::GET::/api/users", "grpc::Users::GetUser", "grpc::Users::getUser"} { + g.AddNode(&graph.Node{ID: id, Kind: graph.KindContract, Name: id, Language: "contract"}) + } + + matched := []contracts.CrossLink{ + mkBridgeLink("http::GET::/api/users", httpProvider, httpConsumer), + mkBridgeLink("grpc::Users::GetUser", grpcIDL, grpcStub), + } + + minted := MaterializeContractBridges(g, matched) + require.Equal(t, 2, minted, "one bridge per matched group") + + httpBridge := g.GetNode("bridge::acme::users::http::GET::/api/users") + require.NotNil(t, httpBridge, "http bridge node missing") + assert.Equal(t, graph.KindContractBridge, httpBridge.Kind) + assert.Equal(t, ContractBridgeFilePath, httpBridge.FilePath) + assert.Equal(t, "GET /api/users", httpBridge.Meta["canonical_key"]) + assert.Equal(t, "http", httpBridge.Meta["contract_type"]) + assert.Equal(t, []string{"svc-a", "svc-b"}, httpBridge.Meta["repos"]) + assert.Equal(t, 1, httpBridge.Meta["provider_count"]) + assert.Equal(t, 1, httpBridge.Meta["consumer_count"]) + assert.Equal(t, true, httpBridge.Meta["cross_repo"]) + assert.Equal(t, "svc-a", httpBridge.RepoPrefix, "bridge owner is the provider repo") + assert.Equal(t, "acme", httpBridge.WorkspaceID) + + // Exact-ID match: provider and consumer collapse into one contract + // node — a single side="both" edge. + httpEdges := g.GetOutEdges(httpBridge.ID) + require.Len(t, httpEdges, 1) + assert.Equal(t, graph.EdgeBridges, httpEdges[0].Kind) + assert.Equal(t, "http::GET::/api/users", httpEdges[0].To) + assert.Equal(t, "both", httpEdges[0].Meta["side"]) + + // Canonical join: provider and consumer keep distinct contract + // nodes — one edge per side. + grpcBridge := g.GetNode("bridge::acme::users::grpc::Users::GetUser") + require.NotNil(t, grpcBridge, "grpc bridge node missing") + assert.Equal(t, "Users.GetUser", grpcBridge.Meta["canonical_key"]) + sides := map[string]string{} + for _, e := range g.GetOutEdges(grpcBridge.ID) { + side, _ := e.Meta["side"].(string) + sides[e.To] = side + } + assert.Equal(t, map[string]string{ + "grpc::Users::GetUser": "provider", + "grpc::Users::getUser": "consumer", + }, sides) +} + +// TestMaterializeContractBridges_IdempotentAndEvicting: re-running +// with the same matches replaces the prior generation 1:1; running +// with a shrunken match set drops the stale bridge; running with no +// matches clears the subgraph entirely. +func TestMaterializeContractBridges_IdempotentAndEvicting(t *testing.T) { + g := graph.New() + + provider := contracts.Contract{ + ID: "http::GET::/api/users", Type: contracts.ContractHTTP, Role: contracts.RoleProvider, + SymbolID: "a/routes.go::listUsers", FilePath: "a/routes.go", RepoPrefix: "a", + } + consumer := contracts.Contract{ + ID: "http::GET::/api/users", Type: contracts.ContractHTTP, Role: contracts.RoleConsumer, + SymbolID: "b/client.go::fetchUsers", FilePath: "b/client.go", RepoPrefix: "b", + } + provider2 := contracts.Contract{ + ID: "topic::kafka::orders", Type: contracts.ContractTopic, Role: contracts.RoleProvider, + SymbolID: "a/pub.go::publish", FilePath: "a/pub.go", RepoPrefix: "a", + } + consumer2 := contracts.Contract{ + ID: "topic::kafka::orders", Type: contracts.ContractTopic, Role: contracts.RoleConsumer, + SymbolID: "b/sub.go::consume", FilePath: "b/sub.go", RepoPrefix: "b", + } + + full := []contracts.CrossLink{ + mkBridgeLink("http::GET::/api/users", provider, consumer), + mkBridgeLink("topic::kafka::orders", provider2, consumer2), + } + + require.Equal(t, 2, MaterializeContractBridges(g, full)) + require.Len(t, collectBridgeNodes(g), 2) + edgesBefore := len(collectBridgeEdges(g)) + require.Greater(t, edgesBefore, 0) + + // Idempotent re-run: same input, same persisted state. + require.Equal(t, 2, MaterializeContractBridges(g, full)) + assert.Len(t, collectBridgeNodes(g), 2, "re-run must not duplicate bridge nodes") + assert.Equal(t, edgesBefore, len(collectBridgeEdges(g)), "re-run must not duplicate bridge edges") + + // The topic group disappears (e.g. its file was deleted): its + // bridge must go with it. + require.Equal(t, 1, MaterializeContractBridges(g, full[:1])) + assert.Nil(t, g.GetNode("bridge::a::a::topic::kafka::orders"), "stale bridge must be evicted") + assert.NotNil(t, g.GetNode("bridge::a::a::http::GET::/api/users")) + + // Everything disappears. + require.Equal(t, 0, MaterializeContractBridges(g, nil)) + assert.Empty(t, collectBridgeNodes(g)) + assert.Empty(t, collectBridgeEdges(g)) +} + +// TestContractBridge_TwoRepoIntegration drives the full pipeline over +// two tracked repos sharing one workspace: a .proto IDL provider repo +// (with a Go server registration) and a Go stub-consumer repo. The +// reconcile pass must persist one bridge spanning both repos, and a +// second reconcile must leave the subgraph unchanged. +func TestContractBridge_TwoRepoIntegration(t *testing.T) { + providerRoot := setupGRPCProtoProviderRepo(t, "auth-service") + consumerRoot := setupGRPCGoConsumerRepo(t, "client-svc") + + tmpCfg := filepath.Join(t.TempDir(), "config.yaml") + gc := &config.GlobalConfig{ + Repos: []config.RepoEntry{ + {Path: providerRoot, Name: "auth-service"}, + {Path: consumerRoot, Name: "client-svc"}, + }, + } + gc.SetConfigPath(tmpCfg) + require.NoError(t, gc.Save()) + cm, err := config.NewConfigManager(tmpCfg) + require.NoError(t, err) + + g := graph.New() + mi := NewMultiIndexer(g, newMultiLangRegistry(), search.NewBM25(), cm, zap.NewNop()) + for _, entry := range cm.Global().Repos { + _, err := mi.TrackRepoCtx(context.Background(), entry) + require.NoError(t, err) + } + + bridge := g.GetNode("bridge::shared-test::shared-test::grpc::Users::GetUser") + require.NotNil(t, bridge, + "expected persisted bridge for the matched gRPC group; bridges: %v", bridgeIDs(g)) + assert.Equal(t, graph.KindContractBridge, bridge.Kind) + assert.Equal(t, "grpc", bridge.Meta["contract_type"]) + assert.Equal(t, "Users.GetUser", bridge.Meta["canonical_key"]) + assert.Equal(t, true, bridge.Meta["cross_repo"]) + + repos, _ := bridge.Meta["repos"].([]string) + assert.Equal(t, []string{"auth-service", "client-svc"}, repos, + "bridge must span both participating repos") + + // EdgeBridges fan-out: the shared-ID contract node carries both + // roles; the registration-site provider contract (grpc::Users, + // joined canonically) rides as a separate provider edge. + sides := map[string]string{} + for _, e := range g.GetOutEdges(bridge.ID) { + require.Equal(t, graph.EdgeBridges, e.Kind) + side, _ := e.Meta["side"].(string) + sides[e.To] = side + } + assert.Equal(t, "both", sides["grpc::Users::GetUser"], + "exact-ID provider+consumer collapse into one contract node: %v", sides) + assert.Equal(t, "provider", sides["grpc::Users"], + "registration-site provider must join the group: %v", sides) + + // Idempotency across reconciles: re-running the pass replaces the + // generation in place. + nodesBefore := len(collectBridgeNodes(g)) + edgesBefore := len(collectBridgeEdges(g)) + mi.ReconcileContractEdges() + assert.Equal(t, nodesBefore, len(collectBridgeNodes(g)), "reconcile must not duplicate bridges") + assert.Equal(t, edgesBefore, len(collectBridgeEdges(g)), "reconcile must not duplicate bridge edges") + + // Untracking the consumer dissolves the group: the next reconcile + // rebuilds bridges from the remaining contracts only. + mi.UntrackRepo("client-svc") + assert.Nil(t, g.GetNode("bridge::shared-test::shared-test::grpc::Users::GetUser"), + "bridge must dissolve when the consumer repo is untracked; bridges: %v", bridgeIDs(g)) +} + +// TestMaterializeContractBridges_BoundaryIsolation: two unrelated +// workspaces that each serve the same contract (`GET /api/users`) must +// materialise as TWO distinct bridge nodes, never one merged bridge. +// The matcher pairs only inside a (workspace, project) boundary, so a +// bridge keyed on the bare contract ID would assert a provider_count / +// cross-repo blast radius the matcher never produced. +func TestMaterializeContractBridges_BoundaryIsolation(t *testing.T) { + g := graph.New() + + // Workspace "acme": one repo internally serving + consuming the route. + acmeProv := contracts.Contract{ + ID: "http::GET::/api/users", Type: contracts.ContractHTTP, Role: contracts.RoleProvider, + SymbolID: "acme-api/routes.go::list", FilePath: "acme-api/routes.go", Line: 10, + RepoPrefix: "acme-api", WorkspaceID: "acme", ProjectID: "acme", + } + acmeCons := contracts.Contract{ + ID: "http::GET::/api/users", Type: contracts.ContractHTTP, Role: contracts.RoleConsumer, + SymbolID: "acme-api/client.go::fetch", FilePath: "acme-api/client.go", Line: 4, + RepoPrefix: "acme-api", WorkspaceID: "acme", ProjectID: "acme", + } + // Workspace "globex": a completely unrelated service that happens to + // expose and consume the identical route string. + globexProv := contracts.Contract{ + ID: "http::GET::/api/users", Type: contracts.ContractHTTP, Role: contracts.RoleProvider, + SymbolID: "globex-api/routes.go::list", FilePath: "globex-api/routes.go", Line: 22, + RepoPrefix: "globex-api", WorkspaceID: "globex", ProjectID: "globex", + } + globexCons := contracts.Contract{ + ID: "http::GET::/api/users", Type: contracts.ContractHTTP, Role: contracts.RoleConsumer, + SymbolID: "globex-api/client.go::fetch", FilePath: "globex-api/client.go", Line: 8, + RepoPrefix: "globex-api", WorkspaceID: "globex", ProjectID: "globex", + } + + for _, id := range []string{"http::GET::/api/users"} { + g.AddNode(&graph.Node{ID: id, Kind: graph.KindContract, Name: id, Language: "contract"}) + } + + matched := []contracts.CrossLink{ + mkBridgeLink("http::GET::/api/users", acmeProv, acmeCons), + mkBridgeLink("http::GET::/api/users", globexProv, globexCons), + } + + minted := MaterializeContractBridges(g, matched) + require.Equal(t, 2, minted, "two unrelated workspaces must not collapse into one bridge") + + acmeBridge := g.GetNode("bridge::acme::acme::http::GET::/api/users") + require.NotNil(t, acmeBridge, "acme bridge missing; bridges: %v", bridgeIDs(g)) + globexBridge := g.GetNode("bridge::globex::globex::http::GET::/api/users") + require.NotNil(t, globexBridge, "globex bridge missing; bridges: %v", bridgeIDs(g)) + + // Each bridge counts only its own workspace's provider — never the + // summed count a merged bridge would assert. + assert.Equal(t, 1, acmeBridge.Meta["provider_count"]) + assert.Equal(t, 1, globexBridge.Meta["provider_count"]) + assert.Equal(t, "acme-api", acmeBridge.RepoPrefix) + assert.Equal(t, "globex-api", globexBridge.RepoPrefix) + assert.Equal(t, "acme", acmeBridge.Meta["workspace"]) + assert.Equal(t, "globex", globexBridge.Meta["workspace"]) + // Neither is cross-repo: each pairs inside its own single repo. + assert.Equal(t, false, acmeBridge.Meta["cross_repo"]) + assert.Equal(t, false, globexBridge.Meta["cross_repo"]) +} + +// TestMaterializeContractBridges_StartLineIsOrderIndependent: a group +// with multiple provider records at different lines must pin its bridge +// StartLine to the true minimum regardless of the (map-ordered) match +// iteration, so reconciles stay byte-stable instead of flapping the +// persisted field. +func TestMaterializeContractBridges_StartLineIsOrderIndependent(t *testing.T) { + provLow := contracts.Contract{ + ID: "grpc::Users::GetUser", Type: contracts.ContractGRPC, Role: contracts.RoleProvider, + SymbolID: "svc/a.go::A", FilePath: "svc/a.go", Line: 5, + RepoPrefix: "svc", WorkspaceID: "w", ProjectID: "p", + Meta: map[string]any{"service": "Users", "method": "GetUser"}, + } + provHigh := contracts.Contract{ + ID: "grpc::Users::GetUser", Type: contracts.ContractGRPC, Role: contracts.RoleProvider, + SymbolID: "svc/b.go::B", FilePath: "svc/b.go", Line: 99, + RepoPrefix: "svc", WorkspaceID: "w", ProjectID: "p", + Meta: map[string]any{"service": "Users", "method": "GetUser"}, + } + consumer := contracts.Contract{ + ID: "grpc::Users::GetUser", Type: contracts.ContractGRPC, Role: contracts.RoleConsumer, + SymbolID: "web/api.ts::load", FilePath: "web/api.ts", Line: 3, + RepoPrefix: "web", WorkspaceID: "w", ProjectID: "p", + Meta: map[string]any{"service": "Users", "method": "GetUser"}, + } + + // Two orderings of the same matched group: high-line link first, then + // low-line link first. Both must yield StartLine = 5 (the true min). + orderings := [][]contracts.CrossLink{ + { + mkBridgeLink("grpc::Users::GetUser", provHigh, consumer), + mkBridgeLink("grpc::Users::GetUser", provLow, consumer), + }, + { + mkBridgeLink("grpc::Users::GetUser", provLow, consumer), + mkBridgeLink("grpc::Users::GetUser", provHigh, consumer), + }, + } + for i, matched := range orderings { + g := graph.New() + require.Equal(t, 1, MaterializeContractBridges(g, matched), "ordering %d", i) + bridge := g.GetNode("bridge::w::p::grpc::Users::GetUser") + require.NotNil(t, bridge, "ordering %d: bridge missing", i) + assert.Equal(t, 5, bridge.StartLine, + "ordering %d: StartLine must be the true minimum provider line, order-independent", i) + } +} + +// TestReconcileContractEdges_ConcurrentNoRaceOrTear stresses the +// serialisation that ReconcileContractEdges needs. The janitor, the +// file-watcher, and MCP-triggered track/index all drive it on +// independent goroutines, and the pass evicts the prior EdgeMatches / +// topic / bridge generation and mints a fresh one across many +// non-atomic store writes. Concurrent runs whose registry snapshots +// disagree (one mid-flight while another track/untrack mutates the +// registry) can interleave an evict over the other's freshly-minted +// bridge and persist a stale generation. Here a writer goroutine +// repeatedly untracks and re-tracks the consumer repo — flipping the +// matched set between "bridge present" and "bridge absent" — while many +// reader goroutines reconcile concurrently. Under -race this surfaces +// any unsynchronised access; the terminal state (both repos tracked) +// must hold the single complete bridge generation, never a torn one. +func TestReconcileContractEdges_ConcurrentNoRaceOrTear(t *testing.T) { + providerRoot := setupGRPCProtoProviderRepo(t, "auth-service") + consumerRoot := setupGRPCGoConsumerRepo(t, "client-svc") + + tmpCfg := filepath.Join(t.TempDir(), "config.yaml") + consumerEntry := config.RepoEntry{Path: consumerRoot, Name: "client-svc"} + gc := &config.GlobalConfig{ + Repos: []config.RepoEntry{ + {Path: providerRoot, Name: "auth-service"}, + consumerEntry, + }, + } + gc.SetConfigPath(tmpCfg) + require.NoError(t, gc.Save()) + cm, err := config.NewConfigManager(tmpCfg) + require.NoError(t, err) + + g := graph.New() + mi := NewMultiIndexer(g, newMultiLangRegistry(), search.NewBM25(), cm, zap.NewNop()) + for _, entry := range cm.Global().Repos { + _, err := mi.TrackRepoCtx(context.Background(), entry) + require.NoError(t, err) + } + + bridgeID := "bridge::shared-test::shared-test::grpc::Users::GetUser" + require.NotNil(t, g.GetNode(bridgeID), "baseline bridge missing; bridges: %v", bridgeIDs(g)) + + var wg sync.WaitGroup + + // Writer: flip the matched set under the readers. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 12; i++ { + mi.UntrackRepo("client-svc") + _, _ = mi.TrackRepoCtx(context.Background(), consumerEntry) + } + }() + + // Readers: reconcile concurrently with the flips. + const readers = 12 + wg.Add(readers) + for i := 0; i < readers; i++ { + go func() { + defer wg.Done() + for j := 0; j < 8; j++ { + mi.ReconcileContractEdges() + } + }() + } + wg.Wait() + + // Settle on the final registry state with one last reconcile, then + // assert the terminal generation is complete and not duplicated. + mi.ReconcileContractEdges() + + got := bridgeIDs(g) + count := 0 + for _, id := range got { + if id == bridgeID { + count++ + } + } + require.Equal(t, 1, count, "expected exactly one bridge after concurrent reconciles; bridges: %v", got) + + // The fan-out must be complete: the exact-ID provider+consumer node + // plus the registration-site provider. A torn generation would miss + // one of these. + sides := map[string]string{} + for _, e := range g.GetOutEdges(bridgeID) { + side, _ := e.Meta["side"].(string) + sides[e.To] = side + } + assert.Equal(t, "both", sides["grpc::Users::GetUser"], "edge set torn by interleave: %v", sides) + assert.Equal(t, "provider", sides["grpc::Users"], "edge set torn by interleave: %v", sides) +} + +func bridgeIDs(g graph.Store) []string { + var out []string + for _, n := range collectBridgeNodes(g) { + out = append(out, n.ID) + } + return out +} diff --git a/internal/indexer/indexer.go b/internal/indexer/indexer.go index a53cbe10..67215aef 100644 --- a/internal/indexer/indexer.go +++ b/internal/indexer/indexer.go @@ -4302,6 +4302,7 @@ func (idx *Indexer) buildPerFileContractExtractors() ([]contracts.Extractor, map extractors := []contracts.Extractor{ &contracts.HTTPExtractor{ClientAliases: idx.config.HTTPClientAliases}, &contracts.GRPCExtractor{}, + &contracts.ThriftExtractor{}, &contracts.GraphQLExtractor{}, &contracts.OpenAPIExtractor{}, &contracts.TopicExtractor{}, @@ -4639,6 +4640,7 @@ func isRouteContractType(t contracts.ContractType) bool { switch t { case contracts.ContractHTTP, contracts.ContractGRPC, + contracts.ContractThrift, contracts.ContractGraphQL, contracts.ContractTopic, contracts.ContractWS: diff --git a/internal/indexer/multi.go b/internal/indexer/multi.go index ec0fa8d2..c99151ca 100644 --- a/internal/indexer/multi.go +++ b/internal/indexer/multi.go @@ -65,6 +65,19 @@ type MultiIndexer struct { logger *zap.Logger mu sync.RWMutex + // reconcileMu serialises ReconcileContractEdges end-to-end. The pass + // evicts the prior EdgeMatches / topic / bridge generation and mints + // a fresh one across many independent graph-store writes — it is NOT + // atomic. Several goroutines drive it concurrently (the periodic + // janitor's ReconcileAll, the file-watcher's IncrementalReindex, + // MCP-triggered track / untrack / index), and mi.mu is only taken in + // fine-grained spots inside, not across the whole pass. Without this + // lock two overlapping reconciles can interleave evict and mint and + // persist a stale generation (a bridge wiped by the other run's + // EvictFile after it was minted). A dedicated outer mutex keeps the + // pass self-consistent without widening mi.mu's scope. + reconcileMu sync.Mutex + // stitchProber / proxyBudget wire the cross-daemon proxy-edge feature: // when set by the daemon entry point (flag on), every CrossRepoResolver // this MultiIndexer builds mints proxy edges to remote-owned symbols @@ -2178,6 +2191,12 @@ func (mi *MultiIndexer) wrapperSourceReader() contracts.SourceReader { // via the `contracts check` tool and traversals stop at each service's // boundary. func (mi *MultiIndexer) ReconcileContractEdges() int { + // Serialise the whole pass: the evict-then-mint of EdgeMatches, topic + // edges, and the bridge subgraph spans many non-atomic store writes, + // and several goroutines call this concurrently (see reconcileMu). + mi.reconcileMu.Lock() + defer mi.reconcileMu.Unlock() + g := mi.Graph() if g == nil { return 0 @@ -2382,6 +2401,13 @@ func (mi *MultiIndexer) ReconcileContractEdges() int { } } + // Persist the matched contract groups as the bridge subgraph: one + // KindContractBridge node per group plus EdgeBridges fan-out to + // the participating contract nodes. The pass evicts the previous + // bridge generation internally, so it stays idempotent across + // reconciles and drops bridges whose contracts disappeared. + MaterializeContractBridges(g, result.Matched) + // Topic nodes whose producer and consumer edges all evaporated // since the previous reconcile remain in the graph as leaf // nodes — Graph has no public RemoveNode and the next reconcile diff --git a/internal/mcp/resources.go b/internal/mcp/resources.go index 1abe1f8a..04a514fb 100644 --- a/internal/mcp/resources.go +++ b/internal/mcp/resources.go @@ -182,6 +182,14 @@ func (s *Server) handleResourceSchema(_ context.Context, req mcp.ReadResourceReq - doc — a heading-delimited Markdown prose section; Name is the breadcrumb heading path, Meta["section_text"] holds the section body. Searchable via search_symbols corpus:docs. +- contract — an API contract record (HTTP route, gRPC/Thrift method, + topic, env var, …); ID is the canonical contract key, + Meta carries type/role/symbol_id. +- contract_bridge — one matched provider↔consumer contract group + (route / RPC method / topic) spanning every participating + repo; ID is bridge::, Meta carries + canonical_key, repos, provider_count, consumer_count. + Queried via the contracts tool's action=bridge. ## Edge Kinds - calls — function/method A calls function/method B @@ -193,6 +201,10 @@ func (s *Server) handleResourceSchema(_ context.Context, req mcp.ReadResourceReq - member_of — method/field A belongs to type B - instantiates — function A creates instance of type B - similar_to — function/method A is a near-duplicate (clone) of B +- provides — symbol A provides contract B; consumes is the inverse role +- matches — consumer symbol A resolves to provider symbol B across services +- bridges — contract_bridge A groups contract B (edge Meta["side"] = + provider|consumer|both) - package_workspace_member — package-manager workspace root A owns member package B - cross_repo_calls — calls edge whose target lives in another repo - cross_repo_implements — implements edge crossing a repo boundary diff --git a/internal/mcp/tools_contract_bridge.go b/internal/mcp/tools_contract_bridge.go new file mode 100644 index 00000000..e74c3ecf --- /dev/null +++ b/internal/mcp/tools_contract_bridge.go @@ -0,0 +1,676 @@ +package mcp + +import ( + "context" + "fmt" + "path" + "sort" + "strings" + + mcp "github.com/mark3labs/mcp-go/mcp" + + "github.com/zzet/gortex/internal/contracts" + "github.com/zzet/gortex/internal/graph" +) + +// rrfK is the standard reciprocal-rank-fusion smoothing constant: the +// fused contribution of an item ranked r in one signal is 1/(k+r). +// k=60 keeps single-signal rank-1 hits from drowning out items that +// place moderately well across several signals. +const rrfK = 60.0 + +// bridgeSideEntry is one provider- or consumer-side participant of a +// contract bridge group: where the contract lives and which symbol +// carries it. +type bridgeSideEntry struct { + ContractID string `json:"contract_id"` + Repo string `json:"repo,omitempty"` + SymbolID string `json:"symbol_id,omitempty"` + FilePath string `json:"file_path,omitempty"` + Line int `json:"line,omitempty"` +} + +// bridgeGroupResult is one ranked contract-bridge group: the bridge +// node plus its provider side, consumer side, and (in rank mode) the +// fused score with per-signal ranks. +type bridgeGroupResult struct { + BridgeID string `json:"bridge_id"` + CanonicalKey string `json:"canonical_key"` + ContractType string `json:"contract_type"` + Repos []string `json:"repos,omitempty"` + CrossRepo bool `json:"cross_repo,omitempty"` + ProviderCount int `json:"provider_count"` + ConsumerCount int `json:"consumer_count"` + Providers []bridgeSideEntry `json:"providers"` + Consumers []bridgeSideEntry `json:"consumers"` + FusedScore float64 `json:"fused_score,omitempty"` + SignalRanks map[string]int `json:"signal_ranks,omitempty"` + // MatchedVia lists the anchor contract IDs that reached this + // bridge in impact mode. + MatchedVia []string `json:"matched_via,omitempty"` +} + +// handleContractBridges serves `contracts action=bridge`: queries the +// persisted contract-bridge subgraph (KindContractBridge nodes + +// EdgeBridges fan-out materialised by the indexer's contract +// reconcile). +// +// Two modes: +// +// - rank (default): order bridge groups by reciprocal rank fusion +// over independent signal rankings — text match on the canonical +// key/contract names, path+repo match, graph adjacency to the +// given symbol, and bridge consumer degree. Pass `query` and/or +// `symbol`. +// - impact: given `symbol`, return every bridge reachable from the +// symbol's contracts (its own and its file's) — the cross-service +// blast radius of changing that symbol. +func (s *Server) handleContractBridges(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mode := req.GetString("mode", "rank") + query := strings.TrimSpace(req.GetString("query", "")) + symbolID := strings.TrimSpace(req.GetString("symbol", "")) + limit := req.GetInt("limit", 10) + if limit <= 0 { + limit = 10 + } + + allowed, err := s.resolveRepoFilter(ctx, req) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + groups := s.collectBridgeGroups(allowed) + if len(groups) == 0 { + return mcp.NewToolResultError("no contract bridges materialized — index repositories with matched provider/consumer contracts first"), nil + } + + switch mode { + case "impact": + return s.bridgeImpact(ctx, req, groups, symbolID) + case "rank", "": + return s.bridgeRank(ctx, req, groups, query, symbolID, limit) + default: + return mcp.NewToolResultError("unknown bridge mode: " + mode + " (expected: rank or impact)"), nil + } +} + +// collectBridgeGroups materialises the queryable view of every +// persisted bridge node, resolving the provider/consumer sides +// through the live contract registry (with the contract node's own +// Meta as fallback so the view survives a daemon restart that hasn't +// rehydrated the registry yet). A non-nil allowed set scopes the +// result to bridges touching at least one allowed repo. +func (s *Server) collectBridgeGroups(allowed map[string]bool) []*bridgeGroupResult { + registry := s.effectiveContractRegistry() + var out []*bridgeGroupResult + for n := range s.graph.NodesByKind(graph.KindContractBridge) { + if n == nil || n.Meta == nil { + continue + } + grp := &bridgeGroupResult{ + BridgeID: n.ID, + CanonicalKey: bridgeMetaString(n.Meta, "canonical_key"), + ContractType: bridgeMetaString(n.Meta, "contract_type"), + Repos: bridgeMetaStringSlice(n.Meta, "repos"), + } + if v, ok := n.Meta["cross_repo"].(bool); ok { + grp.CrossRepo = v + } + grp.ProviderCount = bridgeMetaInt(n.Meta, "provider_count") + grp.ConsumerCount = bridgeMetaInt(n.Meta, "consumer_count") + + // The bridge node is pinned to one (workspace, project) match + // boundary; the registry lookup below must be scoped to it so a + // same-ID contract record in an unrelated workspace doesn't leak + // into this bridge's participant list. WorkspaceID lives on the + // node; project rides in Meta. Both default to the node's repo + // prefix when unset (the same "missing → repo-name" rule the + // matcher uses), so older nodes without the Meta keys still scope. + bnd := bridgeBoundary{ + workspace: bridgeBoundarySlug(n.WorkspaceID, n.Meta, "workspace", n.RepoPrefix), + project: bridgeBoundarySlug("", n.Meta, "project", n.RepoPrefix), + } + + for _, e := range s.graph.GetOutEdges(n.ID) { + if e.Kind != graph.EdgeBridges { + continue + } + side := "provider" + if e.Meta != nil { + if v, _ := e.Meta["side"].(string); v != "" { + side = v + } + } + provs, cons := s.bridgeSideEntries(registry, e.To, side, bnd) + grp.Providers = append(grp.Providers, provs...) + grp.Consumers = append(grp.Consumers, cons...) + } + sortBridgeSide(grp.Providers) + sortBridgeSide(grp.Consumers) + + if allowed != nil && !bridgeTouchesRepos(grp, allowed) { + continue + } + out = append(out, grp) + } + sort.Slice(out, func(i, j int) bool { return out[i].BridgeID < out[j].BridgeID }) + return out +} + +// bridgeBoundary is the (workspace, project) the matcher paired a +// bridge's contracts inside. bridgeSideEntries scopes its registry +// lookup to it so a same-ID record in an unrelated workspace can't be +// listed as a participant of this bridge. +type bridgeBoundary struct { + workspace string + project string +} + +// matches reports whether a contract belongs to the boundary. An empty +// boundary slug (older bridge node without the Meta keys, or a +// genuinely empty workspace) matches everything so the read path stays +// backward-compatible. +func (b bridgeBoundary) matches(c contracts.Contract) bool { + if b.workspace != "" && c.EffectiveWorkspace() != b.workspace { + return false + } + if b.project != "" && c.EffectiveProject() != b.project { + return false + } + return true +} + +// bridgeBoundarySlug resolves a boundary slug from the bridge node: +// prefer the explicit value, then Meta[key], then the repo-prefix +// default the matcher falls back to when the slug is unset. +func bridgeBoundarySlug(explicit string, meta map[string]any, key, repoPrefix string) string { + if explicit != "" { + return explicit + } + if v := bridgeMetaString(meta, key); v != "" { + return v + } + return repoPrefix +} + +// bridgeSideEntries resolves the participant records for one +// EdgeBridges endpoint. side="both" expands to records on both roles +// (an exact-ID match collapses provider and consumer into one +// contract node). Records outside the bridge's match boundary are +// filtered out so a same-ID contract in an unrelated workspace is not +// listed as a participant. +func (s *Server) bridgeSideEntries(registry *contracts.Registry, contractID, side string, bnd bridgeBoundary) (provs, cons []bridgeSideEntry) { + wantProv := side == "provider" || side == "both" + wantCons := side == "consumer" || side == "both" + + if registry != nil { + records := registry.ByID(contractID) + for _, c := range records { + if !bnd.matches(c) { + continue + } + entry := bridgeSideEntry{ + ContractID: contractID, + Repo: c.RepoPrefix, + SymbolID: c.SymbolID, + FilePath: c.FilePath, + Line: c.Line, + } + switch { + case c.Role == contracts.RoleProvider && wantProv: + provs = append(provs, entry) + case c.Role == contracts.RoleConsumer && wantCons: + cons = append(cons, entry) + } + } + if len(provs) > 0 || len(cons) > 0 { + return provs, cons + } + } + + // Fallback: the contract node itself. It carries a single role's + // Meta (same-ID records collapse), so this is best-effort — the + // registry path above is authoritative whenever it has data. + n := s.graph.GetNode(contractID) + if n == nil { + return nil, nil + } + entry := bridgeSideEntry{ + ContractID: contractID, + Repo: n.RepoPrefix, + FilePath: n.FilePath, + } + if n.Meta != nil { + entry.SymbolID = bridgeMetaString(n.Meta, "symbol_id") + entry.Line = bridgeMetaInt(n.Meta, "line") + } + if wantProv { + provs = append(provs, entry) + } + if wantCons { + cons = append(cons, entry) + } + return provs, cons +} + +// bridgeRank orders bridge groups by reciprocal rank fusion over the +// independent signal rankings that apply to the request. +func (s *Server) bridgeRank(ctx context.Context, req mcp.CallToolRequest, groups []*bridgeGroupResult, query, symbolID string, limit int) (*mcp.CallToolResult, error) { + rankings := make(map[string][]string) + byID := make(map[string]*bridgeGroupResult, len(groups)) + for _, g := range groups { + byID[g.BridgeID] = g + } + + if query != "" { + tokens := bridgeQueryTokens(query) + rankings["text"] = rankBridges(groups, func(g *bridgeGroupResult) float64 { + return bridgeTextScore(g, tokens) + }) + rankings["path_repo"] = rankBridges(groups, func(g *bridgeGroupResult) float64 { + return bridgePathRepoScore(g, tokens) + }) + } + + if symbolID != "" { + symNode := s.graph.GetNode(symbolID) + if symNode == nil { + return mcp.NewToolResultError("symbol not found: " + symbolID), nil + } + anchors := s.bridgeAnchorContracts(symNode) + rankings["adjacency"] = rankBridges(groups, func(g *bridgeGroupResult) float64 { + return bridgeAdjacencyScore(g, symNode, anchors) + }) + } + + // Degree always participates: a heavily-consumed contract group is + // the more load-bearing answer at equal text/graph relevance. + rankings["degree"] = rankBridges(groups, func(g *bridgeGroupResult) float64 { + return float64(g.ConsumerCount) + }) + + fused, perSignal := reciprocalRankFusion(rankings, rrfK) + + ranked := make([]*bridgeGroupResult, 0, len(groups)) + for id, score := range fused { + g := byID[id] + if g == nil { + continue + } + g.FusedScore = score + g.SignalRanks = perSignal[id] + ranked = append(ranked, g) + } + sort.Slice(ranked, func(i, j int) bool { + if ranked[i].FusedScore != ranked[j].FusedScore { + return ranked[i].FusedScore > ranked[j].FusedScore + } + return ranked[i].BridgeID < ranked[j].BridgeID + }) + total := len(ranked) + if len(ranked) > limit { + ranked = ranked[:limit] + } + + if isCompact(req) { + var b strings.Builder + fmt.Fprintf(&b, "bridges: %d (showing %d)\n", total, len(ranked)) + for _, g := range ranked { + fmt.Fprintf(&b, " %.4f %s [%s] providers=%d consumers=%d repos=%s\n", + g.FusedScore, g.CanonicalKey, g.ContractType, + g.ProviderCount, g.ConsumerCount, strings.Join(g.Repos, ",")) + } + return mcp.NewToolResultText(b.String()), nil + } + + payload := map[string]any{ + "mode": "rank", + "groups": ranked, + "total": total, + "signals": signalNames(rankings), + } + return s.respondJSONOrTOON(ctx, req, payload) +} + +// bridgeImpact returns every bridge reachable from the symbol's +// contract surface: contracts attached to the symbol itself plus the +// contracts declared in its file, expanded through the persisted +// EdgeBridges in-edges. +func (s *Server) bridgeImpact(ctx context.Context, req mcp.CallToolRequest, groups []*bridgeGroupResult, symbolID string) (*mcp.CallToolResult, error) { + if symbolID == "" { + return mcp.NewToolResultError("symbol is required for bridge impact mode"), nil + } + symNode := s.graph.GetNode(symbolID) + if symNode == nil { + return mcp.NewToolResultError("symbol not found: " + symbolID), nil + } + anchors := s.bridgeAnchorContracts(symNode) + if len(anchors) == 0 { + payload := map[string]any{ + "mode": "impact", + "symbol": symbolID, + "groups": []*bridgeGroupResult{}, + "total": 0, + "note": "no contracts attached to this symbol or its file", + } + return s.respondJSONOrTOON(ctx, req, payload) + } + + byID := make(map[string]*bridgeGroupResult, len(groups)) + for _, g := range groups { + byID[g.BridgeID] = g + } + + matchedVia := make(map[string][]string) + for contractID := range anchors { + for _, e := range s.graph.GetInEdges(contractID) { + if e.Kind != graph.EdgeBridges { + continue + } + if _, ok := byID[e.From]; !ok { + continue + } + matchedVia[e.From] = append(matchedVia[e.From], contractID) + } + } + + impacted := make([]*bridgeGroupResult, 0, len(matchedVia)) + for bridgeID, via := range matchedVia { + g := byID[bridgeID] + sort.Strings(via) + g.MatchedVia = dedupeSortedStrings(via) + impacted = append(impacted, g) + } + sort.Slice(impacted, func(i, j int) bool { + if impacted[i].ConsumerCount != impacted[j].ConsumerCount { + return impacted[i].ConsumerCount > impacted[j].ConsumerCount + } + return impacted[i].BridgeID < impacted[j].BridgeID + }) + + if isCompact(req) { + var b strings.Builder + fmt.Fprintf(&b, "impacted bridges: %d (symbol %s)\n", len(impacted), symbolID) + for _, g := range impacted { + fmt.Fprintf(&b, " %s [%s] consumers=%d repos=%s via=%s\n", + g.CanonicalKey, g.ContractType, g.ConsumerCount, + strings.Join(g.Repos, ","), strings.Join(g.MatchedVia, ",")) + } + return mcp.NewToolResultText(b.String()), nil + } + + payload := map[string]any{ + "mode": "impact", + "symbol": symbolID, + "groups": impacted, + "total": len(impacted), + } + return s.respondJSONOrTOON(ctx, req, payload) +} + +// bridgeAnchorContracts returns the contract IDs anchored to a +// symbol: contracts attached to the symbol itself plus every contract +// declared in the symbol's file. This is the entry set a change to +// the symbol can reach without leaving its file. +func (s *Server) bridgeAnchorContracts(symNode *graph.Node) map[string]bool { + anchors := make(map[string]bool) + registry := s.effectiveContractRegistry() + if registry != nil { + for _, c := range registry.BySymbol(symNode.ID) { + anchors[c.ID] = true + } + if symNode.FilePath != "" { + for _, c := range registry.ByFile(symNode.FilePath) { + anchors[c.ID] = true + } + } + } + // Graph fallback: provides/consumes out-edges land on contract + // nodes directly. + for _, e := range s.graph.GetOutEdges(symNode.ID) { + if e.Kind == graph.EdgeProvides || e.Kind == graph.EdgeConsumes { + anchors[e.To] = true + } + } + return anchors +} + +// reciprocalRankFusion fuses independent per-signal rankings into one +// score per item: fused(i) = Σ_s 1/(k + rank_s(i)) over the signals +// that ranked the item (1-based ranks). Items absent from a signal +// contribute nothing for it. Returns the fused scores plus each +// item's per-signal rank for explainability. +func reciprocalRankFusion(rankings map[string][]string, k float64) (map[string]float64, map[string]map[string]int) { + fused := make(map[string]float64) + perSignal := make(map[string]map[string]int) + for signal, ids := range rankings { + for i, id := range ids { + rank := i + 1 + fused[id] += 1.0 / (k + float64(rank)) + if perSignal[id] == nil { + perSignal[id] = make(map[string]int) + } + perSignal[id][signal] = rank + } + } + return fused, perSignal +} + +// rankBridges scores every group and returns the IDs of those with a +// positive score, best-first. Ties break on bridge ID so rankings are +// deterministic. +func rankBridges(groups []*bridgeGroupResult, score func(*bridgeGroupResult) float64) []string { + type scored struct { + id string + score float64 + } + var hits []scored + for _, g := range groups { + if sc := score(g); sc > 0 { + hits = append(hits, scored{g.BridgeID, sc}) + } + } + sort.Slice(hits, func(i, j int) bool { + if hits[i].score != hits[j].score { + return hits[i].score > hits[j].score + } + return hits[i].id < hits[j].id + }) + ids := make([]string, len(hits)) + for i, h := range hits { + ids[i] = h.id + } + return ids +} + +// bridgeQueryTokens lowercases and splits a free-text query into +// alphanumeric terms. +func bridgeQueryTokens(query string) []string { + return strings.FieldsFunc(strings.ToLower(query), func(r rune) bool { + return (r < 'a' || r > 'z') && (r < '0' || r > '9') + }) +} + +// bridgeTextScore matches query tokens against the bridge's canonical +// key, contract type, and the contract IDs on both sides (those embed +// the service/method/path/topic vocabulary). Token-boundary hits +// weigh double a bare substring hit. +func bridgeTextScore(g *bridgeGroupResult, tokens []string) float64 { + var docParts []string + docParts = append(docParts, g.CanonicalKey, g.ContractType) + for _, e := range g.Providers { + docParts = append(docParts, e.ContractID, e.SymbolID) + } + for _, e := range g.Consumers { + docParts = append(docParts, e.ContractID, e.SymbolID) + } + doc := strings.ToLower(strings.Join(docParts, " ")) + docTokens := make(map[string]bool) + for _, t := range bridgeQueryTokens(doc) { + docTokens[t] = true + } + score := 0.0 + for _, tok := range tokens { + switch { + case docTokens[tok]: + score += 2 + case strings.Contains(doc, tok): + score++ + } + } + return score +} + +// bridgePathRepoScore matches query tokens against the bridge's repo +// spread and the file paths of its participants. +func bridgePathRepoScore(g *bridgeGroupResult, tokens []string) float64 { + var docParts []string + docParts = append(docParts, g.Repos...) + for _, e := range g.Providers { + docParts = append(docParts, e.FilePath, e.Repo) + } + for _, e := range g.Consumers { + docParts = append(docParts, e.FilePath, e.Repo) + } + doc := strings.ToLower(strings.Join(docParts, " ")) + docTokens := make(map[string]bool) + for _, t := range bridgeQueryTokens(doc) { + docTokens[t] = true + } + score := 0.0 + for _, tok := range tokens { + switch { + case docTokens[tok]: + score += 2 + case strings.Contains(doc, tok): + score++ + } + } + return score +} + +// bridgeAdjacencyScore scores a bridge by graph proximity to the +// query symbol: bridges directly anchored to one of the symbol's +// contracts rank above same-file participants, which rank above +// same-directory, which rank above same-repo. +func bridgeAdjacencyScore(g *bridgeGroupResult, symNode *graph.Node, anchors map[string]bool) float64 { + best := 0.0 + consider := func(e bridgeSideEntry) { + score := 0.0 + switch { + case anchors[e.ContractID]: + score = 8 + case e.FilePath != "" && e.FilePath == symNode.FilePath: + score = 4 + case e.FilePath != "" && symNode.FilePath != "" && + path.Dir(e.FilePath) == path.Dir(symNode.FilePath): + score = 2 + case e.Repo != "" && e.Repo == symNode.RepoPrefix: + score = 1 + } + if score > best { + best = score + } + } + for _, e := range g.Providers { + consider(e) + } + for _, e := range g.Consumers { + consider(e) + } + return best +} + +// bridgeTouchesRepos reports whether the bridge group involves at +// least one repo in the allowed set. +func bridgeTouchesRepos(g *bridgeGroupResult, allowed map[string]bool) bool { + for _, r := range g.Repos { + if allowed[r] { + return true + } + } + for _, e := range g.Providers { + if allowed[e.Repo] { + return true + } + } + for _, e := range g.Consumers { + if allowed[e.Repo] { + return true + } + } + return false +} + +func sortBridgeSide(entries []bridgeSideEntry) { + sort.Slice(entries, func(i, j int) bool { + if entries[i].ContractID != entries[j].ContractID { + return entries[i].ContractID < entries[j].ContractID + } + if entries[i].Repo != entries[j].Repo { + return entries[i].Repo < entries[j].Repo + } + if entries[i].FilePath != entries[j].FilePath { + return entries[i].FilePath < entries[j].FilePath + } + return entries[i].Line < entries[j].Line + }) +} + +func signalNames(rankings map[string][]string) []string { + names := make([]string, 0, len(rankings)) + for name := range rankings { + names = append(names, name) + } + sort.Strings(names) + return names +} + +func dedupeSortedStrings(in []string) []string { + out := in[:0] + var prev string + for i, s := range in { + if i > 0 && s == prev { + continue + } + out = append(out, s) + prev = s + } + return out +} + +// metaString / metaInt / metaStringSlice read loosely-typed Node.Meta +// values that may have round-tripped through a persistence backend +// (gob restores []string as-is; JSON paths may yield []any / float64). +func bridgeMetaString(meta map[string]any, key string) string { + v, _ := meta[key].(string) + return v +} + +func bridgeMetaInt(meta map[string]any, key string) int { + switch v := meta[key].(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + } + return 0 +} + +func bridgeMetaStringSlice(meta map[string]any, key string) []string { + switch v := meta[key].(type) { + case []string: + return v + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + out = append(out, s) + } + } + return out + } + return nil +} diff --git a/internal/mcp/tools_contract_bridge_test.go b/internal/mcp/tools_contract_bridge_test.go new file mode 100644 index 00000000..401ab87f --- /dev/null +++ b/internal/mcp/tools_contract_bridge_test.go @@ -0,0 +1,358 @@ +package mcp + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/config" + "github.com/zzet/gortex/internal/contracts" + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/indexer" + "github.com/zzet/gortex/internal/parser" + "github.com/zzet/gortex/internal/parser/languages" + "github.com/zzet/gortex/internal/query" + "github.com/zzet/gortex/internal/search" +) + +// TestReciprocalRankFusion_HandComputed pins the fusion math to a +// hand-worked example with k=60: +// +// signal a ranks: x(1), y(2), z(3) +// signal b ranks: y(1), x(2) +// signal c ranks: z(1) +// +// fused(x) = 1/61 + 1/62 +// fused(y) = 1/62 + 1/61 +// fused(z) = 1/63 + 1/61 +func TestReciprocalRankFusion_HandComputed(t *testing.T) { + rankings := map[string][]string{ + "a": {"x", "y", "z"}, + "b": {"y", "x"}, + "c": {"z"}, + } + fused, perSignal := reciprocalRankFusion(rankings, 60) + + wantX := 1.0/61 + 1.0/62 + wantY := 1.0/62 + 1.0/61 + wantZ := 1.0/63 + 1.0/61 + + assert.InDelta(t, wantX, fused["x"], 1e-12) + assert.InDelta(t, wantY, fused["y"], 1e-12) + assert.InDelta(t, wantZ, fused["z"], 1e-12) + + // x and y tie exactly; z trails because its second-best rank is 3. + assert.Equal(t, fused["x"], fused["y"]) + assert.Less(t, fused["z"], fused["x"]) + + assert.Equal(t, map[string]int{"a": 1, "b": 2}, perSignal["x"]) + assert.Equal(t, map[string]int{"a": 2, "b": 1}, perSignal["y"]) + assert.Equal(t, map[string]int{"a": 3, "c": 1}, perSignal["z"]) +} + +func TestReciprocalRankFusion_Empty(t *testing.T) { + fused, perSignal := reciprocalRankFusion(nil, 60) + assert.Empty(t, fused) + assert.Empty(t, perSignal) +} + +// setupBridgeWorkspaceRepo writes a repo dir with a shared-workspace +// .gortex.yaml so the two repos' contracts pair across the boundary. +func setupBridgeWorkspaceRepo(t *testing.T, root, name string, files map[string]string) string { + t.Helper() + dir := filepath.Join(root, name) + require.NoError(t, os.MkdirAll(dir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(dir, ".gortex.yaml"), + []byte("workspace: acme\nproject: acme\n"), 0o644)) + for rel, content := range files { + full := filepath.Join(dir, rel) + require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) + require.NoError(t, os.WriteFile(full, []byte(content), 0o644)) + } + return dir +} + +// newBridgeTestServer indexes a provider repo (Gin routes) and a +// consumer repo (http.Get of the same path) into one multi-repo graph +// and returns an MCP server whose reconcile pass has materialized the +// bridge subgraph. +func newBridgeTestServer(t *testing.T) *Server { + t.Helper() + root := t.TempDir() + + providerRepo := setupBridgeWorkspaceRepo(t, root, "provider-svc", map[string]string{ + "go.mod": "module example.com/provider-svc\n\ngo 1.21\n", + "main.go": `package main + +import "github.com/gin-gonic/gin" + +func setupRoutes(r *gin.Engine) { + r.GET("/api/users", listUsers) +} + +func listUsers() {} +`, + }) + consumerRepo := setupBridgeWorkspaceRepo(t, root, "consumer-svc", map[string]string{ + "go.mod": "module example.com/consumer-svc\n\ngo 1.21\n", + "client.go": `package main + +import "net/http" + +func fetchUsers() { + http.Get("http://api.example.com/api/users") +} +`, + }) + + tmpCfg := filepath.Join(t.TempDir(), "config.yaml") + gc := &config.GlobalConfig{ + Repos: []config.RepoEntry{ + {Path: providerRepo, Name: "provider-svc"}, + {Path: consumerRepo, Name: "consumer-svc"}, + }, + } + gc.SetConfigPath(tmpCfg) + require.NoError(t, gc.Save()) + + cm, err := config.NewConfigManager(tmpCfg) + require.NoError(t, err) + + preg := parser.NewRegistry() + languages.RegisterAll(preg) + + g := graph.New() + mi := indexer.NewMultiIndexer(g, preg, search.NewBM25(), cm, zap.NewNop()) + _, err = mi.IndexAll() + require.NoError(t, err) + + eng := query.NewEngine(g) + return NewServer(eng, g, nil, nil, zap.NewNop(), nil, MultiRepoOptions{ + ConfigManager: cm, + MultiIndexer: mi, + }) +} + +type bridgeTestPayload struct { + Mode string `json:"mode"` + Total int `json:"total"` + Symbol string `json:"symbol"` + Groups []struct { + BridgeID string `json:"bridge_id"` + CanonicalKey string `json:"canonical_key"` + ContractType string `json:"contract_type"` + Repos []string `json:"repos"` + CrossRepo bool `json:"cross_repo"` + ProviderCount int `json:"provider_count"` + ConsumerCount int `json:"consumer_count"` + FusedScore float64 `json:"fused_score"` + SignalRanks map[string]int `json:"signal_ranks"` + MatchedVia []string `json:"matched_via"` + Providers []struct { + ContractID string `json:"contract_id"` + Repo string `json:"repo"` + SymbolID string `json:"symbol_id"` + FilePath string `json:"file_path"` + } `json:"providers"` + Consumers []struct { + ContractID string `json:"contract_id"` + Repo string `json:"repo"` + SymbolID string `json:"symbol_id"` + } `json:"consumers"` + } `json:"groups"` +} + +func callBridge(t *testing.T, srv *Server, args map[string]any) bridgeTestPayload { + t.Helper() + req := mcplib.CallToolRequest{} + if args == nil { + args = map[string]any{} + } + args["action"] = "bridge" + req.Params.Arguments = args + res, err := srv.handleContracts(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, res) + require.False(t, res.IsError, "bridge call failed: %+v", res.Content) + + var payload bridgeTestPayload + require.NoError(t, json.Unmarshal([]byte(extractTextFromContent(t, res.Content)), &payload)) + return payload +} + +// TestHandleContractBridges_RankMode runs the fused ranking over a +// real two-repo index and asserts the grouped response shape: bridge +// identity, both sides resolved with repo + symbol + location, and +// per-signal ranks riding on the fused score. +func TestHandleContractBridges_RankMode(t *testing.T) { + srv := newBridgeTestServer(t) + + payload := callBridge(t, srv, map[string]any{"query": "users api"}) + require.Equal(t, "rank", payload.Mode) + require.GreaterOrEqual(t, payload.Total, 1) + require.NotEmpty(t, payload.Groups) + + top := payload.Groups[0] + assert.Equal(t, "bridge::acme::acme::http::GET::/api/users", top.BridgeID) + assert.Equal(t, "GET /api/users", top.CanonicalKey) + assert.Equal(t, "http", top.ContractType) + assert.Equal(t, []string{"consumer-svc", "provider-svc"}, top.Repos) + assert.True(t, top.CrossRepo) + assert.Greater(t, top.FusedScore, 0.0) + assert.Contains(t, top.SignalRanks, "text") + assert.Contains(t, top.SignalRanks, "degree") + + require.NotEmpty(t, top.Providers, "provider side must be resolved") + prov := top.Providers[0] + assert.Equal(t, "provider-svc", prov.Repo) + assert.Equal(t, "provider-svc/main.go::listUsers", prov.SymbolID) + assert.Equal(t, "provider-svc/main.go", prov.FilePath) + + require.NotEmpty(t, top.Consumers, "consumer side must be resolved") + cons := top.Consumers[0] + assert.Equal(t, "consumer-svc", cons.Repo) + assert.Equal(t, "consumer-svc/client.go::fetchUsers", cons.SymbolID) +} + +// TestHandleContractBridges_RankMode_SymbolAdjacency: passing a symbol +// adds the graph-adjacency signal to the fusion. +func TestHandleContractBridges_RankMode_SymbolAdjacency(t *testing.T) { + srv := newBridgeTestServer(t) + + payload := callBridge(t, srv, map[string]any{ + "query": "users", + "symbol": "consumer-svc/client.go::fetchUsers", + }) + require.NotEmpty(t, payload.Groups) + top := payload.Groups[0] + assert.Equal(t, "bridge::acme::acme::http::GET::/api/users", top.BridgeID) + assert.Contains(t, top.SignalRanks, "adjacency", + "symbol-anchored call must rank the adjacency signal: %v", top.SignalRanks) + assert.Equal(t, 1, top.SignalRanks["adjacency"]) +} + +// TestHandleContractBridges_ImpactMode: from the consumer symbol, the +// bridge it participates in must surface as cross-service blast +// radius, with the anchoring contract recorded. +func TestHandleContractBridges_ImpactMode(t *testing.T) { + srv := newBridgeTestServer(t) + + payload := callBridge(t, srv, map[string]any{ + "mode": "impact", + "symbol": "consumer-svc/client.go::fetchUsers", + }) + require.Equal(t, "impact", payload.Mode) + require.Equal(t, 1, payload.Total) + top := payload.Groups[0] + assert.Equal(t, "bridge::acme::acme::http::GET::/api/users", top.BridgeID) + assert.Equal(t, []string{"http::GET::/api/users"}, top.MatchedVia) + + // The provider symbol works symmetrically — its route's bridge is + // its blast radius too. + payload = callBridge(t, srv, map[string]any{ + "mode": "impact", + "symbol": "provider-svc/main.go::listUsers", + }) + require.Equal(t, 1, payload.Total) + assert.Equal(t, "bridge::acme::acme::http::GET::/api/users", payload.Groups[0].BridgeID) +} + +// TestHandleContractBridges_Errors covers the argument-validation +// paths: impact without symbol, unknown mode, unknown symbol. +func TestHandleContractBridges_Errors(t *testing.T) { + srv := newBridgeTestServer(t) + + for name, args := range map[string]map[string]any{ + "impact without symbol": {"action": "bridge", "mode": "impact"}, + "unknown mode": {"action": "bridge", "mode": "sideways"}, + "unknown symbol": {"action": "bridge", "mode": "impact", "symbol": "nope.go::missing"}, + } { + req := mcplib.CallToolRequest{} + req.Params.Arguments = args + res, err := srv.handleContracts(context.Background(), req) + require.NoError(t, err, name) + assert.True(t, res.IsError, "%s should return a tool error", name) + } +} + +// TestHandleContractBridges_RepoScope: the repo filter keeps bridges +// touching the named repo and drops the rest. +func TestHandleContractBridges_RepoScope(t *testing.T) { + srv := newBridgeTestServer(t) + + payload := callBridge(t, srv, map[string]any{"query": "users", "repo": "provider-svc"}) + require.NotEmpty(t, payload.Groups, "bridge touches provider-svc, must stay in scope") + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{"action": "bridge", "query": "users", "repo": "unrelated-repo"} + res, err := srv.handleContracts(context.Background(), req) + require.NoError(t, err) + assert.True(t, res.IsError, "no bridges touch unrelated-repo — expect the empty-scope error") +} + +// TestCollectBridgeGroups_BoundaryScopedParticipants guards the read- +// side half of the workspace-isolation fix: registry.ByID returns every +// same-ID record across all workspaces, so the participant resolver must +// filter to the bridge's own (workspace, project) boundary. Without the +// filter, a bridge for workspace "acme" would list "globex"'s provider +// as a participant. +func TestCollectBridgeGroups_BoundaryScopedParticipants(t *testing.T) { + g := graph.New() + + // One contract node ID, but two registry records in different + // workspaces — exactly the shape registry.ByID returns. + contractID := "http::GET::/api/users" + g.AddNode(&graph.Node{ID: contractID, Kind: graph.KindContract, Name: contractID, Language: "contract"}) + + reg := contracts.NewRegistry() + reg.Add(contracts.Contract{ + ID: contractID, Type: contracts.ContractHTTP, Role: contracts.RoleProvider, + SymbolID: "acme-api/routes.go::list", FilePath: "acme-api/routes.go", Line: 10, + RepoPrefix: "acme-api", WorkspaceID: "acme", ProjectID: "acme", + }) + reg.Add(contracts.Contract{ + ID: contractID, Type: contracts.ContractHTTP, Role: contracts.RoleProvider, + SymbolID: "globex-api/routes.go::list", FilePath: "globex-api/routes.go", Line: 22, + RepoPrefix: "globex-api", WorkspaceID: "globex", ProjectID: "globex", + }) + + // A bridge node pinned to the acme boundary, linked to the shared + // contract node — the materialiser's output for the acme group. + g.AddNode(&graph.Node{ + ID: "bridge::acme::acme::" + contractID, + Kind: graph.KindContractBridge, + Name: "GET /api/users", + FilePath: indexer.ContractBridgeFilePath, + Language: "contract", + RepoPrefix: "acme-api", + WorkspaceID: "acme", + Meta: map[string]any{ + "contract_type": "http", "canonical_key": "GET /api/users", + "workspace": "acme", "project": "acme", + "repos": []string{"acme-api"}, "provider_count": 1, "consumer_count": 0, + }, + }) + g.AddEdge(&graph.Edge{ + From: "bridge::acme::acme::" + contractID, To: contractID, + Kind: graph.EdgeBridges, Meta: map[string]any{"side": "provider"}, + }) + + eng := query.NewEngine(g) + srv := NewServer(eng, g, nil, nil, zap.NewNop(), nil) + srv.SetContractRegistry(reg) + + groups := srv.collectBridgeGroups(nil) + require.Len(t, groups, 1) + grp := groups[0] + require.Len(t, grp.Providers, 1, + "only the acme-boundary provider must be listed, not globex's same-ID record: %+v", grp.Providers) + assert.Equal(t, "acme-api", grp.Providers[0].Repo) + assert.Equal(t, "acme-api/routes.go::list", grp.Providers[0].SymbolID) +} diff --git a/internal/mcp/tools_enhancements.go b/internal/mcp/tools_enhancements.go index f2d25875..f73a17d3 100644 --- a/internal/mcp/tools_enhancements.go +++ b/internal/mcp/tools_enhancements.go @@ -274,14 +274,17 @@ func (s *Server) registerEnhancementTools() { // contracts — unified contracts tool (list + check + validate) s.addTool( mcp.NewTool("contracts", - mcp.WithDescription("API contracts tool. action=list (default): lists detected contracts (HTTP, gRPC, GraphQL, topics, WebSocket, env, OpenAPI). action=check: detects orphan providers/consumers across repos. action=validate: diffs provider↔consumer request/response shapes and flags breaking/warning/info issues.\n\nDEFAULT SCOPE for list: auto-scopes to the active project's repos and hides dependency-origin contracts (type=dependency, vendored paths like vendor/, node_modules/). The response reports other_repos (count of contracts filtered out of scope) and dependencies_skipped (count of dep contracts hidden). To widen scope, pass repo=, project=, ref=, or all_repos=true. To include dependency contracts, pass include_deps=true."), - mcp.WithString("action", mcp.Description("list (default), check, or validate")), + mcp.WithDescription("API contracts tool. action=list (default): lists detected contracts (HTTP, gRPC, Thrift, GraphQL, topics, WebSocket, env, OpenAPI). action=check: detects orphan providers/consumers across repos. action=validate: diffs provider↔consumer request/response shapes and flags breaking/warning/info issues. action=bridge: queries the persisted contract-bridge subgraph — one node per matched provider↔consumer group (HTTP route, gRPC/Thrift method, pub/sub topic) — ranked by reciprocal rank fusion over text, path/repo, graph-adjacency, and consumer-degree signals (mode=rank, pass query and/or symbol), or expanded from a symbol into its cross-service blast radius (mode=impact, pass symbol).\n\nDEFAULT SCOPE for list: auto-scopes to the active project's repos and hides dependency-origin contracts (type=dependency, vendored paths like vendor/, node_modules/). The response reports other_repos (count of contracts filtered out of scope) and dependencies_skipped (count of dep contracts hidden). To widen scope, pass repo=, project=, ref=, or all_repos=true. To include dependency contracts, pass include_deps=true."), + mcp.WithString("action", mcp.Description("list (default), check, validate, or bridge")), mcp.WithString("repo", mcp.Description("Filter by repository prefix")), mcp.WithString("project", mcp.Description("Filter to repositories in a specific project (resolves to the project's repo set)")), mcp.WithString("ref", mcp.Description("Filter to repositories tagged with this ref")), mcp.WithBoolean("all_repos", mcp.Description("(list) Disable active-project auto-scope; return contracts from every indexed repo. Default false.")), mcp.WithBoolean("include_deps", mcp.Description("(list) Include type=dependency contracts and contracts from vendored paths (vendor/, node_modules/, Pods/, .venv/). Default false.")), - mcp.WithString("type", mcp.Description("(list) Filter by type: http, grpc, graphql, topic, ws, env, openapi, dependency")), + mcp.WithString("type", mcp.Description("(list) Filter by type: http, grpc, thrift, graphql, topic, ws, env, openapi, dependency")), + mcp.WithString("query", mcp.Description("(bridge) Free-text query ranked against bridge canonical keys, contract names, repos, and file paths")), + mcp.WithString("symbol", mcp.Description("(bridge) Symbol ID anchoring the graph-adjacency signal (mode=rank) or the blast-radius expansion (mode=impact)")), + mcp.WithString("mode", mcp.Description("(bridge) rank (default) or impact")), mcp.WithString("role", mcp.Description("(list) Filter by role: provider or consumer")), mcp.WithNumber("limit", mcp.Description("(list) Max contracts per page (default: 200)")), mcp.WithString("cursor", mcp.Description("(list) Opaque pagination cursor from a previous `next_cursor` to fetch the next page.")), @@ -3339,8 +3342,10 @@ func (s *Server) handleContracts(ctx context.Context, req mcp.CallToolRequest) ( return s.handleCheckContracts(ctx, req) case "validate": return s.handleValidateContracts(ctx, req) + case "bridge": + return s.handleContractBridges(ctx, req) default: - return mcp.NewToolResultError("unknown contracts action: " + action + " (expected: list, check, or validate)"), nil + return mcp.NewToolResultError("unknown contracts action: " + action + " (expected: list, check, validate, or bridge)"), nil } } From 68edf88dbf5a3c911cfd06d6b26e7462b2363874 Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Sat, 13 Jun 2026 08:10:50 +0200 Subject: [PATCH 5/5] feat(semantic): in-process tree-sitter type resolvers for six languages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New internal/semantic/tstypes package: per-language type resolvers for Java, Python, Ruby, Rust, TypeScript/JavaScript, and C# that run fully in-process over the shared tree-sitter AST — no external language server. A table-driven engine builds per-file scope graphs, binds declared and constructor types, propagates them through local assignments, resolves receivers against the graph's method sets via import-aware cross-file lookup, and synthesizes implements/extends edges per language. Resolutions are stamped at the ast_resolved tier with semantic_source -types, never downgrading a stronger edge; ambiguous receivers are skipped. Enrichment is scoped to the repo being enriched, runs its graph-apply phase under the resolve mutex, persists full edge provenance on disk backends, and wires single-file incremental enrichment. Providers register as supplemental in the semantic manager and coexist with LSP providers. --- internal/graph/store.go | 15 + internal/graph/store_sqlite/store.go | 31 +- internal/indexer/indexer.go | 24 +- internal/indexer/semantic_incremental_test.go | 95 +++ internal/semantic/manager.go | 96 ++- internal/semantic/provider.go | 13 + internal/semantic/tstypes/csharp.go | 231 ++++++ internal/semantic/tstypes/csharp_test.go | 192 +++++ internal/semantic/tstypes/engine.go | 737 ++++++++++++++++++ .../semantic/tstypes/fixes_multirepo_test.go | 152 ++++ .../semantic/tstypes/fixes_sqlite_test.go | 137 ++++ internal/semantic/tstypes/fixes_test.go | 169 ++++ internal/semantic/tstypes/harness_test.go | 131 ++++ internal/semantic/tstypes/java.go | 242 ++++++ internal/semantic/tstypes/java_test.go | 303 +++++++ internal/semantic/tstypes/manager_test.go | 196 +++++ internal/semantic/tstypes/provider.go | 211 +++++ internal/semantic/tstypes/python.go | 283 +++++++ internal/semantic/tstypes/python_test.go | 194 +++++ internal/semantic/tstypes/ruby.go | 237 ++++++ internal/semantic/tstypes/ruby_test.go | 161 ++++ internal/semantic/tstypes/rust.go | 285 +++++++ internal/semantic/tstypes/rust_test.go | 232 ++++++ internal/semantic/tstypes/scope.go | 394 ++++++++++ internal/semantic/tstypes/spec.go | 235 ++++++ internal/semantic/tstypes/typescript.go | 286 +++++++ internal/semantic/tstypes/typescript_test.go | 216 +++++ internal/serverstack/shared_server.go | 10 + 28 files changed, 5498 insertions(+), 10 deletions(-) create mode 100644 internal/indexer/semantic_incremental_test.go create mode 100644 internal/semantic/tstypes/csharp.go create mode 100644 internal/semantic/tstypes/csharp_test.go create mode 100644 internal/semantic/tstypes/engine.go create mode 100644 internal/semantic/tstypes/fixes_multirepo_test.go create mode 100644 internal/semantic/tstypes/fixes_sqlite_test.go create mode 100644 internal/semantic/tstypes/fixes_test.go create mode 100644 internal/semantic/tstypes/harness_test.go create mode 100644 internal/semantic/tstypes/java.go create mode 100644 internal/semantic/tstypes/java_test.go create mode 100644 internal/semantic/tstypes/manager_test.go create mode 100644 internal/semantic/tstypes/provider.go create mode 100644 internal/semantic/tstypes/python.go create mode 100644 internal/semantic/tstypes/python_test.go create mode 100644 internal/semantic/tstypes/ruby.go create mode 100644 internal/semantic/tstypes/ruby_test.go create mode 100644 internal/semantic/tstypes/rust.go create mode 100644 internal/semantic/tstypes/rust_test.go create mode 100644 internal/semantic/tstypes/scope.go create mode 100644 internal/semantic/tstypes/spec.go create mode 100644 internal/semantic/tstypes/typescript.go create mode 100644 internal/semantic/tstypes/typescript_test.go diff --git a/internal/graph/store.go b/internal/graph/store.go index e8bd259e..6a20309f 100644 --- a/internal/graph/store.go +++ b/internal/graph/store.go @@ -959,6 +959,21 @@ type FileMtimeDeleter interface { DeleteFileMtimes(repoPrefix string, paths []string) error } +// EdgePersister is an optional capability backends MAY implement to +// durably rewrite the mutable attribute columns (Confidence, +// ConfidenceLabel, Origin, Tier, Meta) of an edge already present in the +// graph, identified by its full logical key (From, To, Kind, FilePath, +// Line). The in-memory backend never needs it — GetOutEdges hands back +// the live *Edge pointer, so an in-place field mutation is already +// durable. A disk backend, by contrast, returns a detached row copy: +// mutating Confidence / Meta on that copy and calling SetEdgeProvenance +// (which only writes Origin + Tier) silently drops the rest. A pass that +// confirms an edge's full provenance bundle calls PersistEdgeAttributes +// so every backend keeps it. A no matching row is a no-op. +type EdgePersister interface { + PersistEdgeAttributes(e *Edge) +} + // CloneShingleWriter is an optional capability backends MAY implement // to persist each function/method node's MinHash shingle set (a // []uint64) keyed by node id. Lifting this state into the same backend diff --git a/internal/graph/store_sqlite/store.go b/internal/graph/store_sqlite/store.go index a4448fa8..2fd62bb4 100644 --- a/internal/graph/store_sqlite/store.go +++ b/internal/graph/store_sqlite/store.go @@ -106,6 +106,7 @@ type Store struct { stmtEdgeCount *sql.Stmt stmtRemoveEdge *sql.Stmt stmtUpdateEdgeOrigin *sql.Stmt + stmtUpdateEdgeAttrs *sql.Stmt stmtSelectEdgeOrigin *sql.Stmt stmtDeleteEdgeByKey *sql.Stmt @@ -276,7 +277,7 @@ func (s *Store) Close() error { s.stmtInsertEdge, s.stmtOutEdges, s.stmtInEdges, s.stmtRepoEdges, s.stmtAllEdges, s.stmtEdgeCount, s.stmtRemoveEdge, - s.stmtUpdateEdgeOrigin, s.stmtSelectEdgeOrigin, s.stmtDeleteEdgeByKey, + s.stmtUpdateEdgeOrigin, s.stmtUpdateEdgeAttrs, s.stmtSelectEdgeOrigin, s.stmtDeleteEdgeByKey, s.stmtSelectFileNodeIDs, s.stmtSelectRepoNodeIDs, s.stmtDeleteNodeByFile, s.stmtDeleteNodeByRepo, } @@ -381,6 +382,8 @@ func (s *Store) prepare() error { `SELECT origin FROM edges WHERE from_id = ? AND to_id = ? AND kind = ? AND file_path = ? AND line = ?`) prep(&s.stmtUpdateEdgeOrigin, `UPDATE edges SET origin = ?, tier = ? WHERE from_id = ? AND to_id = ? AND kind = ? AND file_path = ? AND line = ?`) + prep(&s.stmtUpdateEdgeAttrs, + `UPDATE edges SET confidence = ?, confidence_label = ?, origin = ?, tier = ?, meta = ? WHERE from_id = ? AND to_id = ? AND kind = ? AND file_path = ? AND line = ?`) prep(&s.stmtDeleteEdgeByKey, `DELETE FROM edges WHERE from_id = ? AND to_id = ? AND kind = ? AND file_path = ? AND line = ?`) @@ -663,6 +666,32 @@ func (s *Store) SetEdgeProvenance(e *graph.Edge, newOrigin string) bool { return true } +// PersistEdgeAttributes durably rewrites the mutable attribute columns +// (confidence, confidence_label, origin, tier, meta) of the edge row +// identified by e's full logical key. It is the disk-backend counterpart +// to the in-memory store's "mutate the live *Edge in place" behaviour: a +// pass that confirms an edge's full provenance bundle (not just origin) +// calls this so the confidence / label / meta survive a reload. A missing +// row is a silent no-op (UPDATE ... WHERE matches nothing). +func (s *Store) PersistEdgeAttributes(e *graph.Edge) { + if e == nil { + return + } + metaBlob, err := encodeMeta(e.Meta) + if err != nil { + panicOnFatal(err) + return + } + s.writeMu.Lock() + defer s.writeMu.Unlock() + if _, err := s.stmtUpdateEdgeAttrs.Exec( + e.Confidence, e.ConfidenceLabel, e.Origin, e.Tier, metaBlob, + e.From, e.To, string(e.Kind), e.FilePath, e.Line, + ); err != nil { + panicOnFatal(err) + } +} + // ReindexEdge updates the stored row after e.To has been mutated from // oldTo to e.To. Implemented as delete-old + insert-new under the // same write lock (SQLite's UNIQUE constraint on (from,to,kind,file, diff --git a/internal/indexer/indexer.go b/internal/indexer/indexer.go index 67215aef..51258f78 100644 --- a/internal/indexer/indexer.go +++ b/internal/indexer/indexer.go @@ -842,7 +842,9 @@ func (idx *Indexer) RunDeferredPasses(ctx context.Context) { if idx.semanticMgr != nil && idx.semanticMgr.Enabled() && idx.semanticMgr.HasProviders() { reporter.Report("semantic enrichment", 0, 0) - roots := map[string]string{"default": idx.rootPath} + // Key by the repo prefix so a repo-scoped provider can scope file + // selection to this repo (empty in single-repo mode). + roots := map[string]string{idx.repoPrefix: idx.rootPath} results, err := idx.semanticMgr.EnrichAll(idx.graph, roots) if err != nil { idx.logger.Warn("semantic enrichment failed", zap.Error(err)) @@ -2429,7 +2431,9 @@ func (idx *Indexer) IndexCtx(ctx context.Context, root string) (result *IndexRes // Semantic enrichment (SCIP, go/types, LSP). if idx.semanticMgr != nil && idx.semanticMgr.Enabled() && idx.semanticMgr.HasProviders() { reporter.Report("semantic enrichment", 0, 0) - roots := map[string]string{"default": absRoot} + // Key by the repo prefix so a repo-scoped provider can scope + // file selection to this repo (empty in single-repo mode). + roots := map[string]string{idx.repoPrefix: absRoot} results, err := idx.semanticMgr.EnrichAll(idx.graph, roots) if err != nil { idx.logger.Warn("semantic enrichment failed", zap.Error(err)) @@ -2841,8 +2845,6 @@ func (idx *Indexer) indexFile(filePath string, resolve bool) error { detectClonesAndEmitEdges(idx.graph, idx.repoPrefix, idx.cloneThreshold()) } } - // Persist this file's resolved-reference facts to the durable sidecar - // (delete-then-set so removed references don't linger). No-op on the // in-memory backend. Skipped for a quarantined / timed-out / // minified file: its synthetic result yields no facts, so a // delete-then-set would durably drop the file's real facts on a @@ -2856,6 +2858,20 @@ func (idx *Indexer) indexFile(filePath string, resolve bool) error { // that referenced it — bounded, synchronous, and gated on the // signature delta so a body-only edit fans out to nothing. idx.reresolveAffectedBy(graphPath, abSnap, result.Nodes) + + // Incremental semantic enrichment for this single file. Mirrors the + // full-index EnrichAll call but scoped to the saved file, so a + // watcher save re-runs the type resolvers (and any watch-enabled + // LSP / compiler provider) instead of leaving the file's edges at + // their pre-enrichment tier until the next full reindex. Gated + // internally on Config.EnrichOnWatch; a no-op when disabled. + if idx.semanticMgr != nil && idx.semanticMgr.Enabled() && idx.semanticMgr.HasProviders() { + if _, err := idx.semanticMgr.EnrichFile(idx.graph, idx.rootPath, graphPath); err != nil { + idx.logger.Debug("indexer: incremental semantic enrichment failed", + zap.String("file", graphPath), + zap.Error(err)) + } + } } } diff --git a/internal/indexer/semantic_incremental_test.go b/internal/indexer/semantic_incremental_test.go new file mode 100644 index 00000000..f3cf5146 --- /dev/null +++ b/internal/indexer/semantic_incremental_test.go @@ -0,0 +1,95 @@ +package indexer + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/config" + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/parser" + "github.com/zzet/gortex/internal/parser/languages" + "github.com/zzet/gortex/internal/semantic" + "github.com/zzet/gortex/internal/semantic/tstypes" +) + +// A single-file save must re-run incremental semantic enrichment so the +// file's edges are re-confirmed by the in-process type resolvers, rather +// than staying at their pre-enrichment tier until the next full reindex. +// This exercises the indexFile -> Manager.EnrichFile wiring end to end. +func TestIndexFile_RunsIncrementalSemanticEnrichment(t *testing.T) { + dir := t.TempDir() + + svc := filepath.Join(dir, "a", "Svc.java") + app := filepath.Join(dir, "b", "App.java") + require.NoError(t, os.MkdirAll(filepath.Dir(svc), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Dir(app), 0o755)) + require.NoError(t, os.WriteFile(svc, []byte(`package a; + +public class Svc { + public void run() {} +} +`), 0o644)) + require.NoError(t, os.WriteFile(app, []byte(`package b; + +import a.Svc; + +public class App { + public void handle(Svc s) { + } +} +`), 0o644)) + + g := graph.New() + reg := parser.NewRegistry() + languages.RegisterAll(reg) + cfg := config.Default().Index + cfg.Workers = 2 + idx := New(g, reg, cfg, zap.NewNop()) + + // Semantic manager with the in-process type resolvers and watch-mode + // enrichment enabled. + mgr := semantic.NewManager(semantic.Config{Enabled: true, EnrichOnWatch: true}, zap.NewNop()) + for _, p := range tstypes.DefaultProviders(zap.NewNop()) { + mgr.RegisterProvider(p) + } + idx.SetSemanticManager(mgr) + + if _, err := idx.Index(dir); err != nil { + t.Fatalf("index: %v", err) + } + + caller := "b/App.java::App.handle" + target := "a/Svc.java::Svc.run" + + // At this point handle() has no body call. Now edit App.java to add a + // receiver-qualified call and re-index just that file. + require.NoError(t, os.WriteFile(app, []byte(`package b; + +import a.Svc; + +public class App { + public void handle(Svc s) { + s.run(); + } +} +`), 0o644)) + + require.NoError(t, idx.IndexFile(app)) + + // The incremental semantic pass must have resolved + stamped the call. + var e *graph.Edge + for _, oe := range g.GetOutEdges(caller) { + if oe.Kind == graph.EdgeCalls && oe.To == target { + e = oe + break + } + } + require.NotNilf(t, e, "incremental semantic enrichment did not resolve the call; edges: %v", g.GetOutEdges(caller)) + require.Equal(t, "ast_resolved", e.Origin, "edge not stamped by the in-process type resolver") + require.NotNil(t, e.Meta) + require.Equal(t, "java-types", e.Meta["semantic_source"]) +} diff --git a/internal/semantic/manager.go b/internal/semantic/manager.go index e251e15b..25e73cda 100644 --- a/internal/semantic/manager.go +++ b/internal/semantic/manager.go @@ -44,6 +44,23 @@ type LSPRouter interface { Close() error } +// SupplementalProvider is an optional interface a Provider MAY +// implement to opt out of the per-language arbitration: instead of +// competing for a language slot it always runs (when available and not +// config-disabled) in addition to whichever provider won the slot. +// The in-process tree-sitter type resolvers implement it so they +// coexist with LSP / SCIP providers — their AST-grade provenance never +// downgrades a compiler-grade edge, and a language with no external +// tooling still gets type-aware enrichment. +type SupplementalProvider interface { + Supplemental() bool +} + +func isSupplemental(p Provider) bool { + sp, ok := p.(SupplementalProvider) + return ok && sp.Supplemental() +} + // Manager orchestrates multiple semantic providers and coordinates enrichment. type Manager struct { providers []Provider @@ -183,9 +200,35 @@ func (m *Manager) EnrichAll(g graph.Store, roots map[string]string) ([]*EnrichRe } } + // Supplemental providers run last, outside arbitration: they only + // hold AST-grade provenance, so running after a compiler-grade + // winner can confirm-but-never-downgrade what it stamped. + for _, p := range m.providers { + if !isSupplemental(p) || !p.Available() || m.providerDisabled(p.Name()) { + continue + } + langs := p.Languages() + if len(langs) == 0 { + continue + } + results = m.runEnrichForProvider(g, roots, langs[0], p, results) + } + return results, nil } +// providerDisabled reports an explicit `enabled: false` config entry +// for the named provider. Used by the supplemental run loop, which +// never passes through selectProviders' config gate. +func (m *Manager) providerDisabled(name string) bool { + for _, pc := range m.config.Providers { + if pc.Name == name { + return !pc.Enabled + } + } + return false +} + // configPriorityFor returns the user's config-overridden priority for // the named provider, if any. Used to let `.gortex.yaml` take // precedence over the spec's built-in default. @@ -211,7 +254,17 @@ func (m *Manager) runEnrichForProvider(g graph.Store, roots map[string]string, l zap.String("repo", repoName), ) - result, err := provider.Enrich(g, repoRoot) + // repoName is the roots-map key. In multi-repo mode it carries the + // repo prefix (the MultiIndexer keys roots by prefix; the per-repo + // indexer passes its own RepoPrefix()); a repo-scoped provider uses + // it to scope file selection to the repo actually being enriched. + var result *EnrichResult + var err error + if rsp, ok := provider.(RepoScopedProvider); ok { + result, err = rsp.EnrichRepo(g, repoName, repoRoot) + } else { + result, err = provider.Enrich(g, repoRoot) + } if err != nil { m.logger.Warn("semantic enrichment failed", zap.String("provider", provider.Name()), @@ -259,12 +312,39 @@ func (m *Manager) EnrichFile(g graph.Store, repoRoot, filePath string) (*EnrichR } lang := nodes[0].Language - provider, ok := langProviders[lang] - if !ok || !provider.Available() { - return nil, nil + var primary *EnrichResult + var primaryErr error + if provider, ok := langProviders[lang]; ok && provider.Available() { + primary, primaryErr = provider.EnrichFile(g, repoRoot, filePath) } - return provider.EnrichFile(g, repoRoot, filePath) + // Supplemental providers for this language run regardless of the + // arbitration outcome — same contract as EnrichAll. + for _, p := range m.providers { + if !isSupplemental(p) || !p.Available() || m.providerDisabled(p.Name()) { + continue + } + for _, l := range p.Languages() { + if l != lang { + continue + } + res, err := p.EnrichFile(g, repoRoot, filePath) + if err != nil { + m.logger.Debug("supplemental incremental enrichment failed", + zap.String("provider", p.Name()), + zap.String("file", filePath), + zap.Error(err), + ) + break + } + if primary == nil { + primary = res + } + break + } + } + + return primary, primaryErr } // selectProviders returns the highest-priority available provider per language. @@ -292,6 +372,12 @@ func (m *Manager) selectProviders() map[string]Provider { langCandidates := make(map[string][]langCandidate) for _, p := range m.providers { + // Supplemental providers never occupy a language slot — they + // run unconditionally after arbitration (see EnrichAll), so a + // router-backed LSP spec can still win the language. + if isSupplemental(p) { + continue + } ce, ok := configMap[p.Name()] if ok && !ce.enabled { continue diff --git a/internal/semantic/provider.go b/internal/semantic/provider.go index 20ff262f..b2ab00d7 100644 --- a/internal/semantic/provider.go +++ b/internal/semantic/provider.go @@ -32,6 +32,19 @@ type Provider interface { Close() error } +// RepoScopedProvider is an optional interface a Provider MAY implement to +// receive the repo prefix of the enrichment root alongside the root path. +// In a multi-repo daemon the shared graph holds file nodes from every +// tracked repo, and two repos can share a relative path; a provider that +// selects its work by walking graph file nodes needs the prefix to scope +// to the repo actually being enriched rather than guessing from disk +// existence (which a path collision defeats). The Manager calls EnrichRepo +// when the provider implements it, passing the repo's prefix (empty in +// single-repo mode); otherwise it falls back to Enrich. +type RepoScopedProvider interface { + EnrichRepo(g graph.Store, repoPrefix, repoRoot string) (*EnrichResult, error) +} + // EnrichResult contains statistics from an enrichment pass. type EnrichResult struct { Provider string `json:"provider"` diff --git a/internal/semantic/tstypes/csharp.go b/internal/semantic/tstypes/csharp.go new file mode 100644 index 00000000..9156c3b9 --- /dev/null +++ b/internal/semantic/tstypes/csharp.go @@ -0,0 +1,231 @@ +package tstypes + +import ( + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" + "github.com/zzet/gortex/internal/parser/tsitter/csharp" +) + +// CSharpSpec adapts the engine to tree-sitter-c-sharp. Like Java the +// types are explicit; the one quirk is the base list, which does not +// syntactically distinguish the base class from interfaces — those +// SuperRefs carry an empty Kind and the apply phase discriminates by +// the resolved node's kind (strictly better than the extractor's +// I-prefix heuristic). `using` directives import namespaces, not +// names, so cross-file types resolve by repo-unique name only. +func CSharpSpec() *LangSpec { + grammar := csharp.GetLanguage() + return &LangSpec{ + ProviderName: "csharp-types", + Languages: []string{"csharp"}, + GrammarFor: func(string) *sitter.Language { return grammar }, + TypeDeclTypes: map[string]bool{ + "class_declaration": true, + "interface_declaration": true, + "struct_declaration": true, + "record_declaration": true, + }, + FuncDeclTypes: map[string]bool{ + "method_declaration": true, + "constructor_declaration": true, + "local_function_statement": true, + }, + SelfName: "this", + TypeDeclName: nameField, + Supertypes: csharpSupertypes, + Fields: csharpFields, + Params: csharpParams, + ReturnType: func(fn *sitter.Node, src []byte) string { + switch fn.Type() { + case "method_declaration", "local_function_statement": + if t := fieldText(fn, "returns", src); t != "" { + return t + } + return fieldText(fn, "type", src) + } + return "" + }, + LocalBinding: csharpLocalBinding, + Call: csharpCall, + NewExprType: func(n *sitter.Node, src []byte) string { + if n.Type() != "object_creation_expression" { + return "" + } + return fieldText(n, "type", src) + }, + FieldRef: func(n *sitter.Node, src []byte) (string, bool) { + if n.Type() != "member_access_expression" { + return "", false + } + obj := n.ChildByFieldName("expression") + // `this` is a bare keyword node in the current grammar, + // this_expression in older revisions. + if obj == nil || (obj.Type() != "this" && obj.Type() != "this_expression") { + return "", false + } + return fieldText(n, "name", src), true + }, + // using directives bind namespaces, not type names. + Imports: nil, + } +} + +func csharpSupertypes(n *sitter.Node, src []byte) []SuperRef { + baseList := n.ChildByFieldName("bases") + if baseList == nil { + for i := 0; i < int(n.ChildCount()); i++ { + if c := n.Child(i); c != nil && c.Type() == "base_list" { + baseList = c + break + } + } + } + if baseList == nil { + return nil + } + kind := graph.EdgeKind("") // apply phase discriminates by node kind + if n.Type() == "interface_declaration" { + // An interface's bases can only be interfaces. + kind = graph.EdgeExtends + } + var out []SuperRef + for i := 0; i < int(baseList.NamedChildCount()); i++ { + entry := baseList.NamedChild(i) + name := entry.Content(src) + if entry.Type() == "primary_constructor_base_type" { + // `: Base(args)` — the base-class constructor invocation. + if t := firstChildOfType(entry, "identifier"); t != nil { + name = t.Content(src) + } + } + if name == "" { + continue + } + out = append(out, SuperRef{Name: name, Kind: kind, Line: nodeLine(entry)}) + } + return out +} + +func csharpFields(n *sitter.Node, src []byte) []Binding { + body := n.ChildByFieldName("body") + if body == nil { + return nil + } + var out []Binding + for i := 0; i < int(body.NamedChildCount()); i++ { + c := body.NamedChild(i) + switch c.Type() { + case "field_declaration": + decl := firstChildOfType(c, "variable_declaration") + if decl == nil { + continue + } + typ := fieldText(decl, "type", src) + for j := 0; j < int(decl.NamedChildCount()); j++ { + d := decl.NamedChild(j) + if d.Type() != "variable_declarator" { + continue + } + name := csharpDeclaratorName(d, src) + if name == "" { + continue + } + out = append(out, Binding{Name: name, Type: typ, Line: nodeLine(d)}) + } + case "property_declaration": + name := fieldText(c, "name", src) + if name == "" { + continue + } + out = append(out, Binding{Name: name, Type: fieldText(c, "type", src), Line: nodeLine(c)}) + } + } + return out +} + +func csharpParams(fn *sitter.Node, src []byte) []Binding { + params := fn.ChildByFieldName("parameters") + if params == nil { + return nil + } + var out []Binding + for i := 0; i < int(params.NamedChildCount()); i++ { + p := params.NamedChild(i) + if p.Type() != "parameter" { + continue + } + name := fieldText(p, "name", src) + if name == "" { + continue + } + out = append(out, Binding{Name: name, Type: fieldText(p, "type", src), Line: nodeLine(p)}) + } + return out +} + +func csharpLocalBinding(n *sitter.Node, src []byte) (LocalBind, bool) { + switch n.Type() { + case "local_declaration_statement": + decl := firstChildOfType(n, "variable_declaration") + if decl == nil { + return LocalBind{}, false + } + d := firstChildOfType(decl, "variable_declarator") + if d == nil { + return LocalBind{}, false + } + var init *sitter.Node + if eq := firstChildOfType(d, "equals_value_clause"); eq != nil { + // Older grammar revisions wrap the initializer. + init = eq.NamedChild(int(eq.NamedChildCount()) - 1) + } else if count := int(d.NamedChildCount()); count > 1 { + // Current grammar: `s = ` puts the initializer as the + // declarator's trailing named child. + init = d.NamedChild(count - 1) + } + // Target-typed `new()` carries no type of its own — the + // declared type is authoritative either way; the engine only + // falls back to the initializer for `var`. + return LocalBind{ + Name: csharpDeclaratorName(d, src), + DeclType: fieldText(decl, "type", src), + Init: init, + }, true + case "assignment_expression": + left := n.ChildByFieldName("left") + if left == nil || left.Type() != "identifier" { + return LocalBind{}, false + } + return LocalBind{Name: left.Content(src), Init: n.ChildByFieldName("right")}, true + } + return LocalBind{}, false +} + +// csharpDeclaratorName handles both grammar revisions: a `name` field +// or a bare identifier child. +func csharpDeclaratorName(d *sitter.Node, src []byte) string { + if name := fieldText(d, "name", src); name != "" { + return name + } + if id := firstChildOfType(d, "identifier"); id != nil { + return id.Content(src) + } + return "" +} + +func csharpCall(n *sitter.Node, src []byte) (*sitter.Node, string, bool) { + if n.Type() != "invocation_expression" { + return nil, "", false + } + fn := n.ChildByFieldName("function") + if fn == nil || fn.Type() != "member_access_expression" { + return nil, "", false + } + obj := fn.ChildByFieldName("expression") + if obj == nil { + return nil, "", false + } + // A `this` receiver needs no special case: its content matches + // SelfName, so it resolves against the enclosing type. + return obj, fieldText(fn, "name", src), true +} diff --git a/internal/semantic/tstypes/csharp_test.go b/internal/semantic/tstypes/csharp_test.go new file mode 100644 index 00000000..2ba8775d --- /dev/null +++ b/internal/semantic/tstypes/csharp_test.go @@ -0,0 +1,192 @@ +package tstypes + +import ( + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" +) + +const csSvc = `namespace A { + public class Svc { + public void Run() {} + public void Stop() {} + } +} +` + +// `using` directives bind namespaces, not names, so the cross-file +// case rides on repo-unique name resolution. +func TestCSharp_DeclaredParamTypeResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "A/Svc.cs": csSvc, + "B/App.cs": `namespace B { + public class App { + public void Handle(Svc s) { + s.Run(); + } + } +} +`, + }) + p := NewProvider(CSharpSpec(), zap.NewNop()) + res, err := p.Enrich(g, dir) + if err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "Handle", graph.KindMethod) + target := nodeByNameKind(t, g, "Run", graph.KindMethod) + e := callEdgeTo(g, caller.ID, target.ID) + if e == nil { + t.Fatalf("annotated-param call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } + assertASTProvenance(t, e, "csharp-types") + if res.EdgesConfirmed+res.EdgesAdded == 0 { + t.Errorf("result reported no edge work: %+v", res) + } +} + +// Both `Svc s = new Svc()` and `var s = new Svc()` ground the receiver. +func TestCSharp_ConstructorInferenceResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "A/Svc.cs": csSvc, + "B/App.cs": `namespace B { + public class App { + public void Declared() { + Svc s = new Svc(); + s.Run(); + } + + public void Inferred() { + var s = new Svc(); + s.Stop(); + } + } +} +`, + }) + p := NewProvider(CSharpSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + declared := nodeByNameKind(t, g, "Declared", graph.KindMethod) + run := nodeByNameKind(t, g, "Run", graph.KindMethod) + if callEdgeTo(g, declared.ID, run.ID) == nil { + t.Fatalf("declared-type call not resolved; edges: %v", g.GetOutEdges(declared.ID)) + } + inferred := nodeByNameKind(t, g, "Inferred", graph.KindMethod) + stop := nodeByNameKind(t, g, "Stop", graph.KindMethod) + if callEdgeTo(g, inferred.ID, stop.ID) == nil { + t.Fatalf("var-inferred call not resolved; edges: %v", g.GetOutEdges(inferred.ID)) + } +} + +// The base list does not syntactically split the base class from +// interfaces — the resolved node's kind must discriminate. +func TestCSharp_BaseListImplementsAndExtends(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "A/Svc.cs": csSvc, + "A/IGreeter.cs": `namespace A { + public interface IGreeter { + void Greet(); + } +} +`, + "B/Impl.cs": `namespace B { + public class Impl : Svc, IGreeter { + public void Greet() {} + } +} +`, + }) + p := NewProvider(CSharpSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + impl := nodeByNameKind(t, g, "Impl", graph.KindType) + svc := nodeByNameKind(t, g, "Svc", graph.KindType) + iface := nodeByNameKind(t, g, "IGreeter", graph.KindInterface) + + ee := edgeBetween(g, impl.ID, graph.EdgeExtends, svc.ID) + if ee == nil { + t.Fatalf("extends edge missing; edges: %v", g.GetOutEdges(impl.ID)) + } + assertASTProvenance(t, ee, "csharp-types") + + ie := edgeBetween(g, impl.ID, graph.EdgeImplements, iface.ID) + if ie == nil { + t.Fatalf("implements edge missing; edges: %v", g.GetOutEdges(impl.ID)) + } + assertASTProvenance(t, ie, "csharp-types") +} + +// this-qualified calls, typed fields, and typed auto-properties all +// resolve in-class. +func TestCSharp_ThisFieldAndPropertyReceivers(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "A/Svc.cs": csSvc, + "B/App.cs": `namespace B { + public class App { + private Svc worker; + public Svc Backup { get; set; } + + public void Direct() { + this.Helper(); + } + + public void Helper() { + this.worker.Run(); + this.Backup.Stop(); + } + } +} +`, + }) + p := NewProvider(CSharpSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + direct := nodeByNameKind(t, g, "Direct", graph.KindMethod) + helper := nodeByNameKind(t, g, "Helper", graph.KindMethod) + if callEdgeTo(g, direct.ID, helper.ID) == nil { + t.Fatalf("this.Helper() not resolved; edges: %v", g.GetOutEdges(direct.ID)) + } + run := nodeByNameKind(t, g, "Run", graph.KindMethod) + if callEdgeTo(g, helper.ID, run.ID) == nil { + t.Fatalf("this.worker.Run() not resolved through field type; edges: %v", g.GetOutEdges(helper.ID)) + } + stop := nodeByNameKind(t, g, "Stop", graph.KindMethod) + if callEdgeTo(g, helper.ID, stop.ID) == nil { + t.Fatalf("this.Backup.Stop() not resolved through property type; edges: %v", g.GetOutEdges(helper.ID)) + } +} + +func TestCSharp_AmbiguousReceiverStaysUntouched(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "A/Svc.cs": csSvc, + "A/Alt.cs": `namespace A { + public class Alt { + public void Run() {} + } +} +`, + "B/App.cs": `namespace B { + public class App { + public void Main() { + object s; + s = new Svc(); + s = new Alt(); + s.Run(); + } + } +} +`, + }) + p := NewProvider(CSharpSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "Main", graph.KindMethod) + assertUntouched(t, g, caller.ID, "Run", "csharp-types") +} diff --git a/internal/semantic/tstypes/engine.go b/internal/semantic/tstypes/engine.go new file mode 100644 index 00000000..d358c0ec --- /dev/null +++ b/internal/semantic/tstypes/engine.go @@ -0,0 +1,737 @@ +package tstypes + +import ( + "os" + "path/filepath" + "sort" + "strings" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/parser" + "github.com/zzet/gortex/internal/semantic" +) + +// maxFileBytes guards the enrichment pass against pathological +// generated sources; files above the cap are skipped, same spirit as +// the indexer's own size gates. +const maxFileBytes = 4 << 20 + +// astConfidence is the confidence stamped on edges this engine +// confirms or adds. Deliberately below the 1.0 the compiler-grade +// ConfirmEdge uses: tree-sitter scope analysis is structurally +// grounded but not type-checked. +const astConfidence = 0.95 + +// extendsWalkDepth bounds the inherited-method lookup walk up the +// resolved EdgeExtends chain. +const extendsWalkDepth = 3 + +// fileRef is one graph file node selected for analysis plus its +// on-disk location. +type fileRef struct { + node *graph.Node + absPath string +} + +// languageFiles selects the graph file nodes for the spec's languages +// that belong to the repo identified by repoPrefix and exist on disk +// under repoRoot. +// +// Disk existence alone is NOT a safe repo-membership test in multi-repo +// mode: the shared graph holds file nodes from every tracked repo, and +// two repos can share a relative path (both have `src/Svc.java`). Joining +// a foreign repo's node onto repoRoot would then stat-hit and read THIS +// repo's bytes for that repo's node, contaminating its graph. Selection is +// therefore gated on the node's own RepoPrefix matching the prefix of the +// repo being enriched. In single-repo mode every real node carries the +// empty prefix, so repoPrefix == "" selects them all. +func languageFiles(g graph.Store, spec *LangSpec, repoPrefix, repoRoot string) []fileRef { + langs := make(map[string]bool, len(spec.Languages)) + for _, l := range spec.Languages { + langs[l] = true + } + var out []fileRef + for n := range g.NodesByKind(graph.KindFile) { + if !langs[n.Language] || n.RepoPrefix != repoPrefix { + continue + } + ref, ok := fileRefFor(n, repoRoot) + if !ok { + continue + } + out = append(out, ref) + } + return out +} + +// fileRefFor maps a graph file node to its on-disk location under repoRoot +// (stripping the node's own RepoPrefix from the path) and reports whether +// it is an existing, in-cap regular file. The single point that turns a +// graph file key into bytes-on-disk for both the full and incremental +// passes. +func fileRefFor(n *graph.Node, repoRoot string) (fileRef, bool) { + rel := n.FilePath + if n.RepoPrefix != "" { + rel = strings.TrimPrefix(rel, n.RepoPrefix+"/") + } + abs := filepath.Join(repoRoot, filepath.FromSlash(rel)) + if fi, err := os.Stat(abs); err != nil || fi.IsDir() || fi.Size() > maxFileBytes { + return fileRef{}, false + } + return fileRef{node: n, absPath: abs}, true +} + +// analyzeFile parses one file and runs the binder walk. Pure with +// respect to the graph — safe to fan out across workers. +func analyzeFile(spec *LangSpec, ref fileRef) (*fileFacts, error) { + src, err := os.ReadFile(ref.absPath) + if err != nil { + return nil, err + } + grammar := spec.GrammarFor(ref.node.FilePath) + if grammar == nil { + return nil, nil + } + tree, err := parser.ParseFile(src, grammar) + if err != nil { + return nil, err + } + defer tree.Close() + + facts := &fileFacts{file: ref.node.FilePath, repoPrefix: ref.node.RepoPrefix} + newBinder(spec, src, facts).run(tree.RootNode()) + return facts, nil +} + +// applier owns every graph interaction of an enrichment pass. It runs +// single-goroutine so in-place edge mutations never race. +type applier struct { + g graph.Store + spec *LangSpec + provider string + stampedNodes map[string]*graph.Node // collected for one AddBatch round-trip +} + +func newApplier(g graph.Store, spec *LangSpec, provider string) *applier { + return &applier{g: g, spec: spec, provider: provider, stampedNodes: make(map[string]*graph.Node)} +} + +// receiverTypeKinds is the node-kind set a call receiver's type may +// resolve to — methods only hang off types and interfaces. +var receiverTypeKinds = map[graph.NodeKind]bool{ + graph.KindType: true, + graph.KindInterface: true, +} + +// supertypeKinds returns the node-kind set declared supertypes may +// resolve to: the receiver default unless the spec widens it. +func (a *applier) supertypeKinds() map[graph.NodeKind]bool { + if a.spec.SupertypeKinds != nil { + return a.spec.SupertypeKinds + } + return receiverTypeKinds +} + +// fileIndex is the per-file view of the graph the apply phase joins +// facts against. +type fileIndex struct { + facts *fileFacts + imports map[string]string // local name → path hint + types map[string]*graph.Node + // superTypes additionally holds same-file nodes of the spec's + // widened supertype kinds (Ruby modules); aliases types when the + // spec doesn't widen. + superTypes map[string]*graph.Node + funcs []*graph.Node // function/method nodes, for line containment +} + +func (a *applier) buildIndex(facts *fileFacts) *fileIndex { + idx := &fileIndex{ + facts: facts, + imports: make(map[string]string, len(facts.imports)), + types: make(map[string]*graph.Node), + } + idx.superTypes = idx.types + superKinds := a.supertypeKinds() + if a.spec.SupertypeKinds != nil { + idx.superTypes = make(map[string]*graph.Node) + } + for _, imp := range facts.imports { + if imp.Local != "" { + idx.imports[imp.Local] = imp.Path + } + } + for _, n := range a.g.GetFileNodes(facts.file) { + if receiverTypeKinds[n.Kind] { + if _, dup := idx.types[n.Name]; !dup { + idx.types[n.Name] = n + } + } + if a.spec.SupertypeKinds != nil && superKinds[n.Kind] { + if _, dup := idx.superTypes[n.Name]; !dup { + idx.superTypes[n.Name] = n + } + } + if n.Kind == graph.KindFunction || n.Kind == graph.KindMethod { + idx.funcs = append(idx.funcs, n) + } + } + return idx +} + +// applyAll joins every analyzed file's facts against the graph in +// three phases: supertype edges and meta fills first, calls last — +// a call in one file may resolve through an extends edge (or a +// return_type stamp) another file's facts just synthesized. +func (a *applier) applyAll(all []*fileFacts, res *semantic.EnrichResult) { + sort.Slice(all, func(i, j int) bool { return all[i].file < all[j].file }) + idxs := make([]*fileIndex, len(all)) + for i, facts := range all { + idxs[i] = a.buildIndex(facts) + } + for i, facts := range all { + for _, sf := range facts.supers { + a.applySuper(idxs[i], sf, res) + } + } + for i, facts := range all { + for _, mf := range facts.metas { + a.applyMeta(idxs[i], mf, res) + } + } + for i, facts := range all { + for _, cf := range facts.calls { + a.applyCall(idxs[i], cf, res) + } + } +} + +// flush round-trips the stamped nodes through the store in one batch — +// on disk backends an in-place Meta mutation is otherwise discarded. +func (a *applier) flush() { + if len(a.stampedNodes) == 0 { + return + } + nodes := make([]*graph.Node, 0, len(a.stampedNodes)) + for _, n := range a.stampedNodes { + nodes = append(nodes, n) + } + a.g.AddBatch(nodes, nil) +} + +// --- Type / method resolution ---------------------------------------- + +// resolveTypeNode grounds a bare type name to a graph type node: +// same-file declaration first, then import-hinted cross-file match, +// then a repo-unique name match. Returns nil when the name stays +// ambiguous — the engine never guesses among candidates. +func (a *applier) resolveTypeNode(idx *fileIndex, name string) *graph.Node { + return a.resolveNodeOfKinds(idx, name, idx.types, receiverTypeKinds) +} + +// resolveSuperNode is resolveTypeNode over the spec's supertype kind +// set — identical strategy, wider target kinds where the language +// needs it (Ruby modules). +func (a *applier) resolveSuperNode(idx *fileIndex, name string) *graph.Node { + return a.resolveNodeOfKinds(idx, name, idx.superTypes, a.supertypeKinds()) +} + +func (a *applier) resolveNodeOfKinds(idx *fileIndex, name string, sameFile map[string]*graph.Node, kinds map[graph.NodeKind]bool) *graph.Node { + if name == "" { + return nil + } + if n, ok := sameFile[name]; ok { + return n + } + candidates := a.typeCandidates(idx, name, kinds) + if len(candidates) == 0 { + return nil + } + if hint, ok := idx.imports[name]; ok && hint != "" { + var matched []*graph.Node + for _, c := range candidates { + if importMatches(c.FilePath, c.RepoPrefix, hint, idx.facts.file) { + matched = append(matched, c) + } + } + if len(matched) == 1 { + return matched[0] + } + // The hint named a definition site; when it matches several + // candidates the receiver stays ambiguous, and when it matches + // none the real target is an external / stdlib dependency the + // graph doesn't hold. Either way the engine must not fall back + // to a repo-local same-named type — that would mint a false edge + // shadowing the dependency. A missing edge beats a wrong one. + return nil + } + if len(candidates) == 1 { + return candidates[0] + } + return nil +} + +func (a *applier) typeCandidates(idx *fileIndex, name string, kinds map[graph.NodeKind]bool) []*graph.Node { + var raw []*graph.Node + if idx.facts.repoPrefix != "" { + raw = a.g.FindNodesByNameInRepo(name, idx.facts.repoPrefix) + } else { + raw = a.g.FindNodesByName(name) + } + lang := a.languageSet() + var out []*graph.Node + for _, c := range raw { + if !kinds[c.Kind] { + continue + } + if !lang[c.Language] { + continue + } + out = append(out, c) + } + return out +} + +func (a *applier) languageSet() map[string]bool { + set := make(map[string]bool, len(a.spec.Languages)) + for _, l := range a.spec.Languages { + set[l] = true + } + return set +} + +// importMatches reports whether a candidate definition file plausibly +// backs the import-path hint. Relative hints resolve against the +// importing file's directory; absolute (package-style) hints match as +// a path-segment suffix of the candidate's extension-less path. +func importMatches(candidateFile, candidatePrefix, hint, importerFile string) bool { + cand := strings.TrimSuffix(candidateFile, filepath.Ext(candidateFile)) + if candidatePrefix != "" { + cand = strings.TrimPrefix(cand, candidatePrefix+"/") + } + if strings.HasPrefix(hint, "./") || strings.HasPrefix(hint, "../") { + base := importerFile + if i := strings.LastIndex(base, "/"); i >= 0 { + base = base[:i] + } else { + base = "" + } + resolved := filepath.ToSlash(filepath.Join(base, hint)) + return cand == resolved || cand == resolved+"/index" + } + hint = strings.Trim(hint, "/") + if hint == "" { + return false + } + // Package files: Python's __init__.py and Rust's mod.rs name the + // directory, not the file. + return pathSegSuffix(cand, hint) || + pathSegSuffix(cand, hint+"/__init__") || + pathSegSuffix(cand, hint+"/mod") +} + +// pathSegSuffix reports whether want equals cand or a slash-aligned +// suffix of it. +func pathSegSuffix(cand, want string) bool { + return cand == want || strings.HasSuffix(cand, "/"+want) +} + +// methodOn resolves a method name against a type's member set, +// following resolved EdgeExtends links for inherited methods. Returns +// nil when the type (and its ancestry) declares zero or several +// same-named members — overload sets stay untouched rather than +// half-guessed. +func (a *applier) methodOn(typeNode *graph.Node, method string, depth int) *graph.Node { + if typeNode == nil || depth > extendsWalkDepth { + return nil + } + var fromIDs []string + for _, e := range a.g.GetInEdges(typeNode.ID) { + if e.Kind == graph.EdgeMemberOf { + fromIDs = append(fromIDs, e.From) + } + } + var matches []*graph.Node + if len(fromIDs) > 0 { + for _, n := range a.g.GetNodesByIDs(fromIDs) { + if n.Kind == graph.KindMethod && n.Name == method { + matches = append(matches, n) + } + } + } + switch len(matches) { + case 1: + return matches[0] + case 0: + for _, e := range a.g.GetOutEdges(typeNode.ID) { + if e.Kind != graph.EdgeExtends || graph.IsUnresolvedTarget(e.To) { + continue + } + parent := a.g.GetNode(e.To) + if parent == nil || (parent.Kind != graph.KindType && parent.Kind != graph.KindInterface) { + continue + } + if m := a.methodOn(parent, method, depth+1); m != nil { + return m + } + } + } + return nil +} + +// callableReturnType resolves a bare callee name to its graph +// return_type: same-file declaration first, then a repo-unique +// function. The returned name is normalized to the bare type name. +func (a *applier) callableReturnType(idx *fileIndex, callee string) string { + var match *graph.Node + for _, n := range idx.funcs { + if n.Name == callee { + if match != nil { + return "" // same-file overloads: ambiguous + } + match = n + } + } + if match == nil { + var raw []*graph.Node + if idx.facts.repoPrefix != "" { + raw = a.g.FindNodesByNameInRepo(callee, idx.facts.repoPrefix) + } else { + raw = a.g.FindNodesByName(callee) + } + lang := a.languageSet() + for _, c := range raw { + if c.Kind != graph.KindFunction && c.Kind != graph.KindMethod { + continue + } + if !lang[c.Language] { + continue + } + if match != nil { + return "" + } + match = c + } + } + if match == nil || match.Meta == nil { + return "" + } + rt, _ := match.Meta["return_type"].(string) + return a.spec.normalize(rt) +} + +// enclosingCallable returns the innermost function/method node +// containing line. +func (idx *fileIndex) enclosingCallable(line int) *graph.Node { + var best *graph.Node + bestSize := int(^uint(0) >> 1) + for _, n := range idx.funcs { + if n.StartLine <= line && line <= n.EndLine { + if size := n.EndLine - n.StartLine; size < bestSize { + best = n + bestSize = size + } + } + } + return best +} + +// --- Call application ------------------------------------------------- + +func (a *applier) applyCall(idx *fileIndex, cf callFact, res *semantic.EnrichResult) { + recvType := cf.recvType + if recvType == "" && cf.recvPendingCallee != "" { + recvType = a.callableReturnType(idx, cf.recvPendingCallee) + } + var typeNode *graph.Node + if recvType != "" { + typeNode = a.resolveTypeNode(idx, recvType) + } else if cf.recvIdent != "" { + // Static / type-qualified call: only when the identifier is a + // real type in scope of this file's imports. + typeNode = a.resolveTypeNode(idx, cf.recvIdent) + } + if typeNode == nil { + return + } + target := a.methodOn(typeNode, cf.method, 0) + if target == nil { + return + } + caller := idx.enclosingCallable(cf.line) + if caller == nil || caller.ID == target.ID { + return + } + a.upgradeOrCreateCall(caller, target, cf, idx.facts.file, res) +} + +// upgradeOrCreateCall lands a grounded call resolution on the graph: +// confirm the edge when it already points at the target, claim a +// weaker-tier or still-unresolved edge at the same line, otherwise add +// a fresh edge. Edges that already carry compiler/AST-grade provenance +// pointing elsewhere are never overridden. +func (a *applier) upgradeOrCreateCall(caller, target *graph.Node, cf callFact, file string, res *semantic.EnrichResult) { + outs := a.g.GetOutEdges(caller.ID) + for _, e := range outs { + if e.Kind == graph.EdgeCalls && e.To == target.ID { + if a.confirmAST(e) { + res.EdgesConfirmed++ + } + return + } + } + for _, e := range outs { + if e.Kind != graph.EdgeCalls || e.Line != cf.line { + continue + } + if !trailingNameMatches(e.To, cf.method) { + continue + } + if !a.claimable(e) { + // A same-line edge for this name already carries + // equal-or-stronger evidence for a different target — + // leave it alone and don't double the call site. + return + } + oldTo := e.To + e.To = target.ID + a.g.ReindexEdge(e, oldTo) + a.confirmAST(e) + res.EdgesConfirmed++ + return + } + a.addASTEdge(caller.ID, target.ID, graph.EdgeCalls, file, cf.line) + res.EdgesAdded++ +} + +// --- Supertype application -------------------------------------------- + +func (a *applier) applySuper(idx *fileIndex, sf superFact, res *semantic.EnrichResult) { + typeNode, ok := idx.superTypes[sf.typeName] + if !ok { + return + } + superNode := a.resolveSuperNode(idx, sf.superName) + if superNode == nil || superNode.ID == typeNode.ID { + return + } + kind := sf.kind + if kind == "" { + // Syntax didn't discriminate (C# base list): the resolved + // target's node kind decides. + if superNode.Kind == graph.KindInterface { + kind = graph.EdgeImplements + } else { + kind = graph.EdgeExtends + } + } + outs := a.g.GetOutEdges(typeNode.ID) + for _, e := range outs { + if e.Kind == kind && e.To == superNode.ID { + if a.confirmAST(e) { + res.EdgesConfirmed++ + } + return + } + } + for _, e := range outs { + if e.Kind != graph.EdgeExtends && e.Kind != graph.EdgeImplements { + continue + } + if !a.claimable(e) || !trailingNameMatches(e.To, sf.superName) { + continue + } + if e.Kind == kind { + // Same relation kind, only the target changes — an in-place + // retarget + ReindexEdge is safe because the edge's logical + // key (which folds Kind) keeps the same Kind on both sides. + oldTo := e.To + e.To = superNode.ID + a.g.ReindexEdge(e, oldTo) + a.confirmAST(e) + res.EdgesConfirmed++ + return + } + // The relation kind itself changes (a C#-style base list whose + // member turned out to be an interface, not a base class). + // Mutating Kind in place corrupts the adjacency index: ReindexEdge + // reconstructs the old logical key from the already-mutated Kind, + // so the original entry is never removed — the in-memory store + // leaks a stale index slot and the sqlite store ends up with two + // contradictory rows. Drop the old edge and add a fresh one of the + // correct kind instead, mirroring how the compiler-grade providers + // only ever add new edges rather than flip an existing one's kind. + a.g.RemoveEdge(e.From, e.To, e.Kind) + a.addASTEdge(typeNode.ID, superNode.ID, kind, idx.facts.file, sf.line) + res.EdgesAdded++ + return + } + a.addASTEdge(typeNode.ID, superNode.ID, kind, idx.facts.file, sf.line) + res.EdgesAdded++ +} + +// --- Node meta application -------------------------------------------- + +func (a *applier) applyMeta(idx *fileIndex, mf metaFact, res *semantic.EnrichResult) { + var node *graph.Node + if mf.owner != "" { + node = a.findMember(idx, mf.owner, mf.name) + } else if mf.line > 0 { + node = idx.enclosingCallable(mf.line) + if node != nil && node.StartLine != mf.line { + node = nil + } + } + if node == nil { + return + } + if node.Meta != nil { + if existing, ok := node.Meta[mf.key].(string); ok && existing != "" { + return // never overwrite an existing (possibly stronger) stamp + } + } + semantic.EnrichNodeMeta(node, mf.key, mf.value, a.provider) + a.stampedNodes[node.ID] = node + res.NodesEnriched++ +} + +// findMember locates the field/variable node for owner.name in the +// file (extractor convention: Meta["receiver"] carries the owner). +func (a *applier) findMember(idx *fileIndex, owner, name string) *graph.Node { + for _, n := range a.g.GetFileNodes(idx.facts.file) { + if n.Name != name { + continue + } + if n.Kind != graph.KindField && n.Kind != graph.KindVariable { + continue + } + if recv, _ := n.Meta["receiver"].(string); recv == owner { + return n + } + } + return nil +} + +// --- Edge provenance helpers ------------------------------------------- + +// confirmAST stamps tree-sitter-grade provenance on an edge the engine +// grounded: OriginASTResolved (deliberately NOT the lsp_* tiers — +// these resolutions are scope-grounded but not compiler-verified), +// confidence raised to the AST ceiling, and the provider recorded as +// semantic_source. Never downgrades an edge that already carries +// AST-or-better provenance; returns whether anything changed. +func (a *applier) confirmAST(e *graph.Edge) bool { + // Never downgrade. The comparison runs against the EFFECTIVE origin, + // which backfills legacy edges that carry their compiler-grade + // provenance only in Meta["semantic_source"] (Origin unset). Requiring + // a non-empty Origin here would wrongly let those edges through and + // clobber both their tier and their semantic_source — so the only + // gate is the effective-rank comparison. + if graph.OriginRank(effectiveOrigin(e)) >= graph.OriginRank(graph.OriginASTResolved) { + return false + } + a.persistConfirmedAST(e) + return true +} + +// persistConfirmedAST stamps the AST-grade provenance bundle (origin, +// confidence, label, semantic_source) on e and makes it durable on every +// backend. SetEdgeProvenance only writes origin+tier; on a disk backend e +// is a detached row copy, so the confidence / label / Meta mutations would +// be lost unless the full edge is round-tripped — persistEdgeRow does that +// through the backend's edge-attribute write path. +func (a *applier) persistConfirmedAST(e *graph.Edge) { + a.g.SetEdgeProvenance(e, graph.OriginASTResolved) + if e.Confidence < astConfidence { + e.Confidence = astConfidence + } + e.ConfidenceLabel = graph.ConfidenceLabelFor(e.Kind, e.Confidence) + if e.Meta == nil { + e.Meta = make(map[string]any) + } + e.Meta["semantic_source"] = a.provider + a.persistEdgeRow(e) +} + +// persistEdgeRow makes a confirmed edge's full attribute bundle durable. +// On the in-memory backend GetOutEdges returns the live *Edge pointer, so +// the field mutations are already persisted and this is a no-op. A disk +// backend returns a detached row copy; SetEdgeProvenance only wrote +// origin+tier, so the confidence / label / Meta mutations need an explicit +// round-trip through the backend's edge-attribute write path. +func (a *applier) persistEdgeRow(e *graph.Edge) { + if w, ok := a.g.(graph.EdgePersister); ok { + w.PersistEdgeAttributes(e) + } +} + +func (a *applier) addASTEdge(from, to string, kind graph.EdgeKind, file string, line int) *graph.Edge { + e := &graph.Edge{ + From: from, + To: to, + Kind: kind, + FilePath: file, + Line: line, + Confidence: astConfidence, + ConfidenceLabel: graph.ConfidenceLabelFor(kind, astConfidence), + Origin: graph.OriginASTResolved, + Meta: map[string]any{ + "semantic_source": a.provider, + }, + } + a.g.AddEdge(e) + return e +} + +// claimable reports whether the engine may rewire this edge's target: +// still-unresolved / external stub targets always are; resolved +// targets only when their effective provenance ranks below AST-grade +// (a name-locality guess this engine's type evidence outranks). +func (a *applier) claimable(e *graph.Edge) bool { + if isStubTarget(e.To) { + return true + } + return graph.OriginRank(effectiveOrigin(e)) < graph.OriginRank(graph.OriginASTResolved) +} + +// effectiveOrigin returns the edge's provenance tier, backfilling the +// legacy default for edges minted before Origin stamping. +func effectiveOrigin(e *graph.Edge) string { + if e.Origin != "" { + return e.Origin + } + sem := "" + if e.Meta != nil { + sem, _ = e.Meta["semantic_source"].(string) + } + return graph.DefaultOriginFor(e.Kind, e.Confidence, sem) +} + +func isStubTarget(to string) bool { + if graph.IsUnresolvedTarget(to) { + return true + } + for _, p := range []string{"external::", "stdlib::", "dep::"} { + if strings.HasPrefix(to, p) || strings.Contains(to, "::"+p) { + return true + } + } + return false +} + +// trailingNameMatches reports whether a target id's final name segment +// equals name — across the unresolved / stub / resolved id shapes +// (`unresolved::*.m`, `unresolved::m`, `a/b.go::T.m`). +func trailingNameMatches(to, name string) bool { + if name == "" { + return false + } + s := to + if i := strings.LastIndex(s, "::"); i >= 0 { + s = s[i+2:] + } + if i := strings.LastIndex(s, "."); i >= 0 { + s = s[i+1:] + } + return s == name +} diff --git a/internal/semantic/tstypes/fixes_multirepo_test.go b/internal/semantic/tstypes/fixes_multirepo_test.go new file mode 100644 index 00000000..8e625247 --- /dev/null +++ b/internal/semantic/tstypes/fixes_multirepo_test.go @@ -0,0 +1,152 @@ +package tstypes + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/parser" + "github.com/zzet/gortex/internal/parser/languages" +) + +// indexRepoInto extracts every file under one repo root and adds the +// nodes/edges to g under the given repo prefix — the same prefixing the +// MultiIndexer applies (node ID / FilePath / edge endpoints gain a +// `prefix/` and RepoPrefix is stamped). Returns the on-disk root. +func indexRepoInto(t *testing.T, g *graph.Graph, prefix string, files map[string]string) string { + t.Helper() + dir := t.TempDir() + reg := parser.NewRegistry() + languages.RegisterAll(reg) + for rel, content := range files { + abs := filepath.Join(dir, filepath.FromSlash(rel)) + if err := os.MkdirAll(filepath.Dir(abs), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(abs, []byte(content), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + lang, ok := reg.DetectLanguage(rel) + if !ok { + t.Fatalf("no language for %s", rel) + } + ext, ok := reg.GetByLanguage(lang) + if !ok { + t.Fatalf("no extractor for %s", lang) + } + res, err := ext.Extract(rel, []byte(content)) + if err != nil { + t.Fatalf("extract %s: %v", rel, err) + } + if res.Tree != nil { + res.Tree.Close() + } + prefixNodesEdges(prefix, res.Nodes, res.Edges) + g.AddBatch(res.Nodes, res.Edges) + } + return dir +} + +// prefixNodesEdges mirrors the indexer's repo-prefixing for test graphs. +func prefixNodesEdges(prefix string, nodes []*graph.Node, edges []*graph.Edge) { + if prefix == "" { + return + } + p := prefix + "/" + for _, n := range nodes { + n.ID = p + n.ID + n.FilePath = p + n.FilePath + n.RepoPrefix = prefix + } + for _, e := range edges { + e.From = p + e.From + if !strings.HasPrefix(e.To, "unresolved::") { + e.To = p + e.To + } + e.FilePath = p + e.FilePath + } +} + +// In multi-repo mode two repos can share a relative path. languageFiles +// must scope file selection to the repo actually being enriched — never +// read repo A's bytes for repo B's node just because the relative path +// happens to exist under both roots. +func TestEnrich_MultiRepoPathCollisionDoesNotContaminate(t *testing.T) { + g := graph.New() + + // Both repos define pkg/Svc.java + pkg/App.java at the SAME relative + // paths, but with different method names. Repo A's App calls a.run(); + // repo B's App calls b.go(). If file selection leaked across repos, + // enriching one root would parse the other repo's bytes for the + // colliding path. + repoA := map[string]string{ + "pkg/Svc.java": `package pkg; +public class Svc { + public void run() {} +} +`, + "pkg/App.java": `package pkg; +public class App { + public void main() { + Svc s = new Svc(); + s.run(); + } +} +`, + } + repoB := map[string]string{ + "pkg/Svc.java": `package pkg; +public class Svc { + public void go() {} +} +`, + "pkg/App.java": `package pkg; +public class App { + public void main() { + Svc s = new Svc(); + s.go(); + } +} +`, + } + rootA := indexRepoInto(t, g, "repoA", repoA) + rootB := indexRepoInto(t, g, "repoB", repoB) + + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.EnrichRepo(g, "repoA", rootA); err != nil { + t.Fatal(err) + } + if _, err := p.EnrichRepo(g, "repoB", rootB); err != nil { + t.Fatal(err) + } + + // Repo A's main must call repoA Svc.run, and nothing in repo A may + // point at a repoB target. + mainA := "repoA/pkg/App.java::App.main" + if callEdgeTo(g, mainA, "repoA/pkg/Svc.java::Svc.run") == nil { + t.Fatalf("repo A call run() not resolved within repo A; edges: %v", g.GetOutEdges(mainA)) + } + for _, e := range g.GetOutEdges(mainA) { + if strings.HasPrefix(e.To, "repoB/") { + t.Fatalf("repo A edge leaked into repo B target: %s -> %s", e.From, e.To) + } + } + + // Repo B's main must call repoB Svc.go, and nothing in repo B may + // point at a repoA target. + mainB := "repoB/pkg/App.java::App.main" + if callEdgeTo(g, mainB, "repoB/pkg/Svc.java::Svc.go") == nil { + t.Fatalf("repo B call go() not resolved within repo B; edges: %v", g.GetOutEdges(mainB)) + } + for _, e := range g.GetOutEdges(mainB) { + if strings.HasPrefix(e.To, "repoA/") { + t.Fatalf("repo B edge leaked into repo A target: %s -> %s", e.From, e.To) + } + } + + _ = rootB +} diff --git a/internal/semantic/tstypes/fixes_sqlite_test.go b/internal/semantic/tstypes/fixes_sqlite_test.go new file mode 100644 index 00000000..9649c243 --- /dev/null +++ b/internal/semantic/tstypes/fixes_sqlite_test.go @@ -0,0 +1,137 @@ +package tstypes + +import ( + "os" + "path/filepath" + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/graph/store_sqlite" + "github.com/zzet/gortex/internal/parser" + "github.com/zzet/gortex/internal/parser/languages" +) + +// buildSQLiteFixture writes the fixture to disk, extracts each file with +// the real per-language extractors, and loads the nodes/edges into a +// fresh on-disk SQLite store at dbPath. Returns the store and the on-disk +// source root. +func buildSQLiteFixture(t *testing.T, dbPath string, files map[string]string) (*store_sqlite.Store, string) { + t.Helper() + dir := t.TempDir() + s, err := store_sqlite.Open(dbPath) + if err != nil { + t.Fatalf("open sqlite store: %v", err) + } + reg := parser.NewRegistry() + languages.RegisterAll(reg) + for rel, content := range files { + abs := filepath.Join(dir, filepath.FromSlash(rel)) + if err := os.MkdirAll(filepath.Dir(abs), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(abs, []byte(content), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + lang, ok := reg.DetectLanguage(rel) + if !ok { + t.Fatalf("no language for %s", rel) + } + ext, ok := reg.GetByLanguage(lang) + if !ok { + t.Fatalf("no extractor for %s", lang) + } + res, err := ext.Extract(rel, []byte(content)) + if err != nil { + t.Fatalf("extract %s: %v", rel, err) + } + if res.Tree != nil { + res.Tree.Close() + } + s.AddBatch(res.Nodes, res.Edges) + } + return s, dir +} + +// On a disk backend GetOutEdges returns a detached row copy; confirming an +// edge mutates Confidence / ConfidenceLabel / Meta on that copy, and +// SetEdgeProvenance only writes origin+tier. Those extra attributes must +// still survive a reload — the engine round-trips the full edge through +// the backend's edge-attribute write path. +func TestEnrich_SQLiteConfirmationPersistsFullProvenance(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "store.sqlite") + + files := map[string]string{ + "a/Svc.java": javaSvc, + "b/App.java": `package b; + +import a.Svc; + +public class App { + public void handle(Svc s) { + s.run(); + } +} +`, + } + s, root := buildSQLiteFixture(t, dbPath, files) + + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.EnrichRepo(s, "", root); err != nil { + t.Fatalf("enrich: %v", err) + } + + callerID := "b/App.java::App.handle" + targetID := "a/Svc.java::Svc.run" + + // Sanity: the edge is confirmed in the live store. + if e := outEdgeTo(s, callerID, targetID); e == nil { + t.Fatalf("call edge not present after enrich; edges: %v", s.GetOutEdges(callerID)) + } + + if err := s.Close(); err != nil { + t.Fatalf("close: %v", err) + } + + // Reopen from disk: every confirmed attribute must have persisted, not + // just origin/tier. + s2, err := store_sqlite.Open(dbPath) + if err != nil { + t.Fatalf("reopen: %v", err) + } + defer func() { _ = s2.Close() }() + + e := outEdgeTo(s2, callerID, targetID) + if e == nil { + t.Fatalf("call edge lost across reopen; edges: %v", s2.GetOutEdges(callerID)) + } + if e.Origin != graph.OriginASTResolved { + t.Errorf("origin = %q after reload, want %q", e.Origin, graph.OriginASTResolved) + } + if e.Confidence < astConfidence { + t.Errorf("confidence = %v after reload, want >= %v (lost on disk write-back)", e.Confidence, astConfidence) + } + if e.ConfidenceLabel == "" { + t.Errorf("confidence_label empty after reload (lost on disk write-back)") + } + if e.Meta == nil || e.Meta["semantic_source"] != "java-types" { + t.Errorf("semantic_source = %v after reload, want java-types (lost on disk write-back)", metaVal(e)) + } +} + +func outEdgeTo(s *store_sqlite.Store, fromID, toID string) *graph.Edge { + for _, e := range s.GetOutEdges(fromID) { + if e.Kind == graph.EdgeCalls && e.To == toID { + return e + } + } + return nil +} + +func metaVal(e *graph.Edge) any { + if e == nil || e.Meta == nil { + return nil + } + return e.Meta["semantic_source"] +} diff --git a/internal/semantic/tstypes/fixes_test.go b/internal/semantic/tstypes/fixes_test.go new file mode 100644 index 00000000..f31176a5 --- /dev/null +++ b/internal/semantic/tstypes/fixes_test.go @@ -0,0 +1,169 @@ +package tstypes + +import ( + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" +) + +// When a declared-supertype edge must change KIND (the C# base list emits +// `extends` to a stub that resolves to an interface), the engine must not +// mutate the existing edge's Kind in place: ReindexEdge reconstructs the +// old logical key from the already-flipped Kind, so the stub's inEdges +// bucket is never cleaned and leaks a stale reference. The fix drops the +// old edge and adds a fresh one of the correct kind. +func TestCSharp_SupertypeKindFlipLeavesNoStaleAdjacency(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "A/Greeter.cs": `namespace A { + public interface Greeter { void Greet(); } +} +`, + "B/Impl.cs": `namespace B { + public class Impl : Greeter { public void Greet() {} } +} +`, + }) + impl := nodeByNameKind(t, g, "Impl", graph.KindType) + greeter := nodeByNameKind(t, g, "Greeter", graph.KindInterface) + + p := NewProvider(CSharpSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + + // Exactly one supertype edge from Impl, and it is implements→Greeter. + var supers []*graph.Edge + for _, e := range g.GetOutEdges(impl.ID) { + if e.Kind == graph.EdgeExtends || e.Kind == graph.EdgeImplements { + supers = append(supers, e) + } + } + if len(supers) != 1 { + t.Fatalf("want exactly 1 supertype edge, got %d: %v", len(supers), supers) + } + if supers[0].Kind != graph.EdgeImplements || supers[0].To != greeter.ID { + t.Fatalf("supertype edge = %s -> %s, want implements -> %s", supers[0].Kind, supers[0].To, greeter.ID) + } + + // The original stub target's inEdges bucket must hold no leftover + // reference to the (now-retargeted) edge. The in-place kind flip left + // one behind because removeEdgeFromBucket was handed the post-flip key. + if stale := g.GetInEdges("unresolved::Greeter"); len(stale) != 0 { + t.Fatalf("stub unresolved::Greeter retains %d stale in-edge(s): %v", len(stale), stale) + } + + // The resolved interface must carry exactly the one implements in-edge. + implCount := 0 + for _, e := range g.GetInEdges(greeter.ID) { + if e.Kind == graph.EdgeImplements { + implCount++ + } + } + if implCount != 1 { + t.Fatalf("Greeter has %d implements in-edges, want 1", implCount) + } + + if err := g.VerifyEdgeIdentities(); err != nil { + t.Fatalf("graph edge identities inconsistent after kind flip: %v", err) + } +} + +// An import hint that refutes every repo-local candidate means the real +// target is an external / stdlib dependency the graph doesn't hold. The +// engine must NOT fall back to a lone same-named repo type — that mints a +// false edge shadowing the dependency. +func TestJava_ImportHintRefutingAllCandidatesResolvesNothing(t *testing.T) { + // The repo has exactly one type named Logger, in pkg `a`. App imports a + // DIFFERENT Logger (an external `org.slf4j.Logger`), so the hint points + // at org/slf4j, which the repo-local a/Logger.java does not satisfy. + g, dir := buildFixture(t, map[string]string{ + "a/Logger.java": `package a; + +public class Logger { + public void info() {} +} +`, + "b/App.java": `package b; + +import org.slf4j.Logger; + +public class App { + public void run() { + Logger log = makeLogger(); + log.info(); + } + + private Logger makeLogger() { return null; } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + run := nodeByNameKind(t, g, "run", graph.KindMethod) + // info() must NOT resolve to the repo-local a/Logger.java::Logger.info: + // the import hint refuted that candidate. + localInfo := "a/Logger.java::Logger.info" + if e := callEdgeTo(g, run.ID, localInfo); e != nil { + t.Fatalf("info() falsely resolved to repo-local %s despite import hint pointing at org.slf4j", localInfo) + } +} + +// confirmAST must never DOWNGRADE an edge whose effective provenance is +// already stronger than AST-grade, even when that strength lives only in +// Meta["semantic_source"] with Origin unset (a legacy compiler-confirmed +// edge). The old guard required Origin != "" and so clobbered those edges' +// tier and semantic_source. +func TestConfirmAST_DoesNotDowngradeLegacyCompilerConfirmedEdge(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "b/App.java": `package b; + +import a.Svc; + +public class App { + public void handle(Svc s) { + s.run(); + } +} +`, + }) + caller := nodeByNameKind(t, g, "handle", graph.KindMethod) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + + // Simulate a legacy compiler-grade edge: a resolved calls edge whose + // strength is recorded only in semantic_source, with Origin unset + // (edges minted before Origin stamping). Its effective origin is + // lsp_resolved (rank 6). + g.AddEdge(&graph.Edge{ + From: caller.ID, + To: target.ID, + Kind: graph.EdgeCalls, + FilePath: "b/App.java", + Line: 7, + Confidence: 1.0, + ConfidenceLabel: "EXTRACTED", + Origin: "", // legacy: tier lives in Meta only + Meta: map[string]any{"semantic_source": "java-lsp"}, + }) + + a := newApplier(g, JavaSpec(), "java-types") + e := callEdgeTo(g, caller.ID, target.ID) + if e == nil { + t.Fatal("seed edge missing") + } + if a.confirmAST(e) { + t.Fatalf("confirmAST reported a change on a stronger-than-AST edge") + } + if got := effectiveOrigin(e); graph.OriginRank(got) < graph.OriginRank(graph.OriginLSPResolved) { + t.Fatalf("effective origin downgraded to %q (rank %d), want >= lsp_resolved", got, graph.OriginRank(got)) + } + if e.Meta["semantic_source"] != "java-lsp" { + t.Fatalf("semantic_source clobbered to %v, want java-lsp", e.Meta["semantic_source"]) + } + + _ = dir +} diff --git a/internal/semantic/tstypes/harness_test.go b/internal/semantic/tstypes/harness_test.go new file mode 100644 index 00000000..77e1de24 --- /dev/null +++ b/internal/semantic/tstypes/harness_test.go @@ -0,0 +1,131 @@ +package tstypes + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/parser" + "github.com/zzet/gortex/internal/parser/languages" +) + +// buildFixture writes the fixture files under a temp dir, indexes them +// with the real per-language extractors (so the graph carries the +// exact node-ID and unresolved-edge conventions the daemon's index +// produces), and returns the graph plus the repo root. +func buildFixture(t *testing.T, files map[string]string) (*graph.Graph, string) { + t.Helper() + dir := t.TempDir() + g := graph.New() + reg := parser.NewRegistry() + languages.RegisterAll(reg) + + for rel, content := range files { + abs := filepath.Join(dir, filepath.FromSlash(rel)) + if err := os.MkdirAll(filepath.Dir(abs), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(abs, []byte(content), 0o644); err != nil { + t.Fatalf("write fixture: %v", err) + } + lang, ok := reg.DetectLanguage(rel) + if !ok { + t.Fatalf("no language for %s", rel) + } + ext, ok := reg.GetByLanguage(lang) + if !ok { + t.Fatalf("no extractor for %s", lang) + } + res, err := ext.Extract(rel, []byte(content)) + if err != nil { + t.Fatalf("extract %s: %v", rel, err) + } + if res.Tree != nil { + res.Tree.Close() + } + g.AddBatch(res.Nodes, res.Edges) + } + return g, dir +} + +// nodeByNameKind returns the unique node with the given name and kind. +func nodeByNameKind(t *testing.T, g *graph.Graph, name string, kind graph.NodeKind) *graph.Node { + t.Helper() + var found *graph.Node + for _, n := range g.FindNodesByName(name) { + if n.Kind != kind { + continue + } + if found != nil { + t.Fatalf("multiple %s nodes named %q", kind, name) + } + found = n + } + if found == nil { + t.Fatalf("no %s node named %q", kind, name) + } + return found +} + +// callEdgeTo returns the first calls-edge from the caller whose target +// id is exactly to. +func callEdgeTo(g *graph.Graph, fromID, to string) *graph.Edge { + for _, e := range g.GetOutEdges(fromID) { + if e.Kind == graph.EdgeCalls && e.To == to { + return e + } + } + return nil +} + +// callEdgesNamed returns every calls-edge from the caller whose target +// trailing name matches. +func callEdgesNamed(g *graph.Graph, fromID, name string) []*graph.Edge { + var out []*graph.Edge + for _, e := range g.GetOutEdges(fromID) { + if e.Kind == graph.EdgeCalls && trailingNameMatches(e.To, name) { + out = append(out, e) + } + } + return out +} + +// edgeBetween returns the edge of the given kind between two node ids. +func edgeBetween(g *graph.Graph, fromID string, kind graph.EdgeKind, toID string) *graph.Edge { + for _, e := range g.GetOutEdges(fromID) { + if e.Kind == kind && e.To == toID { + return e + } + } + return nil +} + +// assertASTProvenance checks the edge carries this engine's stamp. +func assertASTProvenance(t *testing.T, e *graph.Edge, provider string) { + t.Helper() + if e.Origin != graph.OriginASTResolved { + t.Errorf("origin = %q, want %q", e.Origin, graph.OriginASTResolved) + } + if e.Meta == nil || e.Meta["semantic_source"] != provider { + t.Errorf("semantic_source = %v, want %q", e.Meta["semantic_source"], provider) + } + if e.Confidence < astConfidence { + t.Errorf("confidence = %v, want >= %v", e.Confidence, astConfidence) + } +} + +// assertUntouched checks no engine stamp landed on any calls-edge of +// the caller matching the method name — the negative-case contract. +func assertUntouched(t *testing.T, g *graph.Graph, fromID, method, provider string) { + t.Helper() + for _, e := range callEdgesNamed(g, fromID, method) { + if e.Meta != nil && e.Meta["semantic_source"] == provider { + t.Errorf("edge %s -> %s was touched by %s; want untouched", e.From, e.To, provider) + } + if !graph.IsUnresolvedTarget(e.To) && !strings.Contains(e.To, "::") { + t.Errorf("edge target %q unexpectedly resolved", e.To) + } + } +} diff --git a/internal/semantic/tstypes/java.go b/internal/semantic/tstypes/java.go new file mode 100644 index 00000000..9c337979 --- /dev/null +++ b/internal/semantic/tstypes/java.go @@ -0,0 +1,242 @@ +package tstypes + +import ( + "strings" + + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" + "github.com/zzet/gortex/internal/parser/tsitter/java" +) + +// JavaSpec adapts the engine to tree-sitter-java. Types are explicit +// everywhere, so the binder grounds receivers from parameter / local / +// field annotations and `new` expressions; `implements` / `extends` +// clauses come straight off the declaration. +func JavaSpec() *LangSpec { + grammar := java.GetLanguage() + return &LangSpec{ + ProviderName: "java-types", + Languages: []string{"java"}, + GrammarFor: func(string) *sitter.Language { return grammar }, + TypeDeclTypes: map[string]bool{ + "class_declaration": true, + "interface_declaration": true, + "enum_declaration": true, + }, + FuncDeclTypes: map[string]bool{ + "method_declaration": true, + "constructor_declaration": true, + }, + SelfName: "this", + TypeDeclName: nameField, + Supertypes: javaSupertypes, + Fields: javaFields, + Params: javaParams, + ReturnType: func(fn *sitter.Node, src []byte) string { + if fn.Type() != "method_declaration" { + return "" + } + return fieldText(fn, "type", src) + }, + LocalBinding: javaLocalBinding, + Call: javaCall, + NewExprType: func(n *sitter.Node, src []byte) string { + if n.Type() != "object_creation_expression" { + return "" + } + return fieldText(n, "type", src) + }, + FieldRef: func(n *sitter.Node, src []byte) (string, bool) { + if n.Type() != "field_access" { + return "", false + } + obj := n.ChildByFieldName("object") + if obj == nil || obj.Type() != "this" { + return "", false + } + return fieldText(n, "field", src), true + }, + Imports: javaImports, + } +} + +func javaSupertypes(n *sitter.Node, src []byte) []SuperRef { + var out []SuperRef + switch n.Type() { + case "class_declaration": + if sup := n.ChildByFieldName("superclass"); sup != nil { + for i := 0; i < int(sup.NamedChildCount()); i++ { + c := sup.NamedChild(i) + switch c.Type() { + case "type_identifier", "generic_type", "scoped_type_identifier": + out = append(out, SuperRef{Name: c.Content(src), Kind: graph.EdgeExtends, Line: nodeLine(c)}) + } + } + } + if ifaces := n.ChildByFieldName("interfaces"); ifaces != nil { + out = append(out, javaTypeList(ifaces, src, graph.EdgeImplements)...) + } + case "interface_declaration": + // `interface A extends B, C` — extends_interfaces is an unnamed + // field in the grammar; scan direct children. + for i := 0; i < int(n.ChildCount()); i++ { + if c := n.Child(i); c != nil && c.Type() == "extends_interfaces" { + out = append(out, javaTypeList(c, src, graph.EdgeExtends)...) + } + } + case "enum_declaration": + if ifaces := n.ChildByFieldName("interfaces"); ifaces != nil { + out = append(out, javaTypeList(ifaces, src, graph.EdgeImplements)...) + } + } + return out +} + +// javaTypeList flattens a super_interfaces / extends_interfaces node's +// type_list into SuperRefs. +func javaTypeList(n *sitter.Node, src []byte, kind graph.EdgeKind) []SuperRef { + var out []SuperRef + var visit func(c *sitter.Node) + visit = func(c *sitter.Node) { + if c == nil { + return + } + switch c.Type() { + case "type_list": + for i := 0; i < int(c.NamedChildCount()); i++ { + visit(c.NamedChild(i)) + } + case "type_identifier", "generic_type", "scoped_type_identifier": + out = append(out, SuperRef{Name: c.Content(src), Kind: kind, Line: nodeLine(c)}) + } + } + for i := 0; i < int(n.NamedChildCount()); i++ { + visit(n.NamedChild(i)) + } + return out +} + +func javaFields(n *sitter.Node, src []byte) []Binding { + body := n.ChildByFieldName("body") + if body == nil { + return nil + } + var out []Binding + for i := 0; i < int(body.NamedChildCount()); i++ { + c := body.NamedChild(i) + if c.Type() != "field_declaration" { + continue + } + typ := fieldText(c, "type", src) + for j := 0; j < int(c.NamedChildCount()); j++ { + d := c.NamedChild(j) + if d.Type() != "variable_declarator" { + continue + } + name := fieldText(d, "name", src) + if name == "" { + continue + } + out = append(out, Binding{Name: name, Type: typ, Line: nodeLine(d)}) + } + } + return out +} + +func javaParams(fn *sitter.Node, src []byte) []Binding { + params := fn.ChildByFieldName("parameters") + if params == nil { + return nil + } + var out []Binding + for i := 0; i < int(params.NamedChildCount()); i++ { + p := params.NamedChild(i) + switch p.Type() { + case "formal_parameter", "spread_parameter": + name := fieldText(p, "name", src) + if name == "" { + // spread_parameter puts the variable_declarator last. + for j := int(p.NamedChildCount()) - 1; j >= 0; j-- { + if c := p.NamedChild(j); c.Type() == "variable_declarator" { + name = fieldText(c, "name", src) + break + } else if c.Type() == "identifier" { + name = c.Content(src) + break + } + } + } + if name == "" { + continue + } + out = append(out, Binding{Name: name, Type: fieldText(p, "type", src), Line: nodeLine(p)}) + } + } + return out +} + +func javaLocalBinding(n *sitter.Node, src []byte) (LocalBind, bool) { + switch n.Type() { + case "local_variable_declaration": + decl := firstChildOfType(n, "variable_declarator") + if decl == nil { + return LocalBind{}, false + } + return LocalBind{ + Name: fieldText(decl, "name", src), + DeclType: fieldText(n, "type", src), + Init: decl.ChildByFieldName("value"), + }, true + case "assignment_expression": + left := n.ChildByFieldName("left") + if left == nil || left.Type() != "identifier" { + return LocalBind{}, false + } + return LocalBind{Name: left.Content(src), Init: n.ChildByFieldName("right")}, true + } + return LocalBind{}, false +} + +func javaCall(n *sitter.Node, src []byte) (*sitter.Node, string, bool) { + if n.Type() != "method_invocation" { + return nil, "", false + } + obj := n.ChildByFieldName("object") + if obj == nil { + return nil, "", false + } + return obj, fieldText(n, "name", src), true +} + +func javaImports(root *sitter.Node, src []byte) []Import { + var out []Import + for i := 0; i < int(root.NamedChildCount()); i++ { + c := root.NamedChild(i) + if c.Type() != "import_declaration" { + continue + } + path := "" + isWildcard := false + for j := 0; j < int(c.ChildCount()); j++ { + ch := c.Child(j) + if ch == nil { + continue + } + switch ch.Type() { + case "scoped_identifier", "identifier": + path = ch.Content(src) + case "asterisk": + isWildcard = true + } + } + if path == "" || isWildcard { + continue + } + local := path + if idx := strings.LastIndex(local, "."); idx >= 0 { + local = local[idx+1:] + } + out = append(out, Import{Local: local, Path: strings.ReplaceAll(path, ".", "/")}) + } + return out +} diff --git a/internal/semantic/tstypes/java_test.go b/internal/semantic/tstypes/java_test.go new file mode 100644 index 00000000..6b8125f7 --- /dev/null +++ b/internal/semantic/tstypes/java_test.go @@ -0,0 +1,303 @@ +package tstypes + +import ( + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" +) + +const javaSvc = `package a; + +public class Svc { + public void run() { + } + + public void stop() { + } +} +` + +const javaIface = `package a; + +public interface Greeter { + void greet(); +} +` + +func TestJava_DeclaredParamTypeResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "b/App.java": `package b; + +import a.Svc; + +public class App { + public void handle(Svc s) { + s.run(); + } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + res, err := p.Enrich(g, dir) + if err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "handle", graph.KindMethod) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + e := callEdgeTo(g, caller.ID, target.ID) + if e == nil { + t.Fatalf("call edge %s -> %s not resolved; edges: %v", caller.ID, target.ID, g.GetOutEdges(caller.ID)) + } + assertASTProvenance(t, e, "java-types") + if res.EdgesConfirmed+res.EdgesAdded == 0 { + t.Errorf("result reported no edge work: %+v", res) + } +} + +func TestJava_ConstructorInferenceResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "b/App.java": `package b; + +import a.Svc; + +public class App { + public void main() { + Svc s = new Svc(); + s.run(); + } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindMethod) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, caller.ID, target.ID) == nil { + t.Fatalf("constructor-inferred call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } +} + +// Cross-file resolution must follow the import hint when several types +// share a name. +func TestJava_ImportHintDisambiguatesCrossFile(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "other/Svc.java": `package other; + +public class Svc { + public void run() { + } +} +`, + "b/App.java": `package b; + +import a.Svc; + +public class App { + public void main() { + Svc s = new Svc(); + s.run(); + } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindMethod) + want := "a/Svc.java::Svc.run" + if callEdgeTo(g, caller.ID, want) == nil { + t.Fatalf("import-hinted call did not land on %s; edges: %v", want, g.GetOutEdges(caller.ID)) + } + wrong := "other/Svc.java::Svc.run" + if callEdgeTo(g, caller.ID, wrong) != nil { + t.Fatalf("call landed on the wrong package's type %s", wrong) + } +} + +func TestJava_ImplementsAndExtendsSynthesis(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "a/Greeter.java": javaIface, + "b/Impl.java": `package b; + +import a.Greeter; +import a.Svc; + +public class Impl extends Svc implements Greeter { + public void greet() { + } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + impl := nodeByNameKind(t, g, "Impl", graph.KindType) + iface := nodeByNameKind(t, g, "Greeter", graph.KindInterface) + svc := nodeByNameKind(t, g, "Svc", graph.KindType) + + ie := edgeBetween(g, impl.ID, graph.EdgeImplements, iface.ID) + if ie == nil { + t.Fatalf("implements edge missing; edges: %v", g.GetOutEdges(impl.ID)) + } + assertASTProvenance(t, ie, "java-types") + + ee := edgeBetween(g, impl.ID, graph.EdgeExtends, svc.ID) + if ee == nil { + t.Fatalf("extends edge missing; edges: %v", g.GetOutEdges(impl.ID)) + } + assertASTProvenance(t, ee, "java-types") +} + +// Inherited methods resolve through the synthesized extends chain. +func TestJava_InheritedMethodResolvesThroughExtends(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "b/Sub.java": `package b; + +import a.Svc; + +public class Sub extends Svc { +} +`, + "c/App.java": `package c; + +import b.Sub; + +public class App { + public void main() { + Sub s = new Sub(); + s.run(); + } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindMethod) + want := "a/Svc.java::Svc.run" + if callEdgeTo(g, caller.ID, want) == nil { + t.Fatalf("inherited method call did not resolve to %s; edges: %v", want, g.GetOutEdges(caller.ID)) + } +} + +// this-qualified and field-typed receivers resolve inside the class. +func TestJava_SelfAndFieldReceivers(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "b/App.java": `package b; + +import a.Svc; + +public class App { + private Svc worker; + + public void direct() { + this.helper(); + } + + public void helper() { + this.worker.run(); + } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + direct := nodeByNameKind(t, g, "direct", graph.KindMethod) + helper := nodeByNameKind(t, g, "helper", graph.KindMethod) + if callEdgeTo(g, direct.ID, helper.ID) == nil { + t.Fatalf("this.helper() not resolved; edges: %v", g.GetOutEdges(direct.ID)) + } + run := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, helper.ID, run.ID) == nil { + t.Fatalf("this.worker.run() not resolved through field type; edges: %v", g.GetOutEdges(helper.ID)) + } +} + +// A receiver rebound to a different type degrades to unknown — the +// engine must leave its calls untouched rather than guess. +func TestJava_AmbiguousReceiverStaysUntouched(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "a/Alt.java": `package a; + +public class Alt { + public void run() { + } +} +`, + "b/App.java": `package b; + +import a.Alt; +import a.Svc; + +public class App { + public void main() { + Object s; + s = new Svc(); + s = new Alt(); + s.run(); + } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindMethod) + assertUntouched(t, g, caller.ID, "run", "java-types") +} + +func TestJava_EnrichFileScopesToOneFile(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "a/Svc.java": javaSvc, + "b/App.java": `package b; + +import a.Svc; + +public class App { + public void main() { + Svc s = new Svc(); + s.run(); + } +} +`, + "c/Other.java": `package c; + +import a.Svc; + +public class Other { + public void go() { + Svc s = new Svc(); + s.stop(); + } +} +`, + }) + p := NewProvider(JavaSpec(), zap.NewNop()) + if _, err := p.EnrichFile(g, dir, "b/App.java"); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindMethod) + run := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, caller.ID, run.ID) == nil { + t.Fatalf("EnrichFile did not resolve the target file's call") + } + other := nodeByNameKind(t, g, "go", graph.KindMethod) + assertUntouched(t, g, other.ID, "stop", "java-types") +} diff --git a/internal/semantic/tstypes/manager_test.go b/internal/semantic/tstypes/manager_test.go new file mode 100644 index 00000000..5a669166 --- /dev/null +++ b/internal/semantic/tstypes/manager_test.go @@ -0,0 +1,196 @@ +package tstypes + +import ( + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/semantic" +) + +// mixedFixture is one small caller/callee pair per supported language. +func mixedFixture() map[string]string { + return map[string]string{ + "a/Svc.java": javaSvc, + "b/App.java": `package b; + +import a.Svc; + +public class App { + public void main() { + Svc s = new Svc(); + s.run(); + } +} +`, + "app/svc.py": pySvc, + "app/main.py": `from app.svc import Svc + + +def main(): + s = Svc() + s.run() +`, + "src/svc.ts": tsSvc, + "src/app.ts": `import { Svc } from "./svc"; + +export function main(): void { + const s = new Svc(); + s.run(); +} +`, + "lib/svc.rb": rubySvc, + "lib/app.rb": `class App + def main + s = Svc.new + s.run + end +end +`, + "rs/engine.rs": rustSvc, + "rs/app.rs": `use crate::engine::Svc; + +pub fn main() { + let s = Svc::new(); + s.run(); +} +`, + "A/Svc.cs": csSvc, + "B/App.cs": `namespace B { + public class App { + public void Main() { + var s = new Svc(); + s.Run(); + } + } +} +`, + } +} + +// All six in-process providers register on a plain manager and resolve +// the mixed fixture without any LSP router or external binary. +func TestManager_SupplementalProvidersEnrichWithoutLSP(t *testing.T) { + g, dir := buildFixture(t, mixedFixture()) + mgr := semantic.NewManager(semantic.Config{Enabled: true}, zap.NewNop()) + defer func() { _ = mgr.Close() }() + for _, p := range DefaultProviders(zap.NewNop()) { + mgr.RegisterProvider(p) + } + if !mgr.HasProviders() { + t.Fatal("manager reports no available providers") + } + + results, err := mgr.EnrichAll(g, map[string]string{"": dir}) + if err != nil { + t.Fatal(err) + } + got := make(map[string]*semantic.EnrichResult, len(results)) + for _, r := range results { + got[r.Provider] = r + } + for _, want := range []string{"java-types", "python-types", "ruby-types", "rust-types", "typescript-types", "csharp-types"} { + r, ok := got[want] + if !ok { + t.Errorf("no enrich result from %s", want) + continue + } + if r.EdgesConfirmed+r.EdgesAdded == 0 { + t.Errorf("%s did no edge work: %+v", want, r) + } + } + + // Spot-check one resolution per language family actually landed. + checks := []struct { + caller, callerKind, target string + }{ + {"main", "java", "a/Svc.java::Svc.run"}, + {"main", "python", "app/svc.py::Svc.run"}, + {"main", "typescript", "src/svc.ts::Svc.run"}, + {"main", "ruby", "lib/svc.rb::Svc.run"}, + {"main", "rust", "rs/engine.rs::Svc.run"}, + {"Main", "csharp", "A/Svc.cs::Svc.Run"}, + } + for _, c := range checks { + var caller *graph.Node + for _, n := range g.FindNodesByName(c.caller) { + if n.Language == c.callerKind && (n.Kind == graph.KindFunction || n.Kind == graph.KindMethod) { + caller = n + break + } + } + if caller == nil { + t.Errorf("no %s caller named %s", c.callerKind, c.caller) + continue + } + if callEdgeTo(g, caller.ID, c.target) == nil { + t.Errorf("%s: call %s -> %s not resolved; edges: %v", c.callerKind, caller.ID, c.target, g.GetOutEdges(caller.ID)) + } + } + + // Stats must report every in-process provider ready. + ready := make(map[string]bool) + for _, st := range mgr.Stats() { + if st.Status == "ready" { + ready[st.Name] = true + } + } + for _, want := range []string{"java-types", "python-types", "ruby-types", "rust-types", "typescript-types", "csharp-types"} { + if !ready[want] { + t.Errorf("provider %s not reported ready in Stats", want) + } + } +} + +// An `enabled: false` config entry switches one provider off while the +// others keep running. +func TestManager_ConfigDisablesOneProvider(t *testing.T) { + g, dir := buildFixture(t, mixedFixture()) + cfg := semantic.Config{ + Enabled: true, + Providers: []semantic.ProviderConfig{ + {Name: "java-types", Languages: []string{"java"}, Enabled: false}, + }, + } + mgr := semantic.NewManager(cfg, zap.NewNop()) + defer func() { _ = mgr.Close() }() + for _, p := range DefaultProviders(zap.NewNop()) { + mgr.RegisterProvider(p) + } + results, err := mgr.EnrichAll(g, map[string]string{"": dir}) + if err != nil { + t.Fatal(err) + } + for _, r := range results { + if r.Provider == "java-types" { + t.Fatalf("disabled provider still ran: %+v", r) + } + } + assertUntouched(t, g, "b/App.java::App.main", "run", "java-types") + if callEdgeTo(g, "app/main.py::main", "app/svc.py::Svc.run") == nil { + t.Errorf("python provider should keep running when java is disabled") + } +} + +// The manager's incremental path runs supplemental providers for the +// file's language even when no arbitration winner exists. +func TestManager_EnrichFileRunsSupplemental(t *testing.T) { + g, dir := buildFixture(t, mixedFixture()) + mgr := semantic.NewManager(semantic.Config{Enabled: true, EnrichOnWatch: true}, zap.NewNop()) + defer func() { _ = mgr.Close() }() + for _, p := range DefaultProviders(zap.NewNop()) { + mgr.RegisterProvider(p) + } + res, err := mgr.EnrichFile(g, dir, "b/App.java") + if err != nil { + t.Fatal(err) + } + if res == nil { + t.Fatal("EnrichFile returned no result") + } + caller := "b/App.java::App.main" + if callEdgeTo(g, caller, "a/Svc.java::Svc.run") == nil { + t.Fatalf("incremental enrichment did not resolve the call; edges: %v", g.GetOutEdges(caller)) + } +} diff --git a/internal/semantic/tstypes/provider.go b/internal/semantic/tstypes/provider.go new file mode 100644 index 00000000..c556b020 --- /dev/null +++ b/internal/semantic/tstypes/provider.go @@ -0,0 +1,211 @@ +package tstypes + +import ( + "runtime" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/semantic" +) + +// Provider is the semantic.Provider over one LangSpec. Pure in-process +// — Available is unconditionally true, no subprocess is ever spawned, +// Close is a no-op. It is supplemental: it augments whichever provider +// wins the per-language arbitration (LSP / SCIP) instead of competing +// with it, and only ever stamps AST-grade provenance, so a +// compiler-grade pass running before or after never gets downgraded. +type Provider struct { + spec *LangSpec + logger *zap.Logger +} + +// NewProvider wraps a LangSpec as a semantic provider. +func NewProvider(spec *LangSpec, logger *zap.Logger) *Provider { + if logger == nil { + logger = zap.NewNop() + } + return &Provider{spec: spec, logger: logger} +} + +// DefaultProviders returns the in-process type resolvers for every +// supported language. Registered unconditionally at daemon boot — +// disable one via a `semantic.providers` config entry with +// `enabled: false` under its name. +func DefaultProviders(logger *zap.Logger) []*Provider { + return []*Provider{ + NewProvider(JavaSpec(), logger), + NewProvider(PythonSpec(), logger), + NewProvider(RubySpec(), logger), + NewProvider(RustSpec(), logger), + NewProvider(TypeScriptSpec(), logger), + NewProvider(CSharpSpec(), logger), + } +} + +func (p *Provider) Name() string { return p.spec.ProviderName } +func (p *Provider) Languages() []string { return p.spec.Languages } +func (p *Provider) Available() bool { return true } +func (p *Provider) Close() error { return nil } + +// Supplemental marks this provider as augmenting (see +// semantic.SupplementalProvider): the manager runs it for its +// languages in addition to the arbitration winner. +func (p *Provider) Supplemental() bool { return true } + +// Enrich runs the full-repo pass for a single-repo (un-prefixed) graph. +// It delegates to EnrichRepo with an empty prefix — the in-memory single +// repo case where every real node carries RepoPrefix "". +func (p *Provider) Enrich(g graph.Store, repoRoot string) (*semantic.EnrichResult, error) { + return p.EnrichRepo(g, "", repoRoot) +} + +// EnrichRepo runs the full-repo pass: parse every file of the provider's +// languages that belong to repoPrefix under repoRoot in a bounded worker +// pool, then apply the per-file facts to the graph from a single +// goroutine. repoPrefix scopes file selection so a multi-repo graph with +// a colliding relative path never reads the wrong repo's bytes. +func (p *Provider) EnrichRepo(g graph.Store, repoPrefix, repoRoot string) (*semantic.EnrichResult, error) { + start := time.Now() + res := &semantic.EnrichResult{ + Provider: p.Name(), + Language: p.spec.Languages[0], + } + + files := languageFiles(g, p.spec, repoPrefix, repoRoot) + if len(files) > 0 { + workers := runtime.GOMAXPROCS(0) + if workers > 8 { + workers = 8 + } + if workers > len(files) { + workers = len(files) + } + jobs := make(chan fileRef) + factsCh := make(chan *fileFacts, workers) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for ref := range jobs { + facts, err := analyzeFile(p.spec, ref) + if err != nil { + p.logger.Debug("tstypes: file analysis failed", + zap.String("provider", p.Name()), + zap.String("file", ref.node.FilePath), + zap.Error(err)) + continue + } + if facts != nil { + factsCh <- facts + } + } + }() + } + go func() { + for _, ref := range files { + jobs <- ref + } + close(jobs) + wg.Wait() + close(factsCh) + }() + + var all []*fileFacts + for facts := range factsCh { + all = append(all, facts) + } + // Parsing above is pure and fans out across workers; the apply + // phase mutates the shared graph (retargets edges, reindexes, + // stamps provenance) and MUST run under the graph-wide resolve + // mutex so it serialises against concurrent resolver / cross-repo + // passes — the same lock every other edge-mutating pass holds. + mu := g.ResolveMutex() + mu.Lock() + ap := newApplier(g, p.spec, p.Name()) + ap.applyAll(all, res) + ap.flush() + analyzed := make(map[string]bool, len(all)) + for _, facts := range all { + analyzed[facts.file] = true + } + p.countCoverage(g, analyzed, res) + mu.Unlock() + } + + res.DurationMs = time.Since(start).Milliseconds() + return res, nil +} + +// EnrichFile runs the single-file incremental pass. filePath is the graph +// file key (prefixed in multi-repo mode), which is globally unique — the +// file's own node names the repo, so this is inherently scoped to the +// right repo without a separate prefix argument. +func (p *Provider) EnrichFile(g graph.Store, repoRoot, filePath string) (*semantic.EnrichResult, error) { + start := time.Now() + res := &semantic.EnrichResult{ + Provider: p.Name(), + Language: p.spec.Languages[0], + } + // Find the file's own node by its exact graph key. It carries the + // RepoPrefix that maps the prefixed path back to the on-disk file. + var fileNode *graph.Node + for _, n := range g.GetFileNodes(filePath) { + if n.Kind == graph.KindFile { + fileNode = n + break + } + } + if fileNode == nil || !p.spec.handles(fileNode.Language) { + res.DurationMs = time.Since(start).Milliseconds() + return res, nil + } + ref, ok := fileRefFor(fileNode, repoRoot) + if !ok { + res.DurationMs = time.Since(start).Milliseconds() + return res, nil + } + facts, err := analyzeFile(p.spec, ref) + if err != nil { + return nil, err + } + if facts != nil { + // Same contract as the full pass: the apply phase mutates the + // shared graph and runs under the resolve mutex so it does not + // race a concurrent watcher / resolver pass on another file. + mu := g.ResolveMutex() + mu.Lock() + ap := newApplier(g, p.spec, p.Name()) + ap.applyAll([]*fileFacts{facts}, res) + ap.flush() + p.countCoverage(g, map[string]bool{facts.file: true}, res) + mu.Unlock() + } + res.DurationMs = time.Since(start).Milliseconds() + return res, nil +} + +// countCoverage fills the symbols-covered counters: total is every +// symbol of the provider's languages, covered is the subset living in +// files the pass analyzed. +func (p *Provider) countCoverage(g graph.Store, analyzed map[string]bool, res *semantic.EnrichResult) { + langs := make(map[string]bool, len(p.spec.Languages)) + for _, l := range p.spec.Languages { + langs[l] = true + } + for _, n := range g.AllNodes() { + if !langs[n.Language] || n.Kind == graph.KindFile || n.Kind == graph.KindImport { + continue + } + res.SymbolsTotal++ + if analyzed[n.FilePath] { + res.SymbolsCovered++ + } + } + if res.SymbolsTotal > 0 { + res.CoveragePercent = float64(res.SymbolsCovered) / float64(res.SymbolsTotal) * 100 + } +} diff --git a/internal/semantic/tstypes/python.go b/internal/semantic/tstypes/python.go new file mode 100644 index 00000000..37c0ad75 --- /dev/null +++ b/internal/semantic/tstypes/python.go @@ -0,0 +1,283 @@ +package tstypes + +import ( + "strings" + "unicode" + + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" + "github.com/zzet/gortex/internal/parser/tsitter/python" +) + +// PythonSpec adapts the engine to tree-sitter-python. Typing evidence +// comes from PEP-484 annotations (params, locals, returns) and from +// CapWords constructor calls (`x = Foo()`); the latter are +// convention-based, so the engine's apply phase only acts when the +// name resolves to a real graph class node. `self.x` attributes bind +// in the class scope; explicit base classes synthesize extends edges. +func PythonSpec() *LangSpec { + grammar := python.GetLanguage() + return &LangSpec{ + ProviderName: "python-types", + Languages: []string{"python"}, + GrammarFor: func(string) *sitter.Language { return grammar }, + TypeDeclTypes: map[string]bool{ + "class_definition": true, + }, + FuncDeclTypes: map[string]bool{ + "function_definition": true, + }, + SelfName: "self", + TypeDeclName: nameField, + Supertypes: pySupertypes, + Fields: pyFields, + Params: pyParams, + ReturnType: func(fn *sitter.Node, src []byte) string { + return fieldText(fn, "return_type", src) + }, + LocalBinding: pyLocalBinding, + Call: pyCall, + NewExprType: pyNewExprType, + FieldRef: func(n *sitter.Node, src []byte) (string, bool) { + if n.Type() != "attribute" { + return "", false + } + obj := n.ChildByFieldName("object") + if obj == nil || obj.Type() != "identifier" || obj.Content(src) != "self" { + return "", false + } + return fieldText(n, "attribute", src), true + }, + Imports: pyImports, + } +} + +func pySupertypes(n *sitter.Node, src []byte) []SuperRef { + supers := n.ChildByFieldName("superclasses") + if supers == nil { + return nil + } + var out []SuperRef + for i := 0; i < int(supers.NamedChildCount()); i++ { + c := supers.NamedChild(i) + switch c.Type() { + case "identifier", "attribute": + name := c.Content(src) + // Drop the typing protocol scaffolding bases — they carry + // no resolvable repo-local definition. + bare := name + if i := strings.LastIndex(bare, "."); i >= 0 { + bare = bare[i+1:] + } + if bare == "object" { + continue + } + out = append(out, SuperRef{Name: name, Kind: graph.EdgeExtends, Line: nodeLine(c)}) + } + } + return out +} + +// pyFields collects class-level annotated assignments and `self.x` +// initialisations inside method bodies. +func pyFields(n *sitter.Node, src []byte) []Binding { + body := n.ChildByFieldName("body") + if body == nil { + return nil + } + var out []Binding + var visit func(node *sitter.Node, selfOnly bool) + visit = func(node *sitter.Node, selfOnly bool) { + if node == nil { + return + } + if node.Type() == "class_definition" { + return // nested classes own their fields + } + if node.Type() == "assignment" { + left := node.ChildByFieldName("left") + typ := fieldText(node, "type", src) + if typ == "" { + if right := node.ChildByFieldName("right"); right != nil { + typ = pyNewExprType(right, src) + } + } + if left != nil && typ != "" { + switch left.Type() { + case "identifier": + // Class-level statement: a class attribute. Inside a + // method body the same shape is a local — skip there. + if !selfOnly { + out = append(out, Binding{Name: left.Content(src), Type: typ, Line: nodeLine(left)}) + } + case "attribute": + obj := left.ChildByFieldName("object") + if obj != nil && obj.Type() == "identifier" && obj.Content(src) == "self" { + out = append(out, Binding{Name: fieldText(left, "attribute", src), Type: typ, Line: nodeLine(left)}) + } + } + } + } + for i := 0; i < int(node.NamedChildCount()); i++ { + visit(node.NamedChild(i), selfOnly) + } + } + // Class-level: only direct statements; self.x: walk method bodies. + for i := 0; i < int(body.NamedChildCount()); i++ { + c := body.NamedChild(i) + switch c.Type() { + case "expression_statement": + visit(c, false) + case "function_definition": + if fb := c.ChildByFieldName("body"); fb != nil { + visit(fb, true) + } + } + } + return out +} + +func pyParams(fn *sitter.Node, src []byte) []Binding { + params := fn.ChildByFieldName("parameters") + if params == nil { + return nil + } + var out []Binding + for i := 0; i < int(params.NamedChildCount()); i++ { + p := params.NamedChild(i) + switch p.Type() { + case "identifier": + out = append(out, Binding{Name: p.Content(src), Line: nodeLine(p)}) + case "typed_parameter": + var name string + if id := firstChildOfType(p, "identifier"); id != nil { + name = id.Content(src) + } + if name == "" { + continue + } + out = append(out, Binding{Name: name, Type: fieldText(p, "type", src), Line: nodeLine(p)}) + case "default_parameter": + name := fieldText(p, "name", src) + if name == "" { + continue + } + out = append(out, Binding{Name: name, Line: nodeLine(p)}) + case "typed_default_parameter": + name := fieldText(p, "name", src) + if name == "" { + continue + } + out = append(out, Binding{Name: name, Type: fieldText(p, "type", src), Line: nodeLine(p)}) + } + } + return out +} + +func pyLocalBinding(n *sitter.Node, src []byte) (LocalBind, bool) { + if n.Type() != "assignment" { + return LocalBind{}, false + } + left := n.ChildByFieldName("left") + if left == nil { + return LocalBind{}, false + } + declType := fieldText(n, "type", src) + init := n.ChildByFieldName("right") + switch left.Type() { + case "identifier": + return LocalBind{Name: left.Content(src), DeclType: declType, Init: init}, true + case "attribute": + obj := left.ChildByFieldName("object") + if obj != nil && obj.Type() == "identifier" && obj.Content(src) == "self" { + return LocalBind{Name: fieldText(left, "attribute", src), DeclType: declType, Init: init, Field: true}, true + } + } + return LocalBind{}, false +} + +func pyCall(n *sitter.Node, src []byte) (*sitter.Node, string, bool) { + if n.Type() != "call" { + return nil, "", false + } + fn := n.ChildByFieldName("function") + if fn == nil || fn.Type() != "attribute" { + return nil, "", false + } + obj := fn.ChildByFieldName("object") + if obj == nil { + return nil, "", false + } + return obj, fieldText(fn, "attribute", src), true +} + +// pyNewExprType treats `Foo(...)` as a constructor candidate when the +// callee follows the CapWords class convention. The apply phase still +// verifies the name against a real graph type node before resolving +// anything through it, so a capitalized factory function never grounds +// a false receiver. +func pyNewExprType(n *sitter.Node, src []byte) string { + if n.Type() != "call" { + return "" + } + fn := n.ChildByFieldName("function") + if fn == nil || fn.Type() != "identifier" { + return "" + } + name := fn.Content(src) + if name == "" { + return "" + } + r := []rune(name) + if !unicode.IsUpper(r[0]) { + return "" + } + return name +} + +func pyImports(root *sitter.Node, src []byte) []Import { + var out []Import + var visit func(n *sitter.Node) + visit = func(n *sitter.Node) { + if n == nil { + return + } + switch n.Type() { + case "import_from_statement": + module := fieldText(n, "module_name", src) + if module == "" { + return + } + path := strings.ReplaceAll(module, ".", "/") + for i := 0; i < int(n.NamedChildCount()); i++ { + c := n.NamedChild(i) + switch c.Type() { + case "dotted_name": + if c.Content(src) == module { + continue // the module_name child itself + } + out = append(out, Import{Local: c.Content(src), Path: path}) + case "aliased_import": + alias := fieldText(c, "alias", src) + name := fieldText(c, "name", src) + if alias == "" { + alias = name + } + if alias != "" { + out = append(out, Import{Local: alias, Path: path}) + } + } + } + case "import_statement": + // `import a.b` binds the package root, not a class name — + // only `import x as y` introduces a flat local binding, + // and that's a module, not a type. Skip. + default: + for i := 0; i < int(n.NamedChildCount()); i++ { + visit(n.NamedChild(i)) + } + } + } + visit(root) + return out +} diff --git a/internal/semantic/tstypes/python_test.go b/internal/semantic/tstypes/python_test.go new file mode 100644 index 00000000..4d5f8ef4 --- /dev/null +++ b/internal/semantic/tstypes/python_test.go @@ -0,0 +1,194 @@ +package tstypes + +import ( + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" +) + +const pySvc = `class Svc: + def run(self): + pass + + def stop(self): + pass +` + +func TestPython_AnnotatedParamResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "app/svc.py": pySvc, + "app/main.py": `from app.svc import Svc + + +def handle(s: Svc): + s.run() +`, + }) + p := NewProvider(PythonSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "handle", graph.KindFunction) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + e := callEdgeTo(g, caller.ID, target.ID) + if e == nil { + t.Fatalf("annotated-param call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } + assertASTProvenance(t, e, "python-types") +} + +// `s = Svc()` is convention-based constructor inference — it must only +// fire because Svc resolves to a real class node. +func TestPython_ConstructorInferenceResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "app/svc.py": pySvc, + "app/main.py": `from app.svc import Svc + + +def main(): + s = Svc() + s.run() +`, + }) + p := NewProvider(PythonSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, caller.ID, target.ID) == nil { + t.Fatalf("constructor-inferred call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } +} + +// A capitalized factory that is NOT a class must not ground a receiver. +func TestPython_CapitalizedFactoryDoesNotGround(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "app/util.py": `def Build(): + return object() +`, + "app/main.py": `from app.util import Build + + +def main(): + s = Build() + s.run() +`, + }) + p := NewProvider(PythonSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + assertUntouched(t, g, caller.ID, "run", "python-types") +} + +func TestPython_ImportHintDisambiguates(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "app/svc.py": pySvc, + "other/svc.py": pySvc, + "app/main.py": `from app.svc import Svc + + +def main(): + s = Svc() + s.run() +`, + }) + p := NewProvider(PythonSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + want := "app/svc.py::Svc.run" + if callEdgeTo(g, caller.ID, want) == nil { + t.Fatalf("import-hinted call did not land on %s; edges: %v", want, g.GetOutEdges(caller.ID)) + } + if callEdgeTo(g, caller.ID, "other/svc.py::Svc.run") != nil { + t.Fatal("call landed on the wrong module's class") + } +} + +func TestPython_BaseClassExtendsSynthesis(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "app/svc.py": pySvc, + "app/sub.py": `from app.svc import Svc + + +class Sub(Svc): + def extra(self): + pass +`, + }) + p := NewProvider(PythonSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + sub := nodeByNameKind(t, g, "Sub", graph.KindType) + svc := nodeByNameKind(t, g, "Svc", graph.KindType) + e := edgeBetween(g, sub.ID, graph.EdgeExtends, svc.ID) + if e == nil { + t.Fatalf("extends edge missing; edges: %v", g.GetOutEdges(sub.ID)) + } + assertASTProvenance(t, e, "python-types") +} + +// self-qualified calls and `self.x = Svc()` fields resolve in-class. +func TestPython_SelfAndFieldReceivers(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "app/svc.py": pySvc, + "app/app.py": `from app.svc import Svc + + +class App: + def __init__(self): + self.worker = Svc() + + def direct(self): + self.helper() + + def helper(self): + self.worker.run() +`, + }) + p := NewProvider(PythonSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + direct := nodeByNameKind(t, g, "direct", graph.KindMethod) + helper := nodeByNameKind(t, g, "helper", graph.KindMethod) + if callEdgeTo(g, direct.ID, helper.ID) == nil { + t.Fatalf("self.helper() not resolved; edges: %v", g.GetOutEdges(direct.ID)) + } + run := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, helper.ID, run.ID) == nil { + t.Fatalf("self.worker.run() not resolved; edges: %v", g.GetOutEdges(helper.ID)) + } +} + +func TestPython_AmbiguousReceiverStaysUntouched(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "app/svc.py": pySvc, + "app/alt.py": `class Alt: + def run(self): + pass +`, + "app/main.py": `from app.alt import Alt +from app.svc import Svc + + +def main(): + s = Svc() + s = Alt() + s.run() +`, + }) + p := NewProvider(PythonSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + assertUntouched(t, g, caller.ID, "run", "python-types") +} diff --git a/internal/semantic/tstypes/ruby.go b/internal/semantic/tstypes/ruby.go new file mode 100644 index 00000000..da952864 --- /dev/null +++ b/internal/semantic/tstypes/ruby.go @@ -0,0 +1,237 @@ +package tstypes + +import ( + "strings" + + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" + "github.com/zzet/gortex/internal/parser/tsitter/ruby" +) + +// RubySpec adapts the engine to tree-sitter-ruby. Ruby has no type +// annotations, so every binding comes from `Const.new` constructor +// inference — params and bare locals stay unknown and their calls are +// honestly skipped. Mixins (`include` / `extend` / `prepend`) become +// implements edges; `class Foo < Bar` becomes extends. +func RubySpec() *LangSpec { + grammar := ruby.GetLanguage() + return &LangSpec{ + ProviderName: "ruby-types", + Languages: []string{"ruby"}, + GrammarFor: func(string) *sitter.Language { return grammar }, + TypeDeclTypes: map[string]bool{ + "class": true, + "module": true, + }, + FuncDeclTypes: map[string]bool{ + "method": true, + "singleton_method": true, + }, + SelfName: "self", + TypeDeclName: nameField, + Supertypes: rubySupertypes, + Fields: rubyFields, + Params: rubyParams, + ReturnType: nil, // no return annotations in Ruby + LocalBinding: rubyLocalBinding, + Call: rubyCall, + NewExprType: rubyNewExprType, + FieldRef: func(n *sitter.Node, src []byte) (string, bool) { + if n.Type() != "instance_variable" { + return "", false + } + return n.Content(src), true + }, + Imports: nil, // require paths don't bind constant names + // `include M` targets a module, which the ruby extractor + // indexes as a package node. + SupertypeKinds: map[graph.NodeKind]bool{ + graph.KindType: true, + graph.KindInterface: true, + graph.KindPackage: true, + }, + } +} + +func rubySupertypes(n *sitter.Node, src []byte) []SuperRef { + var out []SuperRef + if n.Type() == "class" { + if sup := n.ChildByFieldName("superclass"); sup != nil { + name := rubyConstantText(sup, src) + if name != "" { + out = append(out, SuperRef{Name: name, Kind: graph.EdgeExtends, Line: nodeLine(sup)}) + } + } + } + // `include M` / `extend M` / `prepend M` as direct body statements + // mix the module's methods in — the closest Ruby has to an + // implements relation. + body := rubyBody(n) + if body == nil { + return out + } + for i := 0; i < int(body.NamedChildCount()); i++ { + c := body.NamedChild(i) + if c.Type() != "call" { + continue + } + if c.ChildByFieldName("receiver") != nil { + continue + } + method := fieldText(c, "method", src) + if method != "include" && method != "extend" && method != "prepend" { + continue + } + args := c.ChildByFieldName("arguments") + if args == nil { + continue + } + for j := 0; j < int(args.NamedChildCount()); j++ { + a := args.NamedChild(j) + if name := rubyConstantText(a, src); name != "" { + out = append(out, SuperRef{Name: name, Kind: graph.EdgeImplements, Line: nodeLine(a)}) + } + } + } + return out +} + +// rubyBody returns a class/module's body_statement node — a `body` +// field in newer grammar revisions, an anonymous child in older ones. +func rubyBody(n *sitter.Node) *sitter.Node { + if body := n.ChildByFieldName("body"); body != nil { + return body + } + return firstChildOfType(n, "body_statement") +} + +// rubyConstantText extracts a constant (or scope-resolved constant) +// name from a node, "" for anything else. +func rubyConstantText(n *sitter.Node, src []byte) string { + switch n.Type() { + case "constant": + return n.Content(src) + case "scope_resolution", "superclass": + // take the trailing constant + for i := int(n.NamedChildCount()) - 1; i >= 0; i-- { + if c := n.NamedChild(i); c.Type() == "constant" || c.Type() == "scope_resolution" { + return rubyConstantText(c, src) + } + } + } + return "" +} + +// rubyFields scans the class's method bodies for `@x = Const.new` +// initialisations — the only grounded instance-variable typing +// evidence an annotation-free language offers. +func rubyFields(n *sitter.Node, src []byte) []Binding { + body := rubyBody(n) + if body == nil { + return nil + } + var out []Binding + var visit func(node *sitter.Node) + visit = func(node *sitter.Node) { + if node == nil { + return + } + switch node.Type() { + case "class", "module": + return // nested types own their ivars + case "assignment": + left := node.ChildByFieldName("left") + right := node.ChildByFieldName("right") + if left != nil && right != nil && left.Type() == "instance_variable" { + if typ := rubyNewExprType(right, src); typ != "" { + out = append(out, Binding{Name: left.Content(src), Type: typ, Line: nodeLine(left)}) + } + } + } + for i := 0; i < int(node.NamedChildCount()); i++ { + visit(node.NamedChild(i)) + } + } + for i := 0; i < int(body.NamedChildCount()); i++ { + visit(body.NamedChild(i)) + } + return out +} + +func rubyParams(fn *sitter.Node, src []byte) []Binding { + params := fn.ChildByFieldName("parameters") + if params == nil { + return nil + } + var out []Binding + for i := 0; i < int(params.NamedChildCount()); i++ { + p := params.NamedChild(i) + name := "" + switch p.Type() { + case "identifier": + name = p.Content(src) + case "optional_parameter", "keyword_parameter": + name = fieldText(p, "name", src) + } + if name != "" { + out = append(out, Binding{Name: name, Line: nodeLine(p)}) + } + } + return out +} + +func rubyLocalBinding(n *sitter.Node, src []byte) (LocalBind, bool) { + if n.Type() != "assignment" { + return LocalBind{}, false + } + left := n.ChildByFieldName("left") + if left == nil { + return LocalBind{}, false + } + init := n.ChildByFieldName("right") + switch left.Type() { + case "identifier": + return LocalBind{Name: left.Content(src), Init: init}, true + case "instance_variable": + return LocalBind{Name: left.Content(src), Init: init, Field: true}, true + } + return LocalBind{}, false +} + +// rubyCall decodes `recv.method(...)`, excluding the `Const.new` +// constructor shape (that's NewExprType's job) and mixin keywords. +func rubyCall(n *sitter.Node, src []byte) (*sitter.Node, string, bool) { + if n.Type() != "call" { + return nil, "", false + } + recv := n.ChildByFieldName("receiver") + if recv == nil { + return nil, "", false + } + method := fieldText(n, "method", src) + if method == "" || method == "new" { + return nil, "", false + } + return recv, method, true +} + +// rubyNewExprType recognises `Const.new(...)` (and `A::B.new`). The +// receiving constant is the constructed type; the apply phase verifies +// it against a graph type node. +func rubyNewExprType(n *sitter.Node, src []byte) string { + if n.Type() != "call" { + return "" + } + if fieldText(n, "method", src) != "new" { + return "" + } + recv := n.ChildByFieldName("receiver") + if recv == nil { + return "" + } + name := rubyConstantText(recv, src) + if i := strings.LastIndex(name, "::"); i >= 0 { + name = name[i+2:] + } + return name +} diff --git a/internal/semantic/tstypes/ruby_test.go b/internal/semantic/tstypes/ruby_test.go new file mode 100644 index 00000000..8d0ca3b5 --- /dev/null +++ b/internal/semantic/tstypes/ruby_test.go @@ -0,0 +1,161 @@ +package tstypes + +import ( + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" +) + +const rubySvc = `class Svc + def run + end + + def stop + end +end +` + +// Ruby has no annotations and no name-binding imports — constructor +// inference plus repo-unique name resolution carries the cross-file +// case. +func TestRuby_ConstructorInferenceResolvesCrossFile(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "lib/svc.rb": rubySvc, + "lib/app.rb": `class App + def main + s = Svc.new + s.run + end +end +`, + }) + p := NewProvider(RubySpec(), zap.NewNop()) + res, err := p.Enrich(g, dir) + if err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindMethod) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + e := callEdgeTo(g, caller.ID, target.ID) + if e == nil { + t.Fatalf("constructor-inferred call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } + assertASTProvenance(t, e, "ruby-types") + if res.EdgesConfirmed+res.EdgesAdded == 0 { + t.Errorf("result reported no edge work: %+v", res) + } +} + +// self-qualified calls and `@ivar = Const.new` receivers resolve +// in-class. +func TestRuby_SelfAndIvarReceivers(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "lib/svc.rb": rubySvc, + "lib/app.rb": `class App + def initialize + @worker = Svc.new + end + + def direct + self.helper + end + + def helper + @worker.run + end +end +`, + }) + p := NewProvider(RubySpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + direct := nodeByNameKind(t, g, "direct", graph.KindMethod) + helper := nodeByNameKind(t, g, "helper", graph.KindMethod) + if callEdgeTo(g, direct.ID, helper.ID) == nil { + t.Fatalf("self.helper not resolved; edges: %v", g.GetOutEdges(direct.ID)) + } + run := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, helper.ID, run.ID) == nil { + t.Fatalf("@worker.run not resolved through ivar type; edges: %v", g.GetOutEdges(helper.ID)) + } +} + +func TestRuby_SuperclassExtendsSynthesis(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "lib/svc.rb": rubySvc, + "lib/sub.rb": `class Sub < Svc + def extra + end +end +`, + }) + p := NewProvider(RubySpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + sub := nodeByNameKind(t, g, "Sub", graph.KindType) + svc := nodeByNameKind(t, g, "Svc", graph.KindType) + e := edgeBetween(g, sub.ID, graph.EdgeExtends, svc.ID) + if e == nil { + t.Fatalf("extends edge missing; edges: %v", g.GetOutEdges(sub.ID)) + } + assertASTProvenance(t, e, "ruby-types") +} + +// `include M` mixes a module in — the module indexes as a package +// node, and the engine still grounds the implements edge against it. +func TestRuby_IncludeModuleImplementsSynthesis(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "lib/greeter.rb": `module Greeter + def greet + end +end +`, + "lib/impl.rb": `class Impl + include Greeter + + def extra + end +end +`, + }) + p := NewProvider(RubySpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + impl := nodeByNameKind(t, g, "Impl", graph.KindType) + mod := nodeByNameKind(t, g, "Greeter", graph.KindPackage) + e := edgeBetween(g, impl.ID, graph.EdgeImplements, mod.ID) + if e == nil { + t.Fatalf("implements edge missing; edges: %v", g.GetOutEdges(impl.ID)) + } + assertASTProvenance(t, e, "ruby-types") +} + +func TestRuby_AmbiguousReceiverStaysUntouched(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "lib/svc.rb": rubySvc, + "lib/alt.rb": `class Alt + def run + end +end +`, + "lib/app.rb": `class App + def main + s = Svc.new + s = Alt.new + s.run + end +end +`, + }) + p := NewProvider(RubySpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindMethod) + assertUntouched(t, g, caller.ID, "run", "ruby-types") +} diff --git a/internal/semantic/tstypes/rust.go b/internal/semantic/tstypes/rust.go new file mode 100644 index 00000000..86b2ce12 --- /dev/null +++ b/internal/semantic/tstypes/rust.go @@ -0,0 +1,285 @@ +package tstypes + +import ( + "strings" + + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" + "github.com/zzet/gortex/internal/parser/tsitter/rust" +) + +// RustSpec adapts the engine to tree-sitter-rust. An `impl T` block is +// treated as a type scope named T so methods see the struct's fields +// (the field pre-pass keys both the struct_item and the impl blocks on +// the same name); `impl Trait for T` synthesizes an implements edge. +// Bindings come from let annotations, struct expressions (`T { .. }`), +// and the `T::new()` convention — verified against a graph type node +// before any resolution happens through them. +func RustSpec() *LangSpec { + grammar := rust.GetLanguage() + return &LangSpec{ + ProviderName: "rust-types", + Languages: []string{"rust"}, + GrammarFor: func(string) *sitter.Language { return grammar }, + TypeDeclTypes: map[string]bool{ + "struct_item": true, + "enum_item": true, + "trait_item": true, + "union_item": true, + "impl_item": true, + }, + FuncDeclTypes: map[string]bool{ + "function_item": true, + }, + SelfName: "self", + TypeDeclName: rustTypeDeclName, + Supertypes: rustSupertypes, + Fields: rustFields, + Params: rustParams, + ReturnType: func(fn *sitter.Node, src []byte) string { + return fieldText(fn, "return_type", src) + }, + LocalBinding: rustLocalBinding, + Call: rustCall, + NewExprType: rustNewExprType, + FieldRef: func(n *sitter.Node, src []byte) (string, bool) { + if n.Type() != "field_expression" { + return "", false + } + val := n.ChildByFieldName("value") + if val == nil || val.Type() != "self" { + return "", false + } + return fieldText(n, "field", src), true + }, + Imports: rustImports, + } +} + +func rustTypeDeclName(n *sitter.Node, src []byte) string { + if n.Type() == "impl_item" { + t := n.ChildByFieldName("type") + if t == nil { + return "" + } + return NormalizeTypeName(t.Content(src)) + } + return fieldText(n, "name", src) +} + +func rustSupertypes(n *sitter.Node, src []byte) []SuperRef { + if n.Type() != "impl_item" { + return nil + } + tr := n.ChildByFieldName("trait") + if tr == nil { + return nil + } + return []SuperRef{{Name: tr.Content(src), Kind: graph.EdgeImplements, Line: nodeLine(tr)}} +} + +func rustFields(n *sitter.Node, src []byte) []Binding { + if n.Type() != "struct_item" && n.Type() != "union_item" { + return nil + } + body := n.ChildByFieldName("body") + if body == nil || body.Type() != "field_declaration_list" { + return nil + } + var out []Binding + for i := 0; i < int(body.NamedChildCount()); i++ { + c := body.NamedChild(i) + if c.Type() != "field_declaration" { + continue + } + name := fieldText(c, "name", src) + if name == "" { + continue + } + out = append(out, Binding{Name: name, Type: fieldText(c, "type", src), Line: nodeLine(c)}) + } + return out +} + +func rustParams(fn *sitter.Node, src []byte) []Binding { + params := fn.ChildByFieldName("parameters") + if params == nil { + return nil + } + var out []Binding + for i := 0; i < int(params.NamedChildCount()); i++ { + p := params.NamedChild(i) + if p.Type() != "parameter" { + continue + } + pattern := p.ChildByFieldName("pattern") + if pattern == nil || pattern.Type() != "identifier" { + continue + } + out = append(out, Binding{Name: pattern.Content(src), Type: fieldText(p, "type", src), Line: nodeLine(p)}) + } + return out +} + +func rustLocalBinding(n *sitter.Node, src []byte) (LocalBind, bool) { + switch n.Type() { + case "let_declaration": + pattern := n.ChildByFieldName("pattern") + if pattern == nil || pattern.Type() != "identifier" { + return LocalBind{}, false + } + return LocalBind{ + Name: pattern.Content(src), + DeclType: fieldText(n, "type", src), + Init: n.ChildByFieldName("value"), + }, true + case "assignment_expression": + left := n.ChildByFieldName("left") + if left == nil || left.Type() != "identifier" { + return LocalBind{}, false + } + return LocalBind{Name: left.Content(src), Init: n.ChildByFieldName("right")}, true + } + return LocalBind{}, false +} + +func rustCall(n *sitter.Node, src []byte) (*sitter.Node, string, bool) { + if n.Type() != "call_expression" { + return nil, "", false + } + fn := n.ChildByFieldName("function") + if fn == nil || fn.Type() != "field_expression" { + return nil, "", false + } + val := fn.ChildByFieldName("value") + field := fn.ChildByFieldName("field") + if val == nil || field == nil || field.Type() != "field_identifier" { + return nil, "", false + } + return val, field.Content(src), true +} + +// rustNewExprType recognises the two constructor shapes: a struct +// expression `Foo { .. }` and the `Foo::new(..)` convention. +func rustNewExprType(n *sitter.Node, src []byte) string { + switch n.Type() { + case "struct_expression": + if name := n.ChildByFieldName("name"); name != nil { + return NormalizeTypeName(name.Content(src)) + } + case "call_expression": + fn := n.ChildByFieldName("function") + if fn == nil { + return "" + } + // Foo::new / module::Foo::new / Foo::::new + if fn.Type() == "scoped_identifier" || fn.Type() == "generic_function" { + text := fn.Content(src) + segs := strings.Split(text, "::") + if len(segs) < 2 || strings.TrimSpace(segs[len(segs)-1]) != "new" { + return "" + } + owner := strings.TrimSpace(segs[len(segs)-2]) + if owner == "" || owner == "Self" { + return "" + } + return NormalizeTypeName(owner) + } + } + return "" +} + +func rustImports(root *sitter.Node, src []byte) []Import { + var out []Import + var visit func(n *sitter.Node) + visit = func(n *sitter.Node) { + if n == nil { + return + } + if n.Type() == "use_declaration" { + if arg := n.ChildByFieldName("argument"); arg != nil { + out = append(out, rustUseImports(arg, src, "")...) + } + return + } + for i := 0; i < int(n.NamedChildCount()); i++ { + visit(n.NamedChild(i)) + } + } + visit(root) + return out +} + +// rustUseImports flattens a use tree (scoped paths, `as` clauses, and +// brace lists) into local-name bindings. The Path hint is the MODULE +// path — the imported name's parent — because that is what maps onto a +// definition file (`use crate::engine::Svc` lives in engine.rs / +// engine/mod.rs, not in a file named after the type). +func rustUseImports(n *sitter.Node, src []byte, prefix string) []Import { + switch n.Type() { + case "identifier", "type_identifier": + return []Import{{Local: n.Content(src), Path: prefix}} + case "scoped_identifier": + full := strings.ReplaceAll(n.Content(src), "::", "/") + name := full + if i := strings.LastIndex(name, "/"); i >= 0 { + name = name[i+1:] + } + return []Import{{Local: name, Path: joinUsePath(prefix, useParent(full))}} + case "use_as_clause": + alias := fieldText(n, "alias", src) + path := n.ChildByFieldName("path") + if alias == "" || path == nil { + return nil + } + full := strings.ReplaceAll(path.Content(src), "::", "/") + return []Import{{Local: alias, Path: joinUsePath(prefix, useParent(full))}} + case "scoped_use_list": + path := n.ChildByFieldName("path") + list := n.ChildByFieldName("list") + if list == nil { + return nil + } + base := prefix + if path != nil { + base = joinUsePath(prefix, strings.ReplaceAll(path.Content(src), "::", "/")) + } + var out []Import + for i := 0; i < int(list.NamedChildCount()); i++ { + out = append(out, rustUseImports(list.NamedChild(i), src, base)...) + } + return out + case "use_list": + var out []Import + for i := 0; i < int(n.NamedChildCount()); i++ { + out = append(out, rustUseImports(n.NamedChild(i), src, prefix)...) + } + return out + } + return nil +} + +func joinUsePath(prefix, rest string) string { + for _, tok := range []string{"crate", "self", "super"} { + rest = strings.TrimPrefix(rest, tok+"/") + if rest == tok { + rest = "" + } + } + switch { + case prefix == "": + return rest + case rest == "": + return prefix + } + return prefix + "/" + rest +} + +// useParent strips a use path's final segment — the imported name — +// leaving the module path. +func useParent(full string) string { + if i := strings.LastIndex(full, "/"); i >= 0 { + return full[:i] + } + return "" +} diff --git a/internal/semantic/tstypes/rust_test.go b/internal/semantic/tstypes/rust_test.go new file mode 100644 index 00000000..32759804 --- /dev/null +++ b/internal/semantic/tstypes/rust_test.go @@ -0,0 +1,232 @@ +package tstypes + +import ( + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" +) + +const rustSvc = `pub struct Svc { + count: u32, +} + +impl Svc { + pub fn run(&self) {} + pub fn stop(&self) {} +} +` + +func TestRust_DeclaredParamTypeResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/engine.rs": rustSvc, + "src/app.rs": `use crate::engine::Svc; + +pub fn handle(s: &Svc) { + s.run(); +} +`, + }) + p := NewProvider(RustSpec(), zap.NewNop()) + res, err := p.Enrich(g, dir) + if err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "handle", graph.KindFunction) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + e := callEdgeTo(g, caller.ID, target.ID) + if e == nil { + t.Fatalf("annotated-param call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } + assertASTProvenance(t, e, "rust-types") + if res.EdgesConfirmed+res.EdgesAdded == 0 { + t.Errorf("result reported no edge work: %+v", res) + } +} + +// Both constructor shapes ground a receiver: the struct expression and +// the `T::new()` convention. +func TestRust_ConstructorInferenceResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/engine.rs": rustSvc, + "src/app.rs": `use crate::engine::Svc; + +pub fn literal() { + let s = Svc { count: 0 }; + s.run(); +} + +pub fn convention() { + let s = Svc::new(); + s.stop(); +} +`, + }) + p := NewProvider(RustSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + literal := nodeByNameKind(t, g, "literal", graph.KindFunction) + run := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, literal.ID, run.ID) == nil { + t.Fatalf("struct-expression call not resolved; edges: %v", g.GetOutEdges(literal.ID)) + } + convention := nodeByNameKind(t, g, "convention", graph.KindFunction) + stop := nodeByNameKind(t, g, "stop", graph.KindMethod) + if callEdgeTo(g, convention.ID, stop.ID) == nil { + t.Fatalf("Svc::new() call not resolved; edges: %v", g.GetOutEdges(convention.ID)) + } +} + +// Two same-named structs: the use-path hint picks the matching module +// file. +func TestRust_ImportHintDisambiguates(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/engine.rs": rustSvc, + "src/legacy.rs": rustSvc, + "src/app.rs": `use crate::engine::Svc; + +pub fn main() { + let s = Svc::new(); + s.run(); +} +`, + }) + p := NewProvider(RustSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + want := "src/engine.rs::Svc.run" + if callEdgeTo(g, caller.ID, want) == nil { + t.Fatalf("use-hinted call did not land on %s; edges: %v", want, g.GetOutEdges(caller.ID)) + } + if callEdgeTo(g, caller.ID, "src/legacy.rs::Svc.run") != nil { + t.Fatal("call landed on the wrong module's struct") + } +} + +func TestRust_ImplTraitImplementsSynthesis(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/greeter.rs": `pub trait Greeter { + fn greet(&self); +} +`, + "src/widget.rs": `use crate::greeter::Greeter; + +pub struct Widget { + id: u32, +} + +impl Greeter for Widget { + fn greet(&self) {} +} +`, + }) + p := NewProvider(RustSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + widget := nodeByNameKind(t, g, "Widget", graph.KindType) + trait := nodeByNameKind(t, g, "Greeter", graph.KindInterface) + e := edgeBetween(g, widget.ID, graph.EdgeImplements, trait.ID) + if e == nil { + t.Fatalf("implements edge missing; edges: %v", g.GetOutEdges(widget.ID)) + } + assertASTProvenance(t, e, "rust-types") +} + +// self-qualified calls and struct-field receivers resolve through the +// impl block, with the field declared on the struct. +func TestRust_SelfAndFieldReceivers(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/engine.rs": rustSvc, + "src/app.rs": `use crate::engine::Svc; + +pub struct App { + worker: Svc, +} + +impl App { + pub fn direct(&self) { + self.helper(); + } + + pub fn helper(&self) { + self.worker.run(); + } +} +`, + }) + p := NewProvider(RustSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + direct := nodeByNameKind(t, g, "direct", graph.KindMethod) + helper := nodeByNameKind(t, g, "helper", graph.KindMethod) + if callEdgeTo(g, direct.ID, helper.ID) == nil { + t.Fatalf("self.helper() not resolved; edges: %v", g.GetOutEdges(direct.ID)) + } + run := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, helper.ID, run.ID) == nil { + t.Fatalf("self.worker.run() not resolved through field type; edges: %v", g.GetOutEdges(helper.ID)) + } +} + +// A local initialised from a bare call takes the callee's declared +// return type. +func TestRust_FunctionReturnTypePropagates(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/engine.rs": rustSvc, + "src/app.rs": `use crate::engine::Svc; + +pub fn build() -> Svc { + Svc { count: 0 } +} + +pub fn main() { + let s = build(); + s.run(); +} +`, + }) + p := NewProvider(RustSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + run := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, caller.ID, run.ID) == nil { + t.Fatalf("return-type-propagated call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } +} + +func TestRust_AmbiguousReceiverStaysUntouched(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/engine.rs": rustSvc, + "src/alt.rs": `pub struct Alt { + id: u32, +} + +impl Alt { + pub fn run(&self) {} +} +`, + "src/app.rs": `use crate::alt::Alt; +use crate::engine::Svc; + +pub fn main() { + let mut s = Svc::new(); + s = Alt::new(); + s.run(); +} +`, + }) + p := NewProvider(RustSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + assertUntouched(t, g, caller.ID, "run", "rust-types") +} diff --git a/internal/semantic/tstypes/scope.go b/internal/semantic/tstypes/scope.go new file mode 100644 index 00000000..dcfc8115 --- /dev/null +++ b/internal/semantic/tstypes/scope.go @@ -0,0 +1,394 @@ +package tstypes + +import ( + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" +) + +// fileFacts is the pure-syntax output of one file's binder walk. It +// carries no graph references so the parse/walk phase can run on +// worker goroutines while the apply phase owns every store interaction. +type fileFacts struct { + file string // graph FilePath of the analyzed file + repoPrefix string + imports []Import + calls []callFact + supers []superFact + metas []metaFact +} + +// callFact is one receiver-qualified call site with whatever receiver +// evidence the binder could ground. Exactly one of recvType / +// recvPendingCallee / recvIdent is usually set; the apply phase +// resolves them in that priority order and skips the call when none +// lands on a verified graph type node. +type callFact struct { + line int // 1-based call line + method string + // recvType is the receiver's bound type name (annotation, + // constructor inference, or propagation through locals). + recvType string + // recvPendingCallee is set when the receiver local was initialised + // from a bare call (`u = build_user()`); the apply phase resolves + // the callee's graph return_type. + recvPendingCallee string + // recvIdent is set when the receiver is an identifier with no + // binding in scope — a type-qualified (static) call candidate. Only + // used when it resolves to a real graph type node. + recvIdent string +} + +// superFact is one declared supertype relation, pending graph +// resolution of both endpoints. +type superFact struct { + typeName string + superName string + kind graph.EdgeKind // empty: decide by resolved target kind + line int +} + +// metaFact is one Node.Meta fill: stamp key=value on the symbol node +// matched by (owner, name) or by declaration line. +type metaFact struct { + key string + value string + owner string // receiver type for field stamps; "" for line-matched + name string // field name; "" for line-matched + line int // declaration line for line-matched stamps +} + +// bindingState tracks one name's type through the +// single-assignment-lite discipline: the first typed binding wins, a +// later conflicting (or unknowable) rebind poisons the binding so the +// engine never resolves through a type it cannot defend. +type bindingState struct { + typ string + pendingCallee string + poisoned bool +} + +type scopeKind int + +const ( + scopeFile scopeKind = iota + scopeType + scopeFunc +) + +type scopeEnv struct { + parent *scopeEnv + kind scopeKind + typeName string // set on scopeType + vars map[string]*bindingState +} + +func newScope(parent *scopeEnv, kind scopeKind) *scopeEnv { + return &scopeEnv{parent: parent, kind: kind, vars: make(map[string]*bindingState)} +} + +// lookup walks the scope chain for name. +func (s *scopeEnv) lookup(name string) *bindingState { + for e := s; e != nil; e = e.parent { + if st, ok := e.vars[name]; ok { + return st + } + } + return nil +} + +// enclosingTypeName returns the nearest type scope's name. +func (s *scopeEnv) enclosingTypeName() string { + for e := s; e != nil; e = e.parent { + if e.kind == scopeType { + return e.typeName + } + } + return "" +} + +// nearestTypeScope returns the nearest enclosing type scope. +func (s *scopeEnv) nearestTypeScope() *scopeEnv { + for e := s; e != nil; e = e.parent { + if e.kind == scopeType { + return e + } + } + return nil +} + +// bind applies the single-assignment-lite rule: first binding wins; a +// rebind that does not provably preserve the type degrades the binding +// to unknown (poisoned), permanently for this scope chain. +func (s *scopeEnv) bind(name string, typ, pendingCallee string) { + if name == "" { + return + } + if st := s.lookup(name); st != nil { + if st.poisoned { + return + } + if typ != st.typ || pendingCallee != st.pendingCallee { + st.typ = "" + st.pendingCallee = "" + st.poisoned = true + } + return + } + s.vars[name] = &bindingState{typ: typ, pendingCallee: pendingCallee} +} + +// binder runs the scope-graph walk over one parsed file. +type binder struct { + spec *LangSpec + src []byte + facts *fileFacts + // fieldsByType is the file-level pre-pass result: declared (and + // conventionally initialised) field types per type name. Seeding + // every type scope from it lets a method body resolve fields + // declared after it — and, for Rust, fields declared on the struct + // while the method lives in a separate impl block. + fieldsByType map[string]map[string]string +} + +func newBinder(spec *LangSpec, src []byte, facts *fileFacts) *binder { + return &binder{spec: spec, src: src, facts: facts, fieldsByType: make(map[string]map[string]string)} +} + +func (b *binder) run(root *sitter.Node) { + if root == nil { + return + } + b.prepassFields(root) + fileScope := newScope(nil, scopeFile) + if b.spec.Imports != nil { + b.facts.imports = b.spec.Imports(root, b.src) + } + b.walk(root, fileScope) +} + +// prepassFields collects field types for every type declaration in the +// file before the main walk. +func (b *binder) prepassFields(n *sitter.Node) { + if n == nil { + return + } + if b.spec.TypeDeclTypes[n.Type()] && b.spec.TypeDeclName != nil { + if name := b.spec.TypeDeclName(n, b.src); name != "" && b.spec.Fields != nil { + fields := b.fieldsByType[name] + if fields == nil { + fields = make(map[string]string) + b.fieldsByType[name] = fields + } + for _, f := range b.spec.Fields(n, b.src) { + typ := b.spec.normalize(f.Type) + if prev, ok := fields[f.Name]; ok && prev != typ { + // Conflicting declarations degrade to unknown — + // same rule as local rebinds. + fields[f.Name] = "" + continue + } + fields[f.Name] = typ + if typ != "" { + b.facts.metas = append(b.facts.metas, metaFact{ + key: "semantic_type", value: typ, owner: name, name: f.Name, + }) + } + } + } + } + for i := 0; i < int(n.NamedChildCount()); i++ { + b.prepassFields(n.NamedChild(i)) + } +} + +func (b *binder) walk(n *sitter.Node, env *scopeEnv) { + if n == nil { + return + } + t := n.Type() + + if b.spec.TypeDeclTypes[t] && b.spec.TypeDeclName != nil { + name := b.spec.TypeDeclName(n, b.src) + if name != "" { + if b.spec.Supertypes != nil { + for _, s := range b.spec.Supertypes(n, b.src) { + super := b.spec.normalize(s.Name) + if super == "" || super == name { + continue + } + b.facts.supers = append(b.facts.supers, superFact{ + typeName: name, superName: super, kind: s.Kind, line: s.Line, + }) + } + } + tEnv := newScope(env, scopeType) + tEnv.typeName = name + for fname, ftyp := range b.fieldsByType[name] { + tEnv.vars[fname] = &bindingState{typ: ftyp} + } + b.walkChildren(n, tEnv) + return + } + } + + if b.spec.FuncDeclTypes[t] { + fEnv := newScope(env, scopeFunc) + if b.spec.Params != nil { + for _, p := range b.spec.Params(n, b.src) { + fEnv.vars[p.Name] = &bindingState{typ: b.spec.normalize(p.Type)} + } + } + if b.spec.ReturnType != nil { + if rt := b.spec.normalize(b.spec.ReturnType(n, b.src)); rt != "" { + b.facts.metas = append(b.facts.metas, metaFact{ + key: "return_type", value: rt, line: nodeLine(n), + }) + } + } + b.walkChildren(n, fEnv) + return + } + + if b.spec.LocalBinding != nil { + if lb, ok := b.spec.LocalBinding(n, b.src); ok { + typ := b.spec.normalize(lb.DeclType) + pending := "" + if typ == "" || isInferenceKeyword(typ) { + typ, pending = b.exprType(lb.Init, env) + } + if lb.Field { + if ts := env.nearestTypeScope(); ts != nil { + ts.bind(lb.Name, typ, pending) + } + } else { + env.bind(lb.Name, typ, pending) + } + // Fall through: the initializer may contain calls worth + // recording. + } + } + + if b.spec.Call != nil { + if recv, method, ok := b.spec.Call(n, b.src); ok && method != "" { + if cf, grounded := b.receiverFact(recv, env); grounded { + cf.line = nodeLine(n) + cf.method = method + b.facts.calls = append(b.facts.calls, cf) + } + } + } + + b.walkChildren(n, env) +} + +func (b *binder) walkChildren(n *sitter.Node, env *scopeEnv) { + for i := 0; i < int(n.NamedChildCount()); i++ { + b.walk(n.NamedChild(i), env) + } +} + +// exprType evaluates an initializer expression to (type name, pending +// bare callee). Both empty means unknown. +func (b *binder) exprType(init *sitter.Node, env *scopeEnv) (string, string) { + if init == nil { + return "", "" + } + if b.spec.NewExprType != nil { + if t := b.spec.normalize(b.spec.NewExprType(init, b.src)); t != "" { + return t, "" + } + } + if identifierLike(init.Type()) { + if st := env.lookup(init.Content(b.src)); st != nil && !st.poisoned { + return st.typ, st.pendingCallee + } + return "", "" + } + if b.spec.FieldRef != nil { + if fname, ok := b.spec.FieldRef(init, b.src); ok { + if ts := env.nearestTypeScope(); ts != nil { + if st, found := ts.vars[fname]; found && !st.poisoned { + return st.typ, st.pendingCallee + } + } + return "", "" + } + } + if callee := bareCallee(init, b.src); callee != "" { + return "", callee + } + return "", "" +} + +// receiverFact grounds a call's receiver expression. Returns ok=false +// when the receiver is structurally outside what the engine can defend +// (chained expressions, poisoned bindings, unknown shapes). +func (b *binder) receiverFact(recv *sitter.Node, env *scopeEnv) (callFact, bool) { + if recv == nil { + return callFact{}, false + } + text := recv.Content(b.src) + if b.spec.SelfName != "" && text == b.spec.SelfName { + if tn := env.enclosingTypeName(); tn != "" { + return callFact{recvType: tn}, true + } + return callFact{}, false + } + if b.spec.FieldRef != nil { + if fname, ok := b.spec.FieldRef(recv, b.src); ok { + if ts := env.nearestTypeScope(); ts != nil { + if st, found := ts.vars[fname]; found && !st.poisoned && st.typ != "" { + return callFact{recvType: st.typ}, true + } + } + return callFact{}, false + } + } + if identifierLike(recv.Type()) { + if st := env.lookup(text); st != nil { + if st.poisoned { + return callFact{}, false + } + if st.typ != "" { + return callFact{recvType: st.typ}, true + } + if st.pendingCallee != "" { + return callFact{recvPendingCallee: st.pendingCallee}, true + } + return callFact{}, false + } + // Unbound identifier: a static / type-qualified call candidate. + // The apply phase only acts when it resolves to a type node. + return callFact{recvIdent: text}, true + } + if b.spec.NewExprType != nil { + if t := b.spec.normalize(b.spec.NewExprType(recv, b.src)); t != "" { + return callFact{recvType: t}, true + } + } + return callFact{}, false +} + +// bareCallee returns the callee name when n is a call expression whose +// function is a bare identifier; "" otherwise. Handles the grammars' +// two common shapes (call / call_expression / invocation_expression +// with a `function` field). +func bareCallee(n *sitter.Node, src []byte) string { + switch n.Type() { + case "call", "call_expression", "invocation_expression": + if fn := n.ChildByFieldName("function"); fn != nil && fn.Type() == "identifier" { + return fn.Content(src) + } + } + return "" +} + +// isInferenceKeyword reports whether a written "type" is actually the +// language's inference keyword and should defer to the initializer. +func isInferenceKeyword(t string) bool { + switch t { + case "var", "let", "auto": + return true + } + return false +} diff --git a/internal/semantic/tstypes/spec.go b/internal/semantic/tstypes/spec.go new file mode 100644 index 00000000..a77ecf4b --- /dev/null +++ b/internal/semantic/tstypes/spec.go @@ -0,0 +1,235 @@ +// Package tstypes implements in-process, LSP-free semantic providers +// over the shared tree-sitter ASTs. One shared engine builds a per-file +// scope graph (params, locals, fields, imports), binds declared and +// constructor-inferred types, propagates them through local assignments +// (single-assignment-lite: a rebind to a different type degrades the +// binding to unknown), and resolves receiver-qualified calls plus +// declared supertype relations against the symbol nodes the graph +// already holds. Per-language LangSpec tables adapt the engine to each +// grammar's node vocabulary. +// +// Provenance: everything this package touches is tree-sitter-derived, +// not compiler-verified, so edges are stamped OriginASTResolved (never +// the lsp_* tiers ConfirmEdge uses) with Meta["semantic_source"] set to +// the provider name ("java-types", "python-types", ...). A resolution +// the engine cannot ground in graph evidence — ambiguous receiver, +// unresolvable type name, overloaded method set — is skipped rather +// than guessed: a false edge is worse than a missing one. +package tstypes + +import ( + "strings" + + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" +) + +// Binding is one named, optionally typed binding (param or field). +type Binding struct { + Name string + Type string // declared type name as written; "" when unannotated + Line int // 1-based declaration line +} + +// LocalBind is one local-variable declaration or assignment the engine +// folds into the scope's type environment. +type LocalBind struct { + Name string + DeclType string // explicit annotation; "" when absent + Init *sitter.Node // initializer expression; nil when absent + Field bool // binds in the enclosing type scope (e.g. Ruby @ivar) +} + +// SuperRef is one declared supertype relation of a type declaration. +// Kind is EdgeExtends or EdgeImplements when the syntax declares it; +// an empty Kind defers the choice to the apply phase, which picks by +// the resolved target's node kind (used by C#, whose base list does +// not distinguish the base class from interfaces syntactically). +type SuperRef struct { + Name string + Kind graph.EdgeKind + Line int // 1-based +} + +// Import is one name-binding import: Local is the identifier the file +// sees; Path is a slash-separated location hint used to prefer the +// matching definition file when several nodes share the name. +type Import struct { + Local string + Path string +} + +// LangSpec adapts the shared engine to one language's tree-sitter +// grammar. The node-type sets drive the generic walk; the hooks decode +// the handful of shapes that differ per grammar. Hooks may be nil when +// the language has no equivalent construct (e.g. Ruby has no type +// annotations, C# has no name-binding imports). +type LangSpec struct { + ProviderName string + Languages []string + + // GrammarFor returns the grammar for a file path. Per-path because + // one provider can span sibling grammars (typescript / tsx / + // javascript). + GrammarFor func(filePath string) *sitter.Language + + // TypeDeclTypes / FuncDeclTypes are the node types that open a type + // or callable scope. + TypeDeclTypes map[string]bool + FuncDeclTypes map[string]bool + + // SelfName is the receiver keyword ("this", "self"); "" when the + // language has none. + SelfName string + + // TypeDeclName extracts the declared type name ("" skips the node). + TypeDeclName func(n *sitter.Node, src []byte) string + + // Supertypes lists the declared supertype relations of a type decl. + Supertypes func(n *sitter.Node, src []byte) []SuperRef + + // Fields lists the field bindings of a type decl (declared fields + // plus whatever conventional initialisations the language grounds, + // e.g. Python's `self.x = Foo()` or Ruby's `@x = Foo.new`). + Fields func(n *sitter.Node, src []byte) []Binding + + // Params lists a callable's declared parameters. + Params func(fn *sitter.Node, src []byte) []Binding + + // ReturnType extracts an explicit return-type annotation ("" when + // absent or unsupported). + ReturnType func(fn *sitter.Node, src []byte) string + + // LocalBinding decodes a local declaration / assignment node. + LocalBinding func(n *sitter.Node, src []byte) (LocalBind, bool) + + // Call decodes a receiver-qualified call: the receiver expression + // and the method name. ok=false for anything else (including + // receiverless calls — those are the resolver's job already). + Call func(n *sitter.Node, src []byte) (recv *sitter.Node, method string, ok bool) + + // NewExprType returns the constructed type name when n is a + // constructor expression ("" otherwise). Conventional constructors + // (Python `Foo()`, Ruby `Foo.new`, Rust `Foo::new`) may be + // returned too — the apply phase verifies every receiver type + // against a real graph type node before resolving through it. + NewExprType func(n *sitter.Node, src []byte) string + + // FieldRef reports that n is a reference to an instance field of + // the current receiver (`this.x`, `self.x`, `@x`) and returns the + // field's binding name. + FieldRef func(n *sitter.Node, src []byte) (string, bool) + + // Imports lists the file's name-binding imports. + Imports func(root *sitter.Node, src []byte) []Import + + // SupertypeKinds widens the node kinds a declared supertype name + // may resolve to. nil keeps the receiver default (type / + // interface). Ruby adds packages: tree-sitter modules index as + // KindPackage and `include M` targets them. + SupertypeKinds map[graph.NodeKind]bool + + // NormalizeType reduces a written type to the bare name the graph + // indexes (strip generics / pointers / qualifiers). nil uses the + // shared default. + NormalizeType func(t string) string +} + +func (s *LangSpec) normalize(t string) string { + if s.NormalizeType != nil { + return s.NormalizeType(t) + } + return NormalizeTypeName(t) +} + +// handles reports whether the spec serves the given language code. +func (s *LangSpec) handles(lang string) bool { + for _, l := range s.Languages { + if l == lang { + return true + } + } + return false +} + +// NormalizeTypeName is the shared written-type → bare-name reduction: +// strips generic arguments, array suffixes, nullability markers, +// reference sigils, and namespace qualifiers, leaving the identifier +// the graph indexes type nodes under. +func NormalizeTypeName(t string) string { + t = strings.TrimSpace(t) + if t == "" { + return "" + } + // Reference / pointer / ownership sigils and prefix keywords. + for { + switch { + case strings.HasPrefix(t, "&"), strings.HasPrefix(t, "*"): + t = strings.TrimSpace(t[1:]) + continue + case strings.HasPrefix(t, "mut "): + t = strings.TrimSpace(t[4:]) + continue + case strings.HasPrefix(t, "dyn "): + t = strings.TrimSpace(t[4:]) + continue + case strings.HasPrefix(t, "impl "): + t = strings.TrimSpace(t[5:]) + continue + } + break + } + // Generic arguments and array / nullability suffixes. + if i := strings.IndexAny(t, "<(["); i >= 0 { + t = t[:i] + } + t = strings.TrimSuffix(strings.TrimSuffix(t, "?"), "!") + // Namespace / module qualifiers — keep the last segment. + if i := strings.LastIndex(t, "::"); i >= 0 { + t = t[i+2:] + } + if i := strings.LastIndex(t, "."); i >= 0 { + t = t[i+1:] + } + return strings.TrimSpace(t) +} + +// nodeLine returns the 1-based start line of n. +func nodeLine(n *sitter.Node) int { + return int(n.StartPoint().Row) + 1 +} + +// fieldText returns the text of a named field child, "" when absent. +func fieldText(n *sitter.Node, field string, src []byte) string { + c := n.ChildByFieldName(field) + if c == nil { + return "" + } + return c.Content(src) +} + +// nameField extracts the `name` field's text — the TypeDeclName shape +// every grammar here shares. +func nameField(n *sitter.Node, src []byte) string { + return fieldText(n, "name", src) +} + +// firstChildOfType returns the first named child with the given type. +func firstChildOfType(n *sitter.Node, t string) *sitter.Node { + for i := 0; i < int(n.NamedChildCount()); i++ { + if c := n.NamedChild(i); c.Type() == t { + return c + } + } + return nil +} + +// identifierLike reports whether the node is a bare single-token name +// usable for scope lookup. +func identifierLike(t string) bool { + switch t { + case "identifier", "constant", "type_identifier", "variable_name", "local_variable": + return true + } + return false +} diff --git a/internal/semantic/tstypes/typescript.go b/internal/semantic/tstypes/typescript.go new file mode 100644 index 00000000..2e55c7e8 --- /dev/null +++ b/internal/semantic/tstypes/typescript.go @@ -0,0 +1,286 @@ +package tstypes + +import ( + "strings" + + "github.com/zzet/gortex/internal/graph" + sitter "github.com/zzet/gortex/internal/parser/tsitter" + "github.com/zzet/gortex/internal/parser/tsitter/javascript" + "github.com/zzet/gortex/internal/parser/tsitter/tsx" + "github.com/zzet/gortex/internal/parser/tsitter/typescript" +) + +// TypeScriptSpec adapts the engine to the TS / TSX / JS grammar +// family. One provider serves both graph languages: .tsx picks the TSX +// grammar (JSX nodes), .js/.jsx/.mjs/.cjs the JavaScript grammar +// (where annotations don't exist and the binder leans on `new` +// inference), everything else plain TypeScript. +func TypeScriptSpec() *LangSpec { + tsGrammar := typescript.GetLanguage() + tsxGrammar := tsx.GetLanguage() + jsGrammar := javascript.GetLanguage() + return &LangSpec{ + ProviderName: "typescript-types", + Languages: []string{"typescript", "javascript"}, + GrammarFor: func(filePath string) *sitter.Language { + lower := strings.ToLower(filePath) + switch { + case strings.HasSuffix(lower, ".tsx"): + return tsxGrammar + case strings.HasSuffix(lower, ".js"), strings.HasSuffix(lower, ".jsx"), + strings.HasSuffix(lower, ".mjs"), strings.HasSuffix(lower, ".cjs"): + return jsGrammar + default: + return tsGrammar + } + }, + TypeDeclTypes: map[string]bool{ + "class_declaration": true, + "abstract_class_declaration": true, + "interface_declaration": true, + }, + FuncDeclTypes: map[string]bool{ + "function_declaration": true, + "generator_function_declaration": true, + "method_definition": true, + "arrow_function": true, + "function_expression": true, + }, + SelfName: "this", + TypeDeclName: nameField, + Supertypes: tsSupertypes, + Fields: tsFields, + Params: tsParams, + ReturnType: func(fn *sitter.Node, src []byte) string { + return typeAnnotationText(fn.ChildByFieldName("return_type"), src) + }, + LocalBinding: tsLocalBinding, + Call: tsCall, + NewExprType: func(n *sitter.Node, src []byte) string { + if n.Type() != "new_expression" { + return "" + } + ctor := n.ChildByFieldName("constructor") + if ctor == nil || ctor.Type() != "identifier" { + return "" + } + return ctor.Content(src) + }, + FieldRef: func(n *sitter.Node, src []byte) (string, bool) { + if n.Type() != "member_expression" { + return "", false + } + obj := n.ChildByFieldName("object") + if obj == nil || obj.Type() != "this" { + return "", false + } + return fieldText(n, "property", src), true + }, + Imports: tsImports, + } +} + +func tsSupertypes(n *sitter.Node, src []byte) []SuperRef { + var out []SuperRef + collect := func(c *sitter.Node, kind graph.EdgeKind) { + for i := 0; i < int(c.NamedChildCount()); i++ { + t := c.NamedChild(i) + switch t.Type() { + case "identifier", "type_identifier", "generic_type", "nested_type_identifier", "member_expression": + out = append(out, SuperRef{Name: t.Content(src), Kind: kind, Line: nodeLine(t)}) + } + } + } + switch n.Type() { + case "class_declaration", "abstract_class_declaration": + for i := 0; i < int(n.ChildCount()); i++ { + h := n.Child(i) + if h == nil || h.Type() != "class_heritage" { + continue + } + sawClause := false + for j := 0; j < int(h.NamedChildCount()); j++ { + c := h.NamedChild(j) + switch c.Type() { + case "extends_clause": + sawClause = true + collect(c, graph.EdgeExtends) + case "implements_clause": + sawClause = true + collect(c, graph.EdgeImplements) + } + } + if !sawClause { + // JavaScript grammar: class_heritage is `extends ` + // with the expression as a direct child. + collect(h, graph.EdgeExtends) + } + } + case "interface_declaration": + for i := 0; i < int(n.ChildCount()); i++ { + c := n.Child(i) + if c != nil && (c.Type() == "extends_type_clause" || c.Type() == "extends_clause") { + collect(c, graph.EdgeExtends) + } + } + } + return out +} + +func tsFields(n *sitter.Node, src []byte) []Binding { + body := n.ChildByFieldName("body") + if body == nil { + return nil + } + var out []Binding + for i := 0; i < int(body.NamedChildCount()); i++ { + c := body.NamedChild(i) + if c.Type() != "public_field_definition" && c.Type() != "field_definition" { + continue + } + name := fieldText(c, "name", src) + if name == "" { + continue + } + typ := typeAnnotationText(c.ChildByFieldName("type"), src) + if typ == "" { + // `count = new Counter()` — infer from the initializer. + if v := c.ChildByFieldName("value"); v != nil && v.Type() == "new_expression" { + if ctor := v.ChildByFieldName("constructor"); ctor != nil && ctor.Type() == "identifier" { + typ = ctor.Content(src) + } + } + } + out = append(out, Binding{Name: name, Type: typ, Line: nodeLine(c)}) + } + return out +} + +func tsParams(fn *sitter.Node, src []byte) []Binding { + params := fn.ChildByFieldName("parameters") + if params == nil { + return nil + } + var out []Binding + for i := 0; i < int(params.NamedChildCount()); i++ { + p := params.NamedChild(i) + switch p.Type() { + case "required_parameter", "optional_parameter": + pattern := p.ChildByFieldName("pattern") + if pattern == nil || pattern.Type() != "identifier" { + continue + } + out = append(out, Binding{ + Name: pattern.Content(src), + Type: typeAnnotationText(p.ChildByFieldName("type"), src), + Line: nodeLine(p), + }) + case "identifier": + // JavaScript grammar: parameters are bare identifiers. + out = append(out, Binding{Name: p.Content(src), Line: nodeLine(p)}) + } + } + return out +} + +func tsLocalBinding(n *sitter.Node, src []byte) (LocalBind, bool) { + switch n.Type() { + case "variable_declarator": + name := n.ChildByFieldName("name") + if name == nil || name.Type() != "identifier" { + return LocalBind{}, false + } + init := n.ChildByFieldName("value") + // An arrow function initializer is a callable, not a typed + // local; FuncDeclTypes handles its scope. + if init != nil && (init.Type() == "arrow_function" || init.Type() == "function_expression") { + return LocalBind{}, false + } + return LocalBind{ + Name: name.Content(src), + DeclType: typeAnnotationText(n.ChildByFieldName("type"), src), + Init: init, + }, true + case "assignment_expression": + left := n.ChildByFieldName("left") + if left == nil || left.Type() != "identifier" { + return LocalBind{}, false + } + return LocalBind{Name: left.Content(src), Init: n.ChildByFieldName("right")}, true + } + return LocalBind{}, false +} + +func tsCall(n *sitter.Node, src []byte) (*sitter.Node, string, bool) { + if n.Type() != "call_expression" { + return nil, "", false + } + fn := n.ChildByFieldName("function") + if fn == nil || fn.Type() != "member_expression" { + return nil, "", false + } + obj := fn.ChildByFieldName("object") + if obj == nil { + return nil, "", false + } + prop := fn.ChildByFieldName("property") + if prop == nil || prop.Type() != "property_identifier" { + return nil, "", false + } + return obj, prop.Content(src), true +} + +func tsImports(root *sitter.Node, src []byte) []Import { + var out []Import + for i := 0; i < int(root.NamedChildCount()); i++ { + stmt := root.NamedChild(i) + if stmt.Type() != "import_statement" { + continue + } + source := fieldText(stmt, "source", src) + source = strings.Trim(source, "\"'`") + if source == "" { + continue + } + clause := firstChildOfType(stmt, "import_clause") + if clause == nil { + continue + } + for j := 0; j < int(clause.NamedChildCount()); j++ { + c := clause.NamedChild(j) + switch c.Type() { + case "identifier": // default import + out = append(out, Import{Local: c.Content(src), Path: source}) + case "named_imports": + for k := 0; k < int(c.NamedChildCount()); k++ { + spec := c.NamedChild(k) + if spec.Type() != "import_specifier" { + continue + } + local := fieldText(spec, "alias", src) + if local == "" { + local = fieldText(spec, "name", src) + } + if local != "" { + out = append(out, Import{Local: local, Path: source}) + } + } + } + } + } + return out +} + +// typeAnnotationText unwraps a `: T` type_annotation node to T's text. +func typeAnnotationText(annot *sitter.Node, src []byte) string { + if annot == nil { + return "" + } + if annot.Type() == "type_annotation" { + if annot.NamedChildCount() == 0 { + return "" + } + return annot.NamedChild(0).Content(src) + } + return annot.Content(src) +} diff --git a/internal/semantic/tstypes/typescript_test.go b/internal/semantic/tstypes/typescript_test.go new file mode 100644 index 00000000..9b3405b7 --- /dev/null +++ b/internal/semantic/tstypes/typescript_test.go @@ -0,0 +1,216 @@ +package tstypes + +import ( + "testing" + + "go.uber.org/zap" + + "github.com/zzet/gortex/internal/graph" +) + +const tsSvc = `export class Svc { + run(): void { + } + + stop(): void { + } +} +` + +func TestTypeScript_DeclaredTypeResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/svc.ts": tsSvc, + "src/app.ts": `import { Svc } from "./svc"; + +export class App { + handle(s: Svc): void { + s.run(); + } +} +`, + }) + p := NewProvider(TypeScriptSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "handle", graph.KindMethod) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + e := callEdgeTo(g, caller.ID, target.ID) + if e == nil { + t.Fatalf("annotated-param call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } + assertASTProvenance(t, e, "typescript-types") +} + +func TestTypeScript_ConstructorInferenceResolvesCall(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/svc.ts": tsSvc, + "src/app.ts": `import { Svc } from "./svc"; + +export function main(): void { + const s = new Svc(); + s.run(); +} +`, + }) + p := NewProvider(TypeScriptSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, caller.ID, target.ID) == nil { + t.Fatalf("constructor-inferred call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } +} + +// Two same-named classes: the relative-import hint must pick the right +// one. +func TestTypeScript_ImportHintDisambiguates(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/svc.ts": tsSvc, + "other/svc.ts": tsSvc, + "src/app.ts": `import { Svc } from "./svc"; + +export function main(): void { + const s = new Svc(); + s.run(); +} +`, + }) + p := NewProvider(TypeScriptSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + want := "src/svc.ts::Svc.run" + if callEdgeTo(g, caller.ID, want) == nil { + t.Fatalf("import-hinted call did not land on %s; edges: %v", want, g.GetOutEdges(caller.ID)) + } + if callEdgeTo(g, caller.ID, "other/svc.ts::Svc.run") != nil { + t.Fatal("call landed on the wrong module's class") + } +} + +func TestTypeScript_ImplementsAndExtendsSynthesis(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/svc.ts": tsSvc, + "src/iface.ts": `export interface Greeter { + greet(): void; +} +`, + "src/impl.ts": `import { Greeter } from "./iface"; +import { Svc } from "./svc"; + +export class Impl extends Svc implements Greeter { + greet(): void { + } +} +`, + }) + p := NewProvider(TypeScriptSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + impl := nodeByNameKind(t, g, "Impl", graph.KindType) + iface := nodeByNameKind(t, g, "Greeter", graph.KindInterface) + svc := nodeByNameKind(t, g, "Svc", graph.KindType) + if e := edgeBetween(g, impl.ID, graph.EdgeImplements, iface.ID); e == nil { + t.Fatalf("implements edge missing; edges: %v", g.GetOutEdges(impl.ID)) + } else { + assertASTProvenance(t, e, "typescript-types") + } + if e := edgeBetween(g, impl.ID, graph.EdgeExtends, svc.ID); e == nil { + t.Fatalf("extends edge missing; edges: %v", g.GetOutEdges(impl.ID)) + } else { + assertASTProvenance(t, e, "typescript-types") + } +} + +// this-qualified calls and typed class fields resolve inside a class. +func TestTypeScript_SelfAndFieldReceivers(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/svc.ts": tsSvc, + "src/app.ts": `import { Svc } from "./svc"; + +export class App { + private worker: Svc = new Svc(); + + direct(): void { + this.helper(); + } + + helper(): void { + this.worker.run(); + } +} +`, + }) + p := NewProvider(TypeScriptSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + direct := nodeByNameKind(t, g, "direct", graph.KindMethod) + helper := nodeByNameKind(t, g, "helper", graph.KindMethod) + if callEdgeTo(g, direct.ID, helper.ID) == nil { + t.Fatalf("this.helper() not resolved; edges: %v", g.GetOutEdges(direct.ID)) + } + run := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, helper.ID, run.ID) == nil { + t.Fatalf("this.worker.run() not resolved; edges: %v", g.GetOutEdges(helper.ID)) + } +} + +// JavaScript files (no annotations) still ground constructor-inferred +// receivers through the JS grammar. +func TestTypeScript_JavaScriptConstructorInference(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/svc.js": `export class Svc { + run() { + } +} +`, + "src/app.js": `import { Svc } from "./svc"; + +export function main() { + const s = new Svc(); + s.run(); +} +`, + }) + p := NewProvider(TypeScriptSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + target := nodeByNameKind(t, g, "run", graph.KindMethod) + if callEdgeTo(g, caller.ID, target.ID) == nil { + t.Fatalf("JS constructor-inferred call not resolved; edges: %v", g.GetOutEdges(caller.ID)) + } +} + +func TestTypeScript_AmbiguousReceiverStaysUntouched(t *testing.T) { + g, dir := buildFixture(t, map[string]string{ + "src/svc.ts": tsSvc, + "src/alt.ts": `export class Alt { + run(): void { + } +} +`, + "src/app.ts": `import { Alt } from "./alt"; +import { Svc } from "./svc"; + +export function main(): void { + let s = new Svc(); + s = new Alt(); + s.run(); +} +`, + }) + p := NewProvider(TypeScriptSpec(), zap.NewNop()) + if _, err := p.Enrich(g, dir); err != nil { + t.Fatal(err) + } + caller := nodeByNameKind(t, g, "main", graph.KindFunction) + assertUntouched(t, g, caller.ID, "run", "typescript-types") +} diff --git a/internal/serverstack/shared_server.go b/internal/serverstack/shared_server.go index ac586378..6675ea85 100644 --- a/internal/serverstack/shared_server.go +++ b/internal/serverstack/shared_server.go @@ -25,6 +25,7 @@ import ( "github.com/zzet/gortex/internal/semantic/goanalysis" "github.com/zzet/gortex/internal/semantic/lsp" "github.com/zzet/gortex/internal/semantic/scip" + "github.com/zzet/gortex/internal/semantic/tstypes" ) // Lifecycle selects the backend default, whether warm-restart/snapshot @@ -284,6 +285,15 @@ func NewSharedServer(cfg SharedServerConfig) (*SharedServer, error) { semMgr.RegisterProvider(goProvider) contracts.SetBindingResolver(goProvider) + // In-process tree-sitter type resolvers — always-available + // (no external binary), supplemental (they coexist with LSP / + // SCIP providers instead of competing for the language slot). + // Disable one via a `semantic.providers` entry with + // `enabled: false` under its name (e.g. "java-types"). + for _, tp := range tstypes.DefaultProviders(logger) { + semMgr.RegisterProvider(tp) + } + lspWorkspace := cfg.Index if lspWorkspace == "" { lspWorkspace, _ = os.Getwd()