diff --git a/EFFECT_INFERENCE_PLAN.md b/EFFECT_INFERENCE_PLAN.md new file mode 100644 index 00000000..163e8861 --- /dev/null +++ b/EFFECT_INFERENCE_PLAN.md @@ -0,0 +1,136 @@ +# Effect Inference & Enforcement Plan + +## Goal + +Infer ownership effects (Borrow/Store/Send/Freeze) for user-defined functions and enforce them at call sites. This is the foundation for arena-based allocation and actor-safe value transfer. + +## Current State + +- Effect vocabulary: COMPLETE (Borrow, Store, Send, Freeze, BorrowAll, PassThrough, FlowInto) +- Builtin annotations: COMPLETE (rawset→Store, table.insert→Store, print→BorrowAll, etc.) +- Propagation engine: COMPLETE (fixpoint loop unions callee effects into caller) +- Serialization: COMPLETE (all labels round-trip through module manifests) +- Inference for user code: MISSING +- Enforcement: MISSING +- Diagnostics: MISSING + +## Escape Sites (Where inference must happen) + +### Table store: `t[k] = v` / `t.field = v` +- **File**: `compiler/check/flowbuild/assign/emit.go` +- **Location**: `TargetField` at line 515, `TargetIndex` at line 562 +- **Effect**: If `v` is a parameter → `Store{Param: paramIdx(v), Into: paramIdx(t)}` +- **Note**: Only when target base symbol is a parameter or upvalue + +### Upvalue capture +- **File**: `compiler/check/infer/captured/captured.go`, `FromParentFacts` line 11 +- **Binding**: `compiler/bind/table.go`, `CapturedSymbols` line 428 +- **Effect**: If captured symbol is a parameter AND the closure escapes → `Store{Param: paramIdx}` +- **Note**: Connect to `CaptureInfo.Escapes` (types/typ/capture.go line 28) + +### Return escape +- **File**: `compiler/cfg/graph.go`, `EachReturn` line 647 +- **Data**: `ReturnInfo.Symbols` in `compiler/cfg/types.go` line 349 +- **Effect**: If returned symbol is a parameter → parameter escapes via return. Model as PassThrough (already exists) or Store{Into: -1} + +### Yield escape +- **File**: `compiler/stdlib/coroutine.go` line 20 +- **Fix**: Annotate `coroutine.yield` with `Send{FromParam: 0}` +- **Note**: No new label needed, Send covers this + +### Global store +- **File**: `compiler/check/flowbuild/assign/emit.go`, `TargetIdent` at line 280 +- **Condition**: `bindings.Kind(sym) == cfg.SymbolGlobal` +- **Effect**: If RHS is a parameter → `Store{Param: paramIdx, Into: -1}` + +## No New Labels Needed + +| Pattern | Label | Rationale | +|---------|-------|-----------| +| Table store | Store{Param, Into} | Value stored into structure | +| Closure capture + escape | Store{Param, Into: -1} | Capture is opaque store | +| Return | PassThrough / FlowInto | Already exist | +| Yield | Send{FromParam} | Cross-boundary transfer | +| Global store | Store{Param, Into: -1} | Long-lived destination | +| Freeze | Freeze{Param} | Already exists | + +## Implementation + +### Step 1: Annotate stdlib gaps +- File: `compiler/stdlib/coroutine.go` +- Add Send to yield, BorrowAll to pure coroutine functions +- Review channel/process stubs for missing Send annotations + +### Step 2: Create `InferOwnershipEffects` +- New file: `compiler/check/effects/ownership.go` +- Signature: `func InferOwnershipEffects(graph *cfg.Graph, bindings *bind.Table, params []ParamInfo) effect.Row` +- Walks: EachAssign (field/index targets), EachReturn, captured symbols +- Produces: Store/Borrow labels relative to function parameters +- For parameters not observed to escape → Borrow{Param: i} +- For parameters that escape → Store{Param: i, Into: ...} + +### Step 3: Integrate into effects.Propagate +- File: `compiler/check/effects/propagate.go` line 20 +- After computing fnEffect.Row, call InferOwnershipEffects and union result +- Ownership labels enter the fixpoint alongside Throw/IO/Diverge + +### Step 4: Translate callee ownership at call sites +- The critical piece: callee Store labels reference callee params, not caller params +- New function: `TranslateCalleeOwnership(calleeRow, callInfo, callerParams) effect.Row` +- For each callee Store{Param: i}, find which caller expression maps to callee param i +- If that expression is a caller param → Store{Param: callerParamIdx} on caller +- If expression is not a param (literal, local) → no ownership effect on caller +- Integrate into the EachCallSite callback in Propagate + +### Step 5: Enforcement hooks +- File: `compiler/check/hooks/call_check.go`, `checkSingleCall` line 64 +- Check: callee has Store + arg is frozen → diagnostic +- Check: callee has Send + arg is mutable → diagnostic or auto-freeze +- Check: callee has Freeze → mark arg as frozen in subsequent flow + +### Step 6: Testing +- Unit tests: `compiler/check/effects/ownership_test.go` +- Fixtures: `testdata/fixtures/ownership/{table-store,closure-capture,return-escape,global-store,frozen-write}/` +- Cross-module: multi-file fixtures verifying exported function effects +- Effect assertions: extend fixture harness with `-- expect-effect: store(param[0])` or check via `sess.Store.InterprocPrev.Refinements` + +## Performance + +- InferOwnershipEffects runs once per function per fixpoint iteration (cheap) +- Body scan is O(assignments + returns + closures) — already iterated +- Effect rows use deduplication — adding same label twice is no-op +- Union is O(n*m) where n,m are label counts — typically 1-5 per function +- No new flow domain needed — ownership is function-level summary, not per-point lattice + +## CaptureInfo relationship + +- `CaptureInfo.Escapes` (types/typ/capture.go) → runtime codegen concern (heap vs stack upvalue) +- Effect `Store` label → type-level ownership concern (lifetime extends beyond call) +- Complementary, not redundant +- When inferring ownership: consult CapturedSymbols to identify captures, use CaptureInfo.Escapes to determine if closure escapes + +## Dependency Order + +``` +Step 1 (stdlib annotations) → independent +Step 2 (InferOwnershipEffects) → needs graph API understanding +Step 3 (integrate Propagate) → depends on Step 2 +Step 4 (callee translation) → depends on Step 3, most complex +Step 5 (enforcement) → depends on Steps 3-4 +Step 6 (testing) → parallel with each step +``` + +## Critical Files + +| Purpose | File | +|---------|------| +| Inference integration | `compiler/check/effects/propagate.go` | +| Escape site detection | `compiler/check/flowbuild/assign/emit.go` | +| Enforcement | `compiler/check/hooks/call_check.go` | +| Label definitions | `types/effect/label.go` | +| Fixpoint loop | `compiler/check/pipeline/driver.go` | +| Capture binding | `compiler/bind/table.go` | +| Capture types | `compiler/check/infer/captured/captured.go` | +| Return info | `compiler/cfg/graph.go` (EachReturn) | +| Effect export | `compiler/check/effects/export.go` | +| Stdlib annotations | `compiler/stdlib/coroutine.go` | diff --git a/PERF_FIX_PLAN.md b/PERF_FIX_PLAN.md new file mode 100644 index 00000000..fe594b81 --- /dev/null +++ b/PERF_FIX_PLAN.md @@ -0,0 +1,139 @@ +# Lint Deadlock Fix: Route Through Salsa Query Layer + +## Problem + +`wippy lint` deadlocks on large projects (300-900 files). Root cause: expensive type operations called thousands of times without caching, bypassing the existing Salsa-style query system (`types/db/query.go`). + +### Hot paths (from profiling keeper, docker-demo, be-common-components) + +1. **`PruneSoftUnionMembers`** — 49% CPU. Called from `isSubtype` on every subtype check. Walks entire type tree looking for soft union members. Per-call memo discarded after each invocation. + +2. **`ExpandInstantiated`** — was 42% CPU before per-call memo fix. 16 direct calls bypass `Engine.ExpandInstantiated` query. + +3. **`applyInferSubst`** — was 75% CPU in constraint solver. No Salsa analog; internal memo added. + +4. **`walkType` (`occursIn`, `containsTypeVar`)** — was 68% CPU. No Salsa analog; seen-set added. + +5. **`typeContainsNever`** — caused 148GB allocations in docker-demo. No Salsa analog; seen-set added. + +## Architecture + +The Salsa query layer already exists: + +``` +types/query/core/Engine + .IsSubtype(ctx, sub, super) -> db.Query "IsSubtype" + .ExpandInstantiated(ctx, t) -> db.Query "ExpandInstantiated" + .Widen(ctx, t) -> db.Query "Widen" + .Field/Method/Index(ctx, ...) -> db.Query with memoization +``` + +Interface: `types/query/core/TypeOps` — accepted by synthesizer, hooks, etc. + +Problem: 80+ call sites use raw `subtype.IsSubtype()` instead of `Engine.IsSubtype()` because they lack `TypeOps` or `*db.QueryContext`. + +## Plan + +### Phase 1: Thread TypeOps through hot subsystems + +The 80 direct `subtype.IsSubtype` calls cluster in these files: + +| File | Count | Has TypeOps? | +|------|-------|-------------| +| returns/join.go | 10 | No — needs threading from pipeline | +| synth/ops/check.go | 6 | Yes — has `synth.Engine` with TypeOps | +| returns/widen.go | 5 | No — needs threading from pipeline | +| flow/query.go | 4 | Has QueryContext | +| constraint/solver.go | 4 | No context — leaf-level | +| flowbuild/assign/error_return_policy.go | 4 | Has scope with TypeOps | +| narrow/narrow.go | 3 | No context — leaf-level | +| flow/transfer.go | 3 | Has QueryContext | +| constraint/infer.go | 3 | No context — leaf-level | +| hooks/assign_check.go | 3 | Yes — hooks receive TypeOps | +| Others (2 each) | ~20 | Mixed | + +Strategy per group: + +**A. Already have TypeOps (synth/ops, hooks, flowbuild):** Change `subtype.IsSubtype(a, b)` to `ops.IsSubtype(ctx, a, b)`. Mechanical replacement. + +**B. Pipeline subsystems (returns/join.go, returns/widen.go):** Thread `TypeOps` from `pipeline.Runner` into return inference. The `Runner` already has `types core.TypeOps`. Pass it down to return join/widen functions. + +**C. Flow system (flow/query.go, flow/transfer.go):** Already have `*db.QueryContext`. Need access to `Engine` or expose `IsSubtype` as a query they can call. + +**D. Leaf-level (constraint/solver.go, narrow/narrow.go, constraint/infer.go):** These are deep in the type system with no infrastructure access. Options: + - Accept: these calls are on small types (constraint variables, narrowed results) and less hot + - Or: pass a subtype checker function as a parameter + +### Phase 2: Make PruneSoftUnionMembers a derived query + +Current: `isSubtype()` calls `PruneSoftUnionMembers(sub)` and `PruneSoftUnionMembers(super)` every time. + +Option A — **Prune-on-construction**: When creating Union types (`typ.NewUnion`), prune soft members immediately. Then `PruneSoftUnionMembers` becomes a no-op. Risk: changes Union construction semantics, may break inference that needs soft members temporarily. + +Option B — **Lazy flag**: Add a `hasSoft` bit to Union type. `PruneSoftUnionMembers` checks the flag and returns immediately if false. Set the flag during `NewUnion` construction by checking members. Cost: one bool per Union, one check per member at construction. + +Option C — **Query wrapper**: Add `PruneSoft` to `Engine` as a memoized query. Callers going through `Engine.IsSubtype` already benefit. Direct `subtype.IsSubtype` calls still pay the cost. + +**Recommendation: Option B (lazy flag)**. It's the most canonical — the type knows about itself, no external caching needed. + +### Phase 3: Remove ad-hoc memos that duplicate Salsa + +After Phase 1-2, review and remove: + +- `subst.go` expand memo pool — **keep**: prevents exponential blowup *within* a single expansion. Salsa caches the final result but not intermediate recursion. This is algorithmic, not caching. +- `constraint/infer.go` `applyInferSubst` memo — **keep**: internal to constraint solver, no Salsa analog. Converts O(n^2) cycle detection to O(n) with result memo. +- `constraint/infer.go` `walkTypeMemo` seen set — **keep**: prevents re-walking same pointer in `occursIn`/`containsTypeVar`. No Salsa analog. +- `returns/join.go` `typeContainsNeverMemo` seen set — **keep**: prevents unbounded recursion. No Salsa analog. + +These are not "non-canonical memoization" — they're algorithmic fixes within compute functions. Salsa memoizes query results; these prevent pathological behavior during a single query computation. + +### Phase 4: Audit new code for direct bypasses + +Add a lint/review rule: new code should use `TypeOps.IsSubtype` not `subtype.IsSubtype` when `TypeOps` is available. The raw function is for the engine's compute function and leaf-level code without infrastructure. + +## Files to modify + +### Phase 1A — Mechanical (TypeOps already available) +- `compiler/check/synth/ops/check.go` (6 calls) +- `compiler/check/hooks/assign_check.go` (3 calls) +- `compiler/check/hooks/return_check.go` (2 calls) +- `compiler/check/hooks/field_check.go` (2 calls) +- `compiler/check/synth/ops/call.go` (2 calls) +- `compiler/check/synth/ops/typecheck.go` (2 calls) +- `compiler/check/synth/ops/generic.go` (1 call) +- `compiler/check/synth/phase/extract/*.go` (3 calls) +- `compiler/check/synth/phase/resolve/resolver.go` (2 calls) +- `compiler/check/flowbuild/assign/error_return_policy.go` (4 calls) +- `compiler/check/flowbuild/assign/precision.go` (2 calls) + +### Phase 1B — Thread TypeOps into return inference +- `compiler/check/returns/join.go` (10 calls) +- `compiler/check/returns/widen.go` (5 calls) +- `compiler/check/pipeline/runner.go` (pass TypeOps to return inference) + +### Phase 1C — Flow system +- `types/flow/query.go` (4 calls) +- `types/flow/transfer.go` (3 calls) +- `types/flow/type_facts.go` (1 call) + +### Phase 1D — Leaf-level (defer or pass function) +- `types/constraint/solver.go` (4 calls) +- `types/constraint/infer.go` (3 calls) +- `types/narrow/narrow.go` (3 calls) +- `types/query/core/field.go` (2 calls) +- `types/query/core/index.go` (2 calls) +- `types/query/core/instantiate.go` (1 call) +- `ltype.go` (1 call) + +### Phase 2 — Union soft flag +- `types/typ/union.go` — add `hasSoft bool` field, set in constructor +- `types/typ/soft.go` — check flag in `PruneSoftUnionMembers` +- `types/typ/rebuild.go` — ensure flag propagation in NewUnion + +## Execution order + +1. Phase 2 first (Union soft flag) — biggest impact, self-contained +2. Phase 1A — mechanical, low risk +3. Phase 1B — moderate refactor, high impact (returns/join is hottest) +4. Phase 1C — flow system routing +5. Phase 1D — evaluate if needed after 1-3 diff --git a/compiler/bind/table.go b/compiler/bind/table.go index 497b4b47..cae7b488 100644 --- a/compiler/bind/table.go +++ b/compiler/bind/table.go @@ -57,6 +57,9 @@ type BindingTable struct { // names stores the original source name for each symbol names map[cfg.SymbolID]string + // symbolsByName lazily indexes symbols by source name in ascending symbol order. + symbolsByName map[string][]cfg.SymbolID + // paramSymbols maps functions to their parameter symbol list paramSymbols map[*ast.FunctionExpr][]cfg.SymbolID @@ -167,7 +170,16 @@ func (t *BindingTable) Kind(sym cfg.SymbolID) (cfg.SymbolKind, bool) { // SetName records the source name associated with a symbol. func (t *BindingTable) SetName(sym cfg.SymbolID, name string) { + if t == nil || sym == 0 { + return + } + if prev, ok := t.names[sym]; ok { + if prev == name { + return + } + } t.names[sym] = name + t.symbolsByName = nil } // Name returns the source name of a symbol, or empty if unknown. @@ -179,21 +191,48 @@ func (t *BindingTable) Name(sym cfg.SymbolID) string { // // Results are sorted by symbol ID for deterministic iteration. func (t *BindingTable) SymbolsByName(name string) []cfg.SymbolID { + syms := t.SymbolsByNameReadOnly(name) + if len(syms) == 0 { + return nil + } + out := make([]cfg.SymbolID, len(syms)) + copy(out, syms) + return out +} + +// SymbolsByNameReadOnly returns the stored symbols for a source name. +// +// The returned slice is sorted by symbol ID and must be treated as read-only. +func (t *BindingTable) SymbolsByNameReadOnly(name string) []cfg.SymbolID { if t == nil || name == "" { return nil } - result := make([]cfg.SymbolID, 0, 2) - for sym, n := range t.names { - if n == name && sym != 0 { - result = append(result, sym) + if t.symbolsByName == nil { + t.symbolsByName = t.buildSymbolsByNameIndex() + } + return t.symbolsByName[name] +} + +func (t *BindingTable) buildSymbolsByNameIndex() map[string][]cfg.SymbolID { + if t == nil || len(t.names) == 0 { + return nil + } + index := make(map[string][]cfg.SymbolID) + for sym, name := range t.names { + if sym == 0 || name == "" { + continue } + index[name] = append(index[name], sym) } - if len(result) > 1 { - sort.Slice(result, func(i, j int) bool { - return result[i] < result[j] - }) + for name, syms := range index { + if len(syms) > 1 { + sort.Slice(syms, func(i, j int) bool { + return syms[i] < syms[j] + }) + } + index[name] = syms } - return result + return index } // SetParamSymbols records the ordered parameter symbols for a function. diff --git a/compiler/bind/table_test.go b/compiler/bind/table_test.go index 9664ba77..196d433e 100644 --- a/compiler/bind/table_test.go +++ b/compiler/bind/table_test.go @@ -315,6 +315,37 @@ func TestBindingTable_SymbolsByName_UnknownOrEmpty(t *testing.T) { } } +func TestBindingTable_SymbolsByNameReadOnly_TracksUpdatesAndCopyIsolation(t *testing.T) { + table := NewBindingTable() + alpha := cfg.NextSymbolID() + beta := cfg.NextSymbolID() + + table.SetName(alpha, "collect") + table.SetName(beta, "collect") + + got := table.SymbolsByNameReadOnly("collect") + if len(got) != 2 || got[0] != alpha || got[1] != beta { + t.Fatalf("SymbolsByNameReadOnly(\"collect\") = %v, want [%d %d]", got, alpha, beta) + } + + copied := table.SymbolsByName("collect") + copied[0] = beta + got = table.SymbolsByNameReadOnly("collect") + if len(got) != 2 || got[0] != alpha || got[1] != beta { + t.Fatalf("SymbolsByName copy should not mutate stored index, got %v", got) + } + + table.SetName(beta, "other") + got = table.SymbolsByNameReadOnly("collect") + if len(got) != 1 || got[0] != alpha { + t.Fatalf("SymbolsByNameReadOnly(\"collect\") after rename = %v, want [%d]", got, alpha) + } + other := table.SymbolsByNameReadOnly("other") + if len(other) != 1 || other[0] != beta { + t.Fatalf("SymbolsByNameReadOnly(\"other\") after rename = %v, want [%d]", other, beta) + } +} + func TestBindingTable_Globals(t *testing.T) { table := NewBindingTable() diff --git a/compiler/cfg/graph.go b/compiler/cfg/graph.go index 5adda018..c00a10a0 100644 --- a/compiler/cfg/graph.go +++ b/compiler/cfg/graph.go @@ -548,7 +548,7 @@ func (g *Graph) EachAssign(fn func(Point, *AssignInfo)) { return } points := g.orderedAssignPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(func(info NodeInfo) bool { _, ok := info.(*AssignInfo) return ok @@ -565,7 +565,7 @@ func (g *Graph) AssignPoints() []Point { return nil } points := g.orderedAssignPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(func(info NodeInfo) bool { _, ok := info.(*AssignInfo) return ok @@ -587,7 +587,7 @@ func (g *Graph) EachStmtCall(fn func(Point, *CallInfo)) { return } points := g.orderedStmtCallPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(func(info NodeInfo) bool { _, ok := info.(*CallInfo) return ok @@ -619,7 +619,7 @@ func (g *Graph) EachCallSite(fn func(Point, *CallInfo)) { return } points := g.orderedPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(nil) } for _, p := range points { @@ -650,7 +650,7 @@ func (g *Graph) EachReturn(fn func(Point, *ReturnInfo)) { return } points := g.orderedReturnPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(func(info NodeInfo) bool { _, ok := info.(*ReturnInfo) return ok @@ -667,7 +667,7 @@ func (g *Graph) EachBranch(fn func(Point, *BranchInfo)) { return } points := g.orderedBranchPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(func(info NodeInfo) bool { _, ok := info.(*BranchInfo) return ok @@ -684,7 +684,7 @@ func (g *Graph) EachFuncDef(fn func(Point, *FuncDefInfo)) { return } points := g.orderedFuncDefPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(func(info NodeInfo) bool { _, ok := info.(*FuncDefInfo) return ok @@ -701,7 +701,7 @@ func (g *Graph) EachTypeDef(fn func(Point, *TypeDefInfo)) { return } points := g.orderedTypeDefPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(func(info NodeInfo) bool { _, ok := info.(*TypeDefInfo) return ok @@ -718,7 +718,7 @@ func (g *Graph) EachNode(fn func(Point, NodeInfo)) { return } points := g.orderedPoints - if len(points) == 0 && len(g.infoByPoint) > 0 { + if points == nil && len(g.infoByPoint) > 0 { points = g.sortedPoints(nil) } for _, p := range points { @@ -966,7 +966,7 @@ func (g *Graph) PopulateSymbols(resolve SymbolResolver) { } points := g.orderedPoints - if len(points) == 0 { + if points == nil { points = g.sortedPoints(nil) } diff --git a/compiler/check/api/effects.go b/compiler/check/api/effects.go index cd64ae8b..2e234590 100644 --- a/compiler/check/api/effects.go +++ b/compiler/check/api/effects.go @@ -5,37 +5,37 @@ import ( "github.com/wippyai/go-lua/types/constraint" ) -// EffectFacts provides function effect lookup. -type EffectFacts interface { - LookupBySym(sym cfg.SymbolID) *constraint.FunctionEffect +// RefinementFacts provides function refinement lookup. +type RefinementFacts interface { + LookupBySym(sym cfg.SymbolID) *constraint.FunctionRefinement } -// EffectStore provides methods for storing and retrieving function effects. -type EffectStore interface { - LookupEffectBySym(sym cfg.SymbolID) *constraint.FunctionEffect +// RefinementStore provides methods for storing and retrieving function refinements. +type RefinementStore interface { + LookupRefinementBySym(sym cfg.SymbolID) *constraint.FunctionRefinement } -// storeEffectFacts implements EffectFacts backed by an EffectStore. -type storeEffectFacts struct { - store EffectStore +// storeRefinementFacts implements RefinementFacts backed by a RefinementStore. +type storeRefinementFacts struct { + store RefinementStore } -// NewEffectFacts creates an EffectFacts backed by an EffectStore. -func NewEffectFacts(store EffectStore) EffectFacts { +// NewRefinementFacts creates RefinementFacts backed by a RefinementStore. +func NewRefinementFacts(store RefinementStore) RefinementFacts { if store == nil { - return nilEffectFacts{} + return nilRefinementFacts{} } - return &storeEffectFacts{store: store} + return &storeRefinementFacts{store: store} } -func (f *storeEffectFacts) LookupBySym(sym cfg.SymbolID) *constraint.FunctionEffect { +func (f *storeRefinementFacts) LookupBySym(sym cfg.SymbolID) *constraint.FunctionRefinement { if f.store == nil || sym == 0 { return nil } - return f.store.LookupEffectBySym(sym) + return f.store.LookupRefinementBySym(sym) } -// nilEffectFacts is a no-op EffectFacts implementation. -type nilEffectFacts struct{} +// nilRefinementFacts is a no-op RefinementFacts implementation. +type nilRefinementFacts struct{} -func (nilEffectFacts) LookupBySym(cfg.SymbolID) *constraint.FunctionEffect { return nil } +func (nilRefinementFacts) LookupBySym(cfg.SymbolID) *constraint.FunctionRefinement { return nil } diff --git a/compiler/check/api/env.go b/compiler/check/api/env.go index a11a1d94..4ff2c471 100644 --- a/compiler/check/api/env.go +++ b/compiler/check/api/env.go @@ -43,7 +43,7 @@ type BaseEnv interface { Graph() cfg.VersionedGraph Types() flow.TypeFacts Consts() *flow.Solution - Effects() EffectFacts + Refinements() RefinementFacts TypeNames() *scope.State Bindings() *bind.BindingTable ModuleAliases() map[cfg.SymbolID]string @@ -71,7 +71,7 @@ type envBase struct { bindings *bind.BindingTable types flow.TypeFacts solution *flow.Solution - effects EffectFacts + refinements RefinementFacts typeNames *scope.State moduleAliases map[cfg.SymbolID]string globalTypes map[string]typ.Type @@ -136,12 +136,12 @@ func (c *envCommon) Consts() *flow.Solution { return c.base.solution } -// Effects returns the effect facts provider. -func (c *envCommon) Effects() EffectFacts { +// Refinements returns the refinement facts provider. +func (c *envCommon) Refinements() RefinementFacts { if c == nil || c.base == nil { return nil } - return c.base.effects + return c.base.refinements } // TypeNames returns the scope state for type name resolution. @@ -285,20 +285,20 @@ func (e *NarrowEnvImpl) Consts() *flow.Solution { return e.envCommon.Consts() } -// Effects returns the effect facts provider. -func (e *DeclaredEnvImpl) Effects() EffectFacts { +// Refinements returns the refinement facts provider. +func (e *DeclaredEnvImpl) Refinements() RefinementFacts { if e == nil { return nil } - return e.envCommon.Effects() + return e.envCommon.Refinements() } -// Effects returns the effect facts provider. -func (e *NarrowEnvImpl) Effects() EffectFacts { +// Refinements returns the refinement facts provider. +func (e *NarrowEnvImpl) Refinements() RefinementFacts { if e == nil { return nil } - return e.envCommon.Effects() + return e.envCommon.Refinements() } // TypeNames returns the scope state for type name resolution. @@ -446,7 +446,7 @@ type DeclaredEnvConfig struct { DeclaredTypes flow.DeclaredTypes AnnotatedVars map[cfg.SymbolID]bool BaseScope *scope.State - EffectStore EffectStore + RefinementStore RefinementStore ModuleAliases map[cfg.SymbolID]string GlobalTypes map[string]typ.Type SiblingTypes map[cfg.SymbolID]typ.Type @@ -462,7 +462,7 @@ type NarrowEnvConfig struct { AnnotatedVars map[cfg.SymbolID]bool Solution *flow.Solution BaseScope *scope.State - EffectStore EffectStore + RefinementStore RefinementStore ModuleAliases map[cfg.SymbolID]string GlobalTypes map[string]typ.Type SiblingTypes map[cfg.SymbolID]typ.Type @@ -476,7 +476,7 @@ func newEnvBase( bindings *bind.BindingTable, types flow.TypeFacts, solution *flow.Solution, - effects EffectFacts, + refinements RefinementFacts, typeNames *scope.State, moduleAliases map[cfg.SymbolID]string, globalTypes map[string]typ.Type, @@ -487,7 +487,7 @@ func newEnvBase( bindings: bindings, types: types, solution: solution, - effects: effects, + refinements: refinements, typeNames: typeNames, moduleAliases: moduleAliases, globalTypes: globalTypes, @@ -505,7 +505,7 @@ func NewDeclaredEnv(cfg DeclaredEnvConfig) *DeclaredEnvImpl { cfg.Bindings, newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, cfg.SiblingTypes, cfg.LiteralTypes, cfg.AnnotatedVars, nil), nil, - NewEffectFacts(cfg.EffectStore), + NewRefinementFacts(cfg.RefinementStore), cfg.BaseScope, cfg.ModuleAliases, cfg.GlobalTypes, @@ -524,7 +524,7 @@ func NewNarrowEnv(cfg NarrowEnvConfig) *NarrowEnvImpl { cfg.Bindings, newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, cfg.SiblingTypes, cfg.LiteralTypes, cfg.AnnotatedVars, cfg.Solution), cfg.Solution, - NewEffectFacts(cfg.EffectStore), + NewRefinementFacts(cfg.RefinementStore), cfg.BaseScope, cfg.ModuleAliases, cfg.GlobalTypes, @@ -554,7 +554,7 @@ func NewReturnInferenceEnv(cfg ReturnInferenceEnvConfig) *DeclaredEnvImpl { cfg.Bindings, newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, nil, nil, nil, nil), nil, - NewEffectFacts(nil), + NewRefinementFacts(nil), cfg.BaseScope, cfg.ModuleAliases, cfg.GlobalTypes, diff --git a/compiler/check/api/env_test.go b/compiler/check/api/env_test.go index 10b3ebc3..c0f0eac2 100644 --- a/compiler/check/api/env_test.go +++ b/compiler/check/api/env_test.go @@ -122,8 +122,8 @@ func TestDeclaredEnv_NilSafety(t *testing.T) { if env.Consts() != nil { t.Error("nil.Consts() should be nil") } - if env.Effects() != nil { - t.Error("nil.Effects() should be nil") + if env.Refinements() != nil { + t.Error("nil.Refinements() should be nil") } if env.TypeNames() != nil { t.Error("nil.TypeNames() should be nil") diff --git a/compiler/check/api/graphs_attach.go b/compiler/check/api/graphs_attach.go new file mode 100644 index 00000000..301a18d4 --- /dev/null +++ b/compiler/check/api/graphs_attach.go @@ -0,0 +1,20 @@ +package api + +import "github.com/wippyai/go-lua/types/db" + +// GraphProviderKey is the typed attachment key for GraphProvider. +var GraphProviderKey = db.NewAttachmentKey[GraphProvider]("check.GraphProvider") + +// AttachGraphs attaches a graph provider to the query context for lookup. +func AttachGraphs(ctx *db.QueryContext, graphs GraphProvider) { + if ctx == nil || graphs == nil { + return + } + db.Attach(ctx, GraphProviderKey, graphs) +} + +// GraphsFrom retrieves the graph provider from a db.QueryContext. +func GraphsFrom(ctx *db.QueryContext) GraphProvider { + graphs, _ := db.Attached(ctx, GraphProviderKey) + return graphs +} diff --git a/compiler/check/api/result.go b/compiler/check/api/result.go index 1eb0d06c..72fd5905 100644 --- a/compiler/check/api/result.go +++ b/compiler/check/api/result.go @@ -18,7 +18,7 @@ import ( // - Phase A (Resolve): Type annotations resolved into Scopes // - Phase B (Scope/Extract): BaseScope, Scopes, FlowInputs // - Phase C (Solve): FlowSolution -// - Phase D (Narrow): Facts, FnEffect, NarrowSynth +// - Phase D (Narrow): Facts, FnRefinement, NarrowSynth type FuncResult struct { // Graph is the function's control flow graph containing CFG nodes, // binding information, and iteration metadata. @@ -48,9 +48,9 @@ type FuncResult struct { // Provides reachability conditions and exclusion facts for narrowing. FlowSolution *flow.Solution - // FnEffect captures the function's side effects (io, error, terminate). - // Propagated from callees and used for inter-function effect analysis. - FnEffect *constraint.FunctionEffect + // FnRefinement captures the function's inferred refinement summary. + // It includes propagated effect rows and branch-specific narrowing facts. + FnRefinement *constraint.FunctionRefinement // NarrowSynth is the narrowed-phase synthesis engine for this function. // Use TypeOf to query expression types with flow-based narrowing applied. @@ -82,6 +82,14 @@ func (r *FuncResult) EffectiveTypeAt(p cfg.Point, sym cfg.SymbolID) flow.TypedVa return r.Facts.EffectiveTypeAt(p, sym) } +// NarrowedTypeAt returns the precise path-sensitive narrowed type at a CFG point. +func (r *FuncResult) NarrowedTypeAt(p cfg.Point, path constraint.Path) typ.Type { + if r == nil || r.FlowSolution == nil { + return nil + } + return r.FlowSolution.NarrowedTypeAt(p, path) +} + // ExcludesTypeAt checks if the flow solution proves a type is excluded at a CFG point. // Used for type narrowing when control flow eliminates certain type possibilities. // diff --git a/compiler/check/api/store.go b/compiler/check/api/store.go index 2a1277f8..51343247 100644 --- a/compiler/check/api/store.go +++ b/compiler/check/api/store.go @@ -141,8 +141,8 @@ type IterationStore interface { Revision() uint64 BumpRevision() - EffectStore() EffectStore - StoreFunctionEffect(sym cfg.SymbolID, eff *constraint.FunctionEffect) + RefinementStore() RefinementStore + StoreFunctionRefinement(sym cfg.SymbolID, eff *constraint.FunctionRefinement) SetModuleBindings(bindings *bind.BindingTable) SetModuleAliases(aliases map[cfg.SymbolID]string) diff --git a/compiler/check/api/synth.go b/compiler/check/api/synth.go index b7b9136e..fe724007 100644 --- a/compiler/check/api/synth.go +++ b/compiler/check/api/synth.go @@ -151,6 +151,10 @@ type FlowQuery interface { // Combines declared type with flow refinements. EffectiveTypeAt(p cfg.Point, sym cfg.SymbolID) flow.TypedValue + // NarrowedTypeAt returns the exact narrowed type for a source path at a point. + // Used when diagnostics need the solved path-sensitive type, not just symbol-level facts. + NarrowedTypeAt(p cfg.Point, path constraint.Path) typ.Type + // ExcludesTypeAt checks if flow analysis excludes a type at a point. // Used for narrowing union types when branches eliminate possibilities. ExcludesTypeAt(p cfg.Point, path constraint.Path, declared typ.Type) bool diff --git a/compiler/check/api/synth_test.go b/compiler/check/api/synth_test.go index 9eecd26c..f3d10534 100644 --- a/compiler/check/api/synth_test.go +++ b/compiler/check/api/synth_test.go @@ -73,6 +73,7 @@ type mockFlowQuery struct{} func (m *mockFlowQuery) EffectiveTypeAt(cfg.Point, cfg.SymbolID) flow.TypedValue { return flow.TypedValue{} } +func (m *mockFlowQuery) NarrowedTypeAt(cfg.Point, constraint.Path) typ.Type { return nil } func (m *mockFlowQuery) ExcludesTypeAt(cfg.Point, constraint.Path, typ.Type) bool { return false } type mockFlowOps struct{} diff --git a/compiler/check/callsite/callee_symbols.go b/compiler/check/callsite/callee_symbols.go index 1323abfa..bd152843 100644 --- a/compiler/check/callsite/callee_symbols.go +++ b/compiler/check/callsite/callee_symbols.go @@ -32,12 +32,12 @@ func CalleeSymbolCandidates(info *cfg.CallInfo, primary, fallback *bind.BindingT } if info.CalleeName != "" { if primary != nil { - for _, sym := range primary.SymbolsByName(info.CalleeName) { + for _, sym := range primary.SymbolsByNameReadOnly(info.CalleeName) { set.Add(sym) } } if fallback != nil && fallback != primary { - for _, sym := range fallback.SymbolsByName(info.CalleeName) { + for _, sym := range fallback.SymbolsByNameReadOnly(info.CalleeName) { set.Add(sym) } } diff --git a/compiler/check/callsite/effect_resolution.go b/compiler/check/callsite/effect_resolution.go index 43b38610..90fa0158 100644 --- a/compiler/check/callsite/effect_resolution.go +++ b/compiler/check/callsite/effect_resolution.go @@ -62,11 +62,11 @@ func ResolveCalleeEffect( p cfg.Point, graph *cfg.Graph, primary, fallback *bind.BindingTable, - lookup func(sym cfg.SymbolID) *constraint.FunctionEffect, + lookup func(sym cfg.SymbolID) *constraint.FunctionRefinement, synth func(expr ast.Expr, p cfg.Point) typ.Type, resolveBySym func(p cfg.Point, sym cfg.SymbolID) (typ.Type, bool), - effectFromType func(t typ.Type) *constraint.FunctionEffect, -) *constraint.FunctionEffect { + effectFromType func(t typ.Type) *constraint.FunctionRefinement, +) *constraint.FunctionRefinement { if info == nil { return nil } diff --git a/compiler/check/callsite/effect_resolution_test.go b/compiler/check/callsite/effect_resolution_test.go index ad18e354..2a66a527 100644 --- a/compiler/check/callsite/effect_resolution_test.go +++ b/compiler/check/callsite/effect_resolution_test.go @@ -15,8 +15,8 @@ func TestResolveCalleeEffect_PrefersLookup(t *testing.T) { CalleeSymbol: 7, Callee: &ast.IdentExpr{Value: "f"}, } - lookupEff := &constraint.FunctionEffect{Terminates: true} - typeEff := &constraint.FunctionEffect{OnReturn: constraint.TrueCondition()} + lookupEff := &constraint.FunctionRefinement{Terminates: true} + typeEff := &constraint.FunctionRefinement{OnReturn: constraint.TrueCondition()} got := ResolveCalleeEffect( info, @@ -24,7 +24,7 @@ func TestResolveCalleeEffect_PrefersLookup(t *testing.T) { nil, nil, nil, - func(sym cfg.SymbolID) *constraint.FunctionEffect { + func(sym cfg.SymbolID) *constraint.FunctionRefinement { if sym == 7 { return lookupEff } @@ -34,10 +34,10 @@ func TestResolveCalleeEffect_PrefersLookup(t *testing.T) { return typ.Func().WithRefinement(typeEff).Build() }, nil, - func(t typ.Type) *constraint.FunctionEffect { + func(t typ.Type) *constraint.FunctionRefinement { fn := unwrap.Function(t) if fn != nil { - if eff, ok := fn.Refinement.(*constraint.FunctionEffect); ok { + if eff, ok := fn.Refinement.(*constraint.FunctionRefinement); ok { return eff } } @@ -54,7 +54,7 @@ func TestResolveCalleeEffect_FallsBackToSymbolTypeWhenSynthNoEffect(t *testing.T CalleeSymbol: 3, Callee: &ast.IdentExpr{Value: "g"}, } - want := &constraint.FunctionEffect{Terminates: true} + want := &constraint.FunctionRefinement{Terminates: true} got := ResolveCalleeEffect( info, @@ -72,12 +72,12 @@ func TestResolveCalleeEffect_FallsBackToSymbolTypeWhenSynthNoEffect(t *testing.T } return nil, false }, - func(t typ.Type) *constraint.FunctionEffect { + func(t typ.Type) *constraint.FunctionRefinement { fn, ok := t.(*typ.Function) if !ok || fn == nil { return nil } - eff, _ := fn.Refinement.(*constraint.FunctionEffect) + eff, _ := fn.Refinement.(*constraint.FunctionRefinement) return eff }, ) diff --git a/compiler/check/callsite/function_literal.go b/compiler/check/callsite/function_literal.go index fc1603e5..f3a1fddb 100644 --- a/compiler/check/callsite/function_literal.go +++ b/compiler/check/callsite/function_literal.go @@ -61,3 +61,47 @@ func FunctionLiteralForSymbol( return found } + +// FunctionLiteralForGraphSymbol resolves only graph-local stable function +// bindings for a symbol. +// +// Canonical boundary: +// - include graph-local/global function definitions +// - include local identifier assignments of function literals +// - exclude mutable field-path symbols, whose current callable type must come +// from value flow at the call site rather than binder symbol backtracking +func FunctionLiteralForGraphSymbol(graph *cfg.Graph, sym cfg.SymbolID) *ast.FunctionExpr { + if sym == 0 || graph == nil { + return nil + } + + var found *ast.FunctionExpr + graph.EachFuncDef(func(_ cfg.Point, info *cfg.FuncDefInfo) { + if found != nil || info == nil || info.Symbol != sym { + return + } + found = info.FuncExpr + }) + if found != nil { + return found + } + + graph.EachAssign(func(_ cfg.Point, info *cfg.AssignInfo) { + if found != nil || info == nil { + return + } + info.EachTargetSource(func(_ int, target cfg.AssignTarget, source ast.Expr) { + if found != nil { + return + } + if target.Kind != cfg.TargetIdent || target.Symbol != sym { + return + } + if fn, ok := source.(*ast.FunctionExpr); ok { + found = fn + } + }) + }) + + return found +} diff --git a/compiler/check/callsite/function_literal_test.go b/compiler/check/callsite/function_literal_test.go index 5c606904..cc8a9e7c 100644 --- a/compiler/check/callsite/function_literal_test.go +++ b/compiler/check/callsite/function_literal_test.go @@ -88,3 +88,83 @@ func TestFunctionLiteralForSymbol_AssignedFunctionLiteral(t *testing.T) { t.Fatal("expected function literal for assigned symbol") } } + +func TestFunctionLiteralForGraphSymbol_FuncDefSymbol(t *testing.T) { + stmts, err := parse.ParseString(` + local M = {} + function M.run() + return 1 + end + M.run() + `, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + graph := cfg.Build(&ast.FunctionExpr{ + ParList: &ast.ParList{}, + Stmts: stmts, + }) + if graph == nil { + t.Fatal("expected graph") + } + + var calleeSym cfg.SymbolID + graph.EachCallSite(func(_ cfg.Point, info *cfg.CallInfo) { + if info != nil && info.CalleeName == "run" { + calleeSym = info.CalleeSymbol + } + }) + if calleeSym == 0 { + t.Fatal("expected callsite callee symbol") + } + + fn := FunctionLiteralForGraphSymbol(graph, calleeSym) + if fn == nil { + t.Fatal("expected graph-local function literal for field definition") + } +} + +func TestFunctionLiteralForGraphSymbol_IgnoresMutableFieldPathBinding(t *testing.T) { + stmts, err := parse.ParseString(` + local M = { + dep = { + get = function() + return nil + end, + }, + } + M.dep = { + get = function() + return 1 + end, + } + M.dep.get() + `, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + graph := cfg.Build(&ast.FunctionExpr{ + ParList: &ast.ParList{}, + Stmts: stmts, + }) + if graph == nil { + t.Fatal("expected graph") + } + + var calleeSym cfg.SymbolID + graph.EachCallSite(func(_ cfg.Point, info *cfg.CallInfo) { + if info != nil && info.CalleeName == "get" { + calleeSym = info.CalleeSymbol + } + }) + if calleeSym == 0 { + t.Fatal("expected callsite callee symbol") + } + + if fn := FunctionLiteralForGraphSymbol(graph, calleeSym); fn != nil { + t.Fatalf("expected mutable field-path symbol to stay unresolved in graph-local resolver, got %v", fn) + } + if fn := FunctionLiteralForSymbol(graph, graph.Bindings(), calleeSym); fn == nil { + t.Fatal("expected binder-level symbol resolver to still find a literal for the shared field symbol") + } +} diff --git a/compiler/check/checker.go b/compiler/check/checker.go index f4c6aa7b..8f81acca 100644 --- a/compiler/check/checker.go +++ b/compiler/check/checker.go @@ -23,7 +23,7 @@ // reachability conditions and type narrowing facts at each CFG point. // // Phase D (Narrow): Applies flow solution to narrow declared types, computing -// final effective types for all expressions and generating function effects. +// final effective types for all expressions and generating function refinements. // // # INTERPROCEDURAL ANALYSIS // @@ -33,7 +33,7 @@ // - ParamHints: Inferred parameter types from call sites // - FuncTypes: Canonical local function types for sibling lookups // - LiteralSigs: Synthesized signatures for function literals -// - Effects: Function effects (side effects, termination), stored per symbol +// - Refinements: Function refinement summaries, stored per symbol // // # DETERMINISTIC ORDERING // diff --git a/compiler/check/effects/doc.go b/compiler/check/effects/doc.go index e2fb47bb..7435df76 100644 --- a/compiler/check/effects/doc.go +++ b/compiler/check/effects/doc.go @@ -24,7 +24,7 @@ // // # Effect Lookup // -// [LookupEffectBySym] resolves effects for called functions: +// [LookupRefinementBySym] resolves effects for called functions: // - First checks the effect store for analyzed functions // - Falls back to global type information for builtins // - Extracts effects from function type annotations diff --git a/compiler/check/effects/effectops_test.go b/compiler/check/effects/effectops_test.go index 84d62749..2012d7fe 100644 --- a/compiler/check/effects/effectops_test.go +++ b/compiler/check/effects/effectops_test.go @@ -28,10 +28,10 @@ func TestPropagate_EmptyResult(t *testing.T) { } func TestPropagate_WithLocalEffect(t *testing.T) { - fnEffect := &constraint.FunctionEffect{ + fnEffect := &constraint.FunctionRefinement{ Terminates: true, } - result := Propagate(&api.FuncResult{FnEffect: fnEffect}, nil) + result := Propagate(&api.FuncResult{FnRefinement: fnEffect}, nil) if result == nil { t.Fatal("expected non-nil effect") } @@ -40,15 +40,15 @@ func TestPropagate_WithLocalEffect(t *testing.T) { } } -func TestLookupEffectBySym_NilStore(t *testing.T) { - result := LookupEffectBySym(nil, nil, nil, 1) +func TestLookupRefinementBySym_NilStore(t *testing.T) { + result := LookupRefinementBySym(nil, nil, nil, 1) if result != nil { t.Errorf("expected nil for nil store, got %v", result) } } -func TestLookupEffectBySym_ZeroSym(t *testing.T) { - result := LookupEffectBySym(nil, nil, nil, 0) +func TestLookupRefinementBySym_ZeroSym(t *testing.T) { + result := LookupRefinementBySym(nil, nil, nil, 0) if result != nil { t.Errorf("expected nil for zero symbol, got %v", result) } @@ -107,7 +107,7 @@ func TestEffectFromType_NeverReturn(t *testing.T) { } func TestEffectFromType_WithRefinement(t *testing.T) { - eff := &constraint.FunctionEffect{Terminates: true} + eff := &constraint.FunctionRefinement{Terminates: true} fn := typ.Func().Returns(typ.String).WithRefinement(eff).Build() result := EffectFromType(fn) if result == nil { @@ -135,14 +135,14 @@ func TestEnrichExportWithEffects_NilGraph(t *testing.T) { func TestEnrichExportWithEffects_EmptyEffects(t *testing.T) { rec := typ.NewRecord().Build() - result := EnrichExportWithEffects(rec, "", map[cfg.SymbolID]*constraint.FunctionEffect{}, nil) + result := EnrichExportWithEffects(rec, "", map[cfg.SymbolID]*constraint.FunctionRefinement{}, nil) if result != rec { t.Error("expected original record returned when effects map is empty") } } func TestEnrichExportWithEffects_NonRecordNonInterface(t *testing.T) { - result := EnrichExportWithEffects(typ.String, "", map[cfg.SymbolID]*constraint.FunctionEffect{1: {}}, nil) + result := EnrichExportWithEffects(typ.String, "", map[cfg.SymbolID]*constraint.FunctionRefinement{1: {}}, nil) if result != typ.String { t.Error("expected original type returned for non-record/non-interface") } @@ -200,7 +200,7 @@ func TestEnrichExportWithEffects_PreservesRecordQualifiersAndMetatable(t *testin Metatable(meta). Build() - effectsBySym := map[cfg.SymbolID]*constraint.FunctionEffect{ + effectsBySym := map[cfg.SymbolID]*constraint.FunctionRefinement{ symValidate: {Terminates: true}, } @@ -255,11 +255,11 @@ func TestPropagate_CollectsEffectFromAssignmentCallSite(t *testing.T) { } result := Propagate(&api.FuncResult{ - Graph: graph, - FnEffect: &constraint.FunctionEffect{}, - }, func(sym cfg.SymbolID) *constraint.FunctionEffect { + Graph: graph, + FnRefinement: &constraint.FunctionRefinement{}, + }, func(sym cfg.SymbolID) *constraint.FunctionRefinement { if sym == symF { - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ Row: effect.Row{Labels: []effect.Label{effect.IO{}}}, } } @@ -283,11 +283,11 @@ func TestPropagate_CollectsEffectFromReturnCallSite(t *testing.T) { } result := Propagate(&api.FuncResult{ - Graph: graph, - FnEffect: &constraint.FunctionEffect{}, - }, func(sym cfg.SymbolID) *constraint.FunctionEffect { + Graph: graph, + FnRefinement: &constraint.FunctionRefinement{}, + }, func(sym cfg.SymbolID) *constraint.FunctionRefinement { if sym == symF { - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ Row: effect.Row{Labels: []effect.Label{effect.IO{}}}, } } @@ -320,11 +320,11 @@ func TestPropagate_UsesCanonicalCandidatesWhenRawSymbolMissing(t *testing.T) { }) result := Propagate(&api.FuncResult{ - Graph: graph, - FnEffect: &constraint.FunctionEffect{}, - }, func(sym cfg.SymbolID) *constraint.FunctionEffect { + Graph: graph, + FnRefinement: &constraint.FunctionRefinement{}, + }, func(sym cfg.SymbolID) *constraint.FunctionRefinement { if sym == symF { - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ Row: effect.Row{Labels: []effect.Label{effect.IO{}}}, } } @@ -359,10 +359,10 @@ func TestPropagate_UsesModuleBindingNameFallback(t *testing.T) { result := Propagate(&api.FuncResult{ Graph: graph, ModuleBindings: moduleBindings, - FnEffect: &constraint.FunctionEffect{}, - }, func(sym cfg.SymbolID) *constraint.FunctionEffect { + FnRefinement: &constraint.FunctionRefinement{}, + }, func(sym cfg.SymbolID) *constraint.FunctionRefinement { if sym == fallbackSym { - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ Row: effect.Row{Labels: []effect.Label{effect.IO{}}}, } } diff --git a/compiler/check/effects/export.go b/compiler/check/effects/export.go index f5282613..e4ac8d4d 100644 --- a/compiler/check/effects/export.go +++ b/compiler/check/effects/export.go @@ -24,12 +24,12 @@ import ( // EnrichExportWithEffects attaches known function refinements to exported values. // This is used when exporting module records or interfaces so method refinements // (like termination or guard effects) are preserved at module boundaries. -func EnrichExportWithEffects(export typ.Type, rootName string, effectsBySym map[cfg.SymbolID]*constraint.FunctionEffect, graph *cfg.Graph) typ.Type { +func EnrichExportWithEffects(export typ.Type, rootName string, effectsBySym map[cfg.SymbolID]*constraint.FunctionRefinement, graph *cfg.Graph) typ.Type { if export == nil || graph == nil || len(effectsBySym) == 0 { return export } - fieldEffects := make(map[string]*constraint.FunctionEffect) + fieldEffects := make(map[string]*constraint.FunctionRefinement) for _, sym := range cfg.SortedSymbolIDs(effectsBySym) { eff := effectsBySym[sym] if eff == nil { @@ -158,7 +158,7 @@ func appendRecordField(builder *typ.RecordBuilder, f typ.Field) *typ.RecordBuild return builder.Field(f.Name, f.Type) } -func applyFunctionRefinement(fn *typ.Function, eff *constraint.FunctionEffect) *typ.Function { +func applyFunctionRefinement(fn *typ.Function, eff *constraint.FunctionRefinement) *typ.Function { if fn == nil || eff == nil { return fn } diff --git a/compiler/check/effects/propagate.go b/compiler/check/effects/propagate.go index 832105db..8fe13da8 100644 --- a/compiler/check/effects/propagate.go +++ b/compiler/check/effects/propagate.go @@ -12,19 +12,19 @@ import ( "github.com/wippyai/go-lua/types/typ/unwrap" ) -// LookupFunc resolves a function effect by symbol. -type LookupFunc func(sym cfg.SymbolID) *constraint.FunctionEffect +// LookupFunc resolves a function refinement by symbol. +type LookupFunc func(sym cfg.SymbolID) *constraint.FunctionRefinement // Propagate computes the complete effect for a function by combining its // local effects with effects propagated from callees. -func Propagate(result *api.FuncResult, lookup LookupFunc) *constraint.FunctionEffect { +func Propagate(result *api.FuncResult, lookup LookupFunc) *constraint.FunctionRefinement { if result == nil { return nil } - fnEffect := result.FnEffect + fnEffect := result.FnRefinement if fnEffect == nil { - fnEffect = &constraint.FunctionEffect{} + fnEffect = &constraint.FunctionRefinement{} } if result.Graph == nil { @@ -82,7 +82,7 @@ func Propagate(result *api.FuncResult, lookup LookupFunc) *constraint.FunctionEf effectRow = row } - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ Row: effectRow, OnReturn: fnEffect.OnReturn, OnTrue: fnEffect.OnTrue, @@ -91,17 +91,17 @@ func Propagate(result *api.FuncResult, lookup LookupFunc) *constraint.FunctionEf } } -// LookupEffectBySym resolves effects from the store or global type information. -func LookupEffectBySym( - store api.EffectStore, +// LookupRefinementBySym resolves effects from the store or global type information. +func LookupRefinementBySym( + store api.RefinementStore, bindings *bind.BindingTable, globalTypes map[string]typ.Type, sym cfg.SymbolID, -) *constraint.FunctionEffect { +) *constraint.FunctionRefinement { if store == nil || sym == 0 { return nil } - if eff := store.LookupEffectBySym(sym); eff != nil { + if eff := store.LookupRefinementBySym(sym); eff != nil { return eff } if bindings != nil && globalTypes != nil { @@ -140,8 +140,8 @@ func TerminatesFromReachability(result *api.FuncResult) bool { return exitCond.IsFalse() } -// EffectFromType extracts FunctionEffect from a function type's declared effect annotations. -func EffectFromType(t typ.Type) *constraint.FunctionEffect { +// EffectFromType extracts FunctionRefinement from a function type's declared effect annotations. +func EffectFromType(t typ.Type) *constraint.FunctionRefinement { if t == nil { return nil } @@ -150,7 +150,7 @@ func EffectFromType(t typ.Type) *constraint.FunctionEffect { return nil } if fn.Refinement != nil { - if eff, ok := fn.Refinement.(*constraint.FunctionEffect); ok { + if eff, ok := fn.Refinement.(*constraint.FunctionRefinement); ok { return eff } } @@ -176,7 +176,7 @@ func EffectFromType(t typ.Type) *constraint.FunctionEffect { if row.Pure() && !row.IsOpen() && !terminates { return nil } - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ Row: row, Terminates: terminates, } diff --git a/compiler/check/flowbuild/assign/collect.go b/compiler/check/flowbuild/assign/collect.go index 211d4ca1..b5f3686b 100644 --- a/compiler/check/flowbuild/assign/collect.go +++ b/compiler/check/flowbuild/assign/collect.go @@ -5,6 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" + "github.com/wippyai/go-lua/compiler/check/overlaymut" "github.com/wippyai/go-lua/types/typ" ) @@ -17,66 +18,7 @@ func CollectFieldAssignments( synth func(ast.Expr, cfg.Point) typ.Type, filterSyms map[cfg.SymbolID]bool, ) map[cfg.SymbolID]map[string]typ.Type { - result := make(map[cfg.SymbolID]map[string]typ.Type) - if graph == nil { - return result - } - - graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { - if info == nil { - return - } - sources := info.Sources - for i, target := range info.Targets { - var source ast.Expr - if i < len(sources) { - source = sources[i] - } - var sym cfg.SymbolID - var fieldName string - - switch target.Kind { - case cfg.TargetField: - if target.BaseSymbol != 0 && len(target.FieldPath) == 1 { - sym = target.BaseSymbol - fieldName = target.FieldPath[0] - } - case cfg.TargetIndex: - if target.BaseSymbol != 0 && target.Key != nil { - if strKey, ok := target.Key.(*ast.StringExpr); ok && strKey.Value != "" { - sym = target.BaseSymbol - fieldName = strKey.Value - } - } - } - - if sym == 0 || fieldName == "" { - return - } - if filterSyms != nil && !filterSyms[sym] { - return - } - - var fieldType typ.Type - if source != nil && synth != nil { - fieldType = synth(source, p) - } - if fieldType == nil { - fieldType = typ.Unknown - } - - if result[sym] == nil { - result[sym] = make(map[string]typ.Type) - } - if existing := result[sym][fieldName]; existing != nil { - result[sym][fieldName] = typ.JoinPreferNonSoft(existing, fieldType) - } else { - result[sym][fieldName] = fieldType - } - } - }) - - return result + return overlaymut.CollectFieldAssignments(graph, synth, filterSyms) } // CollectIndexerAssignments scans the graph for dynamic index assignments (t[k] = v where k is non-const). @@ -87,68 +29,5 @@ func CollectIndexerAssignments( bindings *bind.BindingTable, filterSyms map[cfg.SymbolID]bool, ) map[cfg.SymbolID][]mutator.IndexerInfo { - result := make(map[cfg.SymbolID][]mutator.IndexerInfo) - if graph == nil { - return result - } - - graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { - if info == nil { - return - } - sources := info.Sources - for i, target := range info.Targets { - var source ast.Expr - if i < len(sources) { - source = sources[i] - } - if target.Kind != cfg.TargetIndex { - continue - } - sym := target.BaseSymbol - if sym == 0 { - continue - } - if filterSyms != nil && !filterSyms[sym] { - continue - } - - // Skip string literal keys (handled by field assignments) - if _, ok := target.Key.(*ast.StringExpr); ok { - continue - } - - // Determine key type - var keyType typ.Type - switch k := target.Key.(type) { - case *ast.IdentExpr: - if synth != nil { - keyType = synth(k, p) - } - case *ast.NumberExpr: - keyType = typ.Integer - default: - if synth != nil && target.Key != nil { - keyType = synth(target.Key, p) - } - } - keyType = canonicalDynamicKeyType(keyType) - - // Determine value type - var valType typ.Type - if source != nil && synth != nil { - valType = synth(source, p) - } - if valType == nil { - valType = typ.Unknown - } - - result[sym] = append(result[sym], mutator.IndexerInfo{ - KeyType: keyType, - ValType: valType, - }) - } - }) - - return result + return overlaymut.CollectIndexerAssignments(graph, synth, bindings, filterSyms) } diff --git a/compiler/check/flowbuild/assign/emit.go b/compiler/check/flowbuild/assign/emit.go index 4fcd5e86..b7245bac 100644 --- a/compiler/check/flowbuild/assign/emit.go +++ b/compiler/check/flowbuild/assign/emit.go @@ -33,6 +33,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" + cfganalysis "github.com/wippyai/go-lua/compiler/cfg/analysis" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" "github.com/wippyai/go-lua/compiler/check/flowbuild/cond" @@ -85,7 +86,8 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect // Collects spec-narrowed types from contract specs and propagates through method calls. // Uses expandValues with SpecTypes overlay for method call synthesis. specNarrowed := CollectSpecNarrowedTypes(fc.Graph, fc.Scopes, synth, symResolver, fc.API, fc.ModuleBindings) - inferredTypes := collectInferredTypes(fc.Graph, fc.Scopes, synth, fc.API, symResolver, specNarrowed, inputs.AnnotatedVars, inputs, fc.ModuleBindings, fc.CallCtx, fc.TypeOps, fc.Services) + preflowBranchSolution := buildPreflowBranchSolution(fc, inputs) + inferredTypes := collectInferredTypes(fc.Graph, fc.Scopes, synth, fc.API, symResolver, specNarrowed, inputs.AnnotatedVars, inputs, fc.ModuleBindings, fc.CallCtx, fc.TypeOps, preflowBranchSolution, fc.Services) // Promote inferred parameter types into DeclaredTypes for unannotated params. // This enables bidirectional inference at call sites (e.g., custom assert helpers). if inputs.DeclaredTypes != nil { @@ -149,13 +151,14 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect overlayTypes = mergeSpecTypesInto(overlayTypes, inferredTypes) overlayTypes = mergeSpecTypesInto(overlayTypes, specNarrowed) overlayTypes = mergeSpecTypesInto(overlayTypes, loopVarTypes) - // Precompute truthy guards: map from CFG point to paths that are narrowed (non-nil) at that point. // Used during table literal synthesis to unwrap optional types. truthyGuards := guard.CollectTruthyGuards(fc.Graph, bindings) typeGuards := guard.CollectTypeGuards(fc.Graph, bindings) - baseSynth := resolve.SynthWithOverlay(overlayTypes, bindings, synth) + baseSynth := synthWithOverlayAndPreflow(overlayTypes, bindings, inputs, fc.CallCtx, fc.TypeOps, preflowBranchSolution, synth) + idom, _ := cfganalysis.ComputeDominators(fc.Graph.CFG()) + structuredWrites := indexStructuredWrites(fc.Graph) var wrappedSynth func(ast.Expr, cfg.Point) typ.Type wrappedSynth = func(expr ast.Expr, p cfg.Point) typ.Type { if table, ok := expr.(*ast.TableExpr); ok && !tblutil.TableHasFunctionField(table) { @@ -269,6 +272,7 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect // Use pre-assignment symbol overlays for assignment targets so RHS // synthesis follows Lua evaluation order (`x = f(x, ...)`). rhsOverlay := rhsSpecTypesAtAssignPoint(fc.Graph, info, p, overlayTypes, resolverWithSpec) + rhsOverlay = enrichStructuredOverlayAtPoint(fc.Graph, idom, structuredWrites, p, rhsOverlay, resolverWithSpec, wrappedSynth) values = expandedAssignValues(fc.API, info, p, rhsOverlay) valuesComputed = true } @@ -309,6 +313,7 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect ensureValues() if value := assignValueAt(values, i); value != nil { assignedType = value + assignedType = preferPreciseDirectSourceType(assignedType, source, p, sc, wrappedSynth, len(info.Targets) == 1) } else if wrappedSynth != nil && source != nil { assignedType = wrappedSynth(source, p) } @@ -380,7 +385,7 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect // Extract predicate link if RHS is a predicate call. if callInfo, retIndex := info.CallForTarget(i); callInfo != nil { - if link := cond.ExtractPredicateLinkFromCallInfo(callInfo, retIndex, p, sc, inputs, derived.TypeKeyRes, wrappedSynth, derived.EffectBySym, symResolver, fc.Graph, fc.ModuleBindings); link != nil { + if link := cond.ExtractPredicateLinkFromCallInfo(callInfo, retIndex, p, sc, inputs, derived.TypeKeyRes, wrappedSynth, derived.RefinementBySym, symResolver, fc.Graph, fc.ModuleBindings); link != nil { if retIndex == 1 && callInfo.IsTypeCheck && callInfo.Method == "is" && callInfo.Receiver != nil && derived.TypeKeyRes != nil { if typeKey, ok := derived.TypeKeyRes(callInfo.TypeCheckName, sc); ok && !typeKey.IsZero() { valuePath := constraint.Path{} @@ -442,10 +447,10 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect } } - // Fallback: check function effect for KeyOf-based keys collector - if tableSym == 0 && derived.EffectBySym != nil { + // Fallback: check function refinement for KeyOf-based keys collector. + if tableSym == 0 && derived.RefinementBySym != nil { for _, calleeSym := range calleeSymbols { - eff := derived.EffectBySym(calleeSym) + eff := derived.RefinementBySym(calleeSym) if eff == nil { continue } diff --git a/compiler/check/flowbuild/assign/emit_test.go b/compiler/check/flowbuild/assign/emit_test.go index b204918e..537c895f 100644 --- a/compiler/check/flowbuild/assign/emit_test.go +++ b/compiler/check/flowbuild/assign/emit_test.go @@ -6,6 +6,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/flowbuild/core" "github.com/wippyai/go-lua/compiler/check/flowbuild/keyscoll" "github.com/wippyai/go-lua/compiler/check/flowbuild/resolve" @@ -17,6 +18,31 @@ import ( "github.com/wippyai/go-lua/types/typ" ) +type preciseSourceSynthStub struct { + preciseType typ.Type +} + +func (s *preciseSourceSynthStub) TypeOf(expr ast.Expr, _ cfg.Point) typ.Type { + switch expr.(type) { + case *ast.LogicalOpExpr: + return s.preciseType + default: + return typ.Unknown + } +} + +func (s *preciseSourceSynthStub) ExpandValues([]ast.Expr, int, cfg.Point) []typ.Type { return nil } + +func (s *preciseSourceSynthStub) InferIterVars([]ast.Expr, int, cfg.Point) []typ.Type { return nil } + +func (s *preciseSourceSynthStub) ExpandValuesWithSpecTypes([]ast.Expr, int, cfg.Point, api.SpecTypes) []typ.Type { + return []typ.Type{typ.Any} +} + +func (s *preciseSourceSynthStub) InferIterVarsWithSpecTypes([]ast.Expr, int, cfg.Point, api.SpecTypes) []typ.Type { + return nil +} + func TestExtractAssignments_NilConfig(t *testing.T) { inputs := &flow.Inputs{ Assignments: []flow.UnifiedAssignment{}, @@ -268,9 +294,9 @@ func TestExtractAssignments_KeysCollectorEffectFallbackIgnoresNonCollectorEffect Synth: func(ast.Expr, cfg.Point) typ.Type { return typ.Unknown }, - EffectBySym: func(cfg.SymbolID) *constraint.FunctionEffect { + RefinementBySym: func(cfg.SymbolID) *constraint.FunctionRefinement { // Non-collector effect (no KeyOf constraint). - return &constraint.FunctionEffect{} + return &constraint.FunctionRefinement{} }, }, }, inputs, nil) @@ -280,6 +306,54 @@ func TestExtractAssignments_KeysCollectorEffectFallbackIgnoresNonCollectorEffect } } +func TestExtractAssignments_PrefersPreciseDirectTypeOverExpandedAnyForLogicalOr(t *testing.T) { + code := ` + local left = nil + local right = nil + local ctx = left or right + ` + chunk, err := parse.ParseString(code, "emit_precise_or.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + graph := cfg.Build(&ast.FunctionExpr{Stmts: chunk}, "emit_precise_or") + exit := graph.Exit() + ctxSym, ok := graph.SymbolAt(exit, "ctx") + if !ok || ctxSym == 0 { + t.Fatal("expected symbol for ctx") + } + + contextAlias := typ.NewAlias("Context", typ.NewMap(typ.String, typ.Any)) + synthAPI := &preciseSourceSynthStub{preciseType: contextAlias} + inputs := &flow.Inputs{ + DeclaredTypes: make(map[cfg.SymbolID]typ.Type), + PredicateLinks: make(map[string]flow.PredicateLink), + SiblingAssignments: make(map[flow.SiblingKey]*flow.SiblingAssignment), + } + ExtractAssignments(&core.FlowContext{ + Graph: graph, + API: synthAPI, + Derived: &core.Derived{ + Synth: synthAPI.TypeOf, + SymResolver: func(cfg.Point, cfg.SymbolID) (typ.Type, bool) { + return nil, false + }, + }, + }, inputs, nil) + + for _, assign := range inputs.Assignments { + if assign.TargetPath.Symbol != ctxSym { + continue + } + if !typ.TypeEquals(assign.Type, contextAlias) { + t.Fatalf("ctx assignment type = %v, want %v", assign.Type, contextAlias) + } + return + } + + t.Fatal("expected assignment for ctx") +} + func TestExtractAssignments_KeysCollectorEffectFallbackRespectsReturnIndex(t *testing.T) { code := ` local function two_returns(tbl) @@ -318,8 +392,8 @@ func TestExtractAssignments_KeysCollectorEffectFallbackRespectsReturnIndex(t *te Synth: func(ast.Expr, cfg.Point) typ.Type { return typ.Unknown }, - EffectBySym: func(cfg.SymbolID) *constraint.FunctionEffect { - return &constraint.FunctionEffect{ + RefinementBySym: func(cfg.SymbolID) *constraint.FunctionRefinement { + return &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(constraint.KeyOf{ Table: constraint.ParamPath(0), Key: constraint.RetPath(1), @@ -376,17 +450,17 @@ func TestExtractAssignments_KeysCollectorEffectFallback_TriesAllNameCandidates(t Synth: func(ast.Expr, cfg.Point) typ.Type { return typ.Unknown }, - EffectBySym: func(sym cfg.SymbolID) *constraint.FunctionEffect { + RefinementBySym: func(sym cfg.SymbolID) *constraint.FunctionRefinement { switch sym { case mismatchSym: - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(constraint.KeyOf{ Table: constraint.ParamPath(0), Key: constraint.RetPath(1), }), } case matchSym: - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(constraint.KeyOf{ Table: constraint.ParamPath(0), Key: constraint.RetPath(0), @@ -575,11 +649,11 @@ func TestExtractAssignments_NestedDynamicIndex_LiftsToRootIndexer(t *testing.T) if assign.Symbol != subscribersSym || assign.KeySymbol != cidSym || len(assign.Segments) != 0 { continue } - if _, ok := assign.ValType.(*typ.Map); ok { - lifted = assign - break - } + if _, ok := assign.ValType.(*typ.Map); ok { + lifted = assign + break } + } if lifted == nil { t.Fatalf("expected lifted indexer assignment for nested dynamic write, got %#v", inputs.IndexerAssignments) } @@ -667,6 +741,25 @@ func TestCorrelationsFromFunctionType_ImplicitStructuredErrorConvention(t *testi } } +func TestCorrelationsFromFunctionType_ImplicitUnionErrorConvention(t *testing.T) { + errorType := typ.NewRecord(). + Field("message", typ.String). + Build() + fnType := typ.Func(). + Returns(typ.NewOptional(typ.String), typ.NewOptional(typ.NewUnion(typ.String, typ.LuaError, errorType))). + Build() + inverse, co := correlationsFromFunctionType(fnType) + if len(co) != 0 { + t.Fatalf("expected no co-correlations, got %v", co) + } + if len(inverse) != 1 { + t.Fatalf("expected one convention-based correlation, got %v", inverse) + } + if inverse[0] != (flow.ReturnCorrelation{ValueIndex: 0, ErrorIndex: 1}) { + t.Fatalf("unexpected convention correlation: %+v", inverse[0]) + } +} + func TestCorrelationsFromFunctionType_NoImplicitStructuredErrorWithoutMessage(t *testing.T) { auxType := typ.NewRecord(). Field("status_code", typ.Number). diff --git a/compiler/check/flowbuild/assign/error_return_policy.go b/compiler/check/flowbuild/assign/error_return_policy.go index e6985e0c..3d26cc74 100644 --- a/compiler/check/flowbuild/assign/error_return_policy.go +++ b/compiler/check/flowbuild/assign/error_return_policy.go @@ -1,6 +1,7 @@ package assign import ( + "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/query/core" "github.com/wippyai/go-lua/types/subtype" @@ -48,13 +49,48 @@ func isOptionalErrorLike(t typ.Type) bool { if inner == nil { return false } - if subtype.IsSubtype(inner, typ.LuaError) || subtype.IsSubtype(inner, typ.String) { + return isErrorLikeType(inner) +} + +func isErrorLikeType(t typ.Type) bool { + if t == nil { + return false + } + t = unwrap.Alias(t) + if t == nil { + return false + } + + switch v := t.(type) { + case *typ.Union: + if len(v.Members) == 0 { + return false + } + for _, m := range v.Members { + if m == nil || m.Kind() == kind.Nil { + continue + } + if !isErrorLikeType(m) { + return false + } + } + return true + case *typ.Intersection: + for _, m := range v.Members { + if isErrorLikeType(m) { + return true + } + } + return false + } + + if subtype.IsSubtype(t, typ.LuaError) || subtype.IsSubtype(t, typ.String) { return true } // Structured error objects in Lua code often expose `message` as a field. // Treat Optional<{message: string}> as error-like for canonical (value, err) // correlation when explicit specs are absent. - messageType, ok := core.Field(inner, "message") + messageType, ok := core.Field(t, "message") if !ok || messageType == nil { return false } diff --git a/compiler/check/flowbuild/assign/infer.go b/compiler/check/flowbuild/assign/infer.go index c1d55951..48606dfb 100644 --- a/compiler/check/flowbuild/assign/infer.go +++ b/compiler/check/flowbuild/assign/infer.go @@ -49,6 +49,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" + cfganalysis "github.com/wippyai/go-lua/compiler/cfg/analysis" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" fbcore "github.com/wippyai/go-lua/compiler/check/flowbuild/core" @@ -116,9 +117,10 @@ func CollectInferredTypes(fc *fbcore.FlowContext, specTypes api.SpecTypes, annot if fc.Derived != nil { symResolver = fc.Derived.SymResolver } + preflowBranchSolution := buildPreflowBranchSolution(fc, inputs) return collectInferredTypes( fc.Graph, fc.Scopes, synth, fc.API, symResolver, - specTypes, annotated, inputs, fc.ModuleBindings, fc.CallCtx, fc.TypeOps, fc.Services, + specTypes, annotated, inputs, fc.ModuleBindings, fc.CallCtx, fc.TypeOps, preflowBranchSolution, fc.Services, ) } @@ -141,12 +143,15 @@ func collectInferredTypes( moduleBindings *bind.BindingTable, callCtx *db.QueryContext, typeOps core.TypeOps, + preflowBranchSolution *flow.Solution, services fbcore.FlowServices, ) api.SpecTypes { inferred := make(api.SpecTypes) if graph == nil { return inferred } + idom, _ := cfganalysis.ComputeDominators(graph.CFG()) + structuredWrites := indexStructuredWrites(graph) bindings := graph.Bindings() if moduleBindings == nil { @@ -458,7 +463,7 @@ func collectInferredTypes( overlayScratch = mergeSpecTypesSoftInto(overlayScratch, inferred, specTypes) overlay := overlayScratch - wrappedSynth := synthWithInferenceOverlay(overlay, funcSigTypes, paramSet, annotated, bindings, synth) + wrappedSynth := synthWithInferenceOverlay(graph, overlay, funcSigTypes, paramSet, annotated, bindings, inputs, callCtx, typeOps, preflowBranchSolution, synth) callSynthFor := func(p cfg.Point, info *cfg.CallInfo) func(ast.Expr, cfg.Point) typ.Type { if info == nil { return wrappedSynth @@ -484,8 +489,9 @@ func collectInferredTypes( } return rhsResolver(point, sym) }) + callOverlay = enrichStructuredOverlayAtPoint(graph, idom, structuredWrites, p, callOverlay, rhsResolver, wrappedSynth) - return synthWithInferenceOverlay(callOverlay, funcSigTypes, paramSet, annotated, bindings, synth) + return synthWithInferenceOverlay(graph, callOverlay, funcSigTypes, paramSet, annotated, bindings, inputs, callCtx, typeOps, preflowBranchSolution, synth) } // Infer expected argument types for a call using the call inference pipeline. @@ -668,21 +674,13 @@ func collectInferredTypes( } return rhsResolver(point, sym) }) + rhsOverlay = enrichStructuredOverlayAtPoint(graph, idom, structuredWrites, p, rhsOverlay, rhsResolver, wrappedSynth) values = expandedAssignValues(synthAPI, info, p, rhsOverlay) valuesComputed = true } if value := assignValueAt(values, i); !typ.IsAbsentOrUnknown(value) { assignedType = value - // Prefer direct expression synthesis when slot expansion - // yields `any` for non-call RHS but the expression itself - // has a more precise type (e.g. indexed map reads). - if source != nil && wrappedSynth != nil && typ.IsAny(value) { - if _, isCall := source.(*ast.FuncCallExpr); !isCall { - if precise := resolve.Ref(wrappedSynth(source, p), sc); !typ.IsAbsentOrUnknown(precise) && !typ.IsAny(precise) { - assignedType = precise - } - } - } + assignedType = preferPreciseDirectSourceType(assignedType, source, p, sc, wrappedSynth, len(info.Targets) == 1) } else if wrappedSynth != nil && source != nil { assignedType = wrappedSynth(source, p) } @@ -935,22 +933,32 @@ func dedupeSymbolIDs(refs []cfg.SymbolID) []cfg.SymbolID { } func synthWithInferenceOverlay( + graph *cfg.Graph, overlay map[cfg.SymbolID]typ.Type, funcSigTypes map[cfg.SymbolID]typ.Type, paramSet map[cfg.SymbolID]bool, annotated map[cfg.SymbolID]bool, bindings *bind.BindingTable, + inputs *flow.Inputs, + callCtx *db.QueryContext, + typeOps core.TypeOps, + preflow *flow.Solution, base func(ast.Expr, cfg.Point) typ.Type, ) func(ast.Expr, cfg.Point) typ.Type { - return func(expr ast.Expr, p cfg.Point) typ.Type { + _ = graph + mergedOverlay := make(map[cfg.SymbolID]typ.Type, len(overlay)+len(funcSigTypes)) + for sym, t := range funcSigTypes { + if t != nil { + mergedOverlay[sym] = t + } + } + for sym, t := range overlay { + mergedOverlay[sym] = t + } + + wrappedBase := func(expr ast.Expr, p cfg.Point) typ.Type { if ident, ok := expr.(*ast.IdentExpr); ok && bindings != nil { if sym, ok := bindings.SymbolOf(ident); ok && sym != 0 { - if t, exists := overlay[sym]; exists { - return t - } - if t, exists := funcSigTypes[sym]; exists { - return t - } if paramSet[sym] && (annotated == nil || !annotated[sym]) { return typ.Unknown } @@ -961,6 +969,8 @@ func synthWithInferenceOverlay( } return base(expr, p) } + + return synthWithOverlayAndPreflow(mergedOverlay, bindings, inputs, callCtx, typeOps, preflow, wrappedBase) } func assignmentOwningSourceCall(assigns []*cfg.AssignInfo, call *cfg.CallInfo) *cfg.AssignInfo { diff --git a/compiler/check/flowbuild/assign/infer_test.go b/compiler/check/flowbuild/assign/infer_test.go index 299bcc68..1248dcee 100644 --- a/compiler/check/flowbuild/assign/infer_test.go +++ b/compiler/check/flowbuild/assign/infer_test.go @@ -302,6 +302,7 @@ func TestCollectInferredTypes_UsesModuleCalleeCandidatesForExpectedArgs(t *testi db.NewQueryContext(db.New()), querycore.NewEngine(), nil, + nil, ) got := inferred[xSym] @@ -378,11 +379,16 @@ func TestSynthWithInferenceOverlay_PriorityAndParamFallback(t *testing.T) { paramSet := map[cfg.SymbolID]bool{aSym: true} synth := synthWithInferenceOverlay( + nil, map[cfg.SymbolID]typ.Type{aSym: typ.String}, map[cfg.SymbolID]typ.Type{aSym: typ.Number}, paramSet, nil, bindings, + nil, + nil, + nil, + nil, base, ) if got := synth(ident, 0); !typ.TypeEquals(got, typ.String) { @@ -390,11 +396,16 @@ func TestSynthWithInferenceOverlay_PriorityAndParamFallback(t *testing.T) { } synth = synthWithInferenceOverlay( + nil, nil, map[cfg.SymbolID]typ.Type{aSym: typ.Number}, paramSet, nil, bindings, + nil, + nil, + nil, + nil, base, ) if got := synth(ident, 0); !typ.TypeEquals(got, typ.Number) { @@ -402,11 +413,16 @@ func TestSynthWithInferenceOverlay_PriorityAndParamFallback(t *testing.T) { } synth = synthWithInferenceOverlay( + nil, nil, nil, paramSet, nil, bindings, + nil, + nil, + nil, + nil, base, ) if got := synth(ident, 0); !typ.TypeEquals(got, typ.Unknown) { @@ -414,11 +430,16 @@ func TestSynthWithInferenceOverlay_PriorityAndParamFallback(t *testing.T) { } synth = synthWithInferenceOverlay( + nil, nil, nil, paramSet, map[cfg.SymbolID]bool{aSym: true}, bindings, + nil, + nil, + nil, + nil, base, ) if got := synth(ident, 0); !typ.TypeEquals(got, typ.Boolean) { @@ -434,11 +455,16 @@ func TestSynthWithInferenceOverlay_PreservesNilOverlayEntries(t *testing.T) { baseCalled := false synth := synthWithInferenceOverlay( + nil, map[cfg.SymbolID]typ.Type{aSym: nil}, nil, nil, nil, bindings, + nil, + nil, + nil, + nil, func(ast.Expr, cfg.Point) typ.Type { baseCalled = true return typ.Boolean diff --git a/compiler/check/flowbuild/assign/precision.go b/compiler/check/flowbuild/assign/precision.go new file mode 100644 index 00000000..f7ef8a60 --- /dev/null +++ b/compiler/check/flowbuild/assign/precision.go @@ -0,0 +1,71 @@ +package assign + +import ( + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/flowbuild/resolve" + "github.com/wippyai/go-lua/compiler/check/scope" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" +) + +// preferPreciseDirectSourceType keeps assignment inference on the canonical +// expression-synthesis path. +// +// Direct synthesis is allowed to repair a slot only when it is strictly more +// informative than the expanded assignment value. This keeps tuple expansion as +// the primary source of truth for assignment slots while still allowing +// canonical single-expression synthesis to repair top-like degradation. +func preferPreciseDirectSourceType( + assignedType typ.Type, + source ast.Expr, + p cfg.Point, + sc *scope.State, + synth func(ast.Expr, cfg.Point) typ.Type, + singleTarget bool, +) typ.Type { + if source == nil || synth == nil { + return assignedType + } + switch source.(type) { + case *ast.Comma3Expr: + return assignedType + } + + precise := resolve.Ref(synth(source, p), sc) + if typ.IsAbsentOrUnknown(precise) { + return assignedType + } + if singleTarget { + if typ.IsAbsentOrUnknown(assignedType) || typ.IsAny(assignedType) { + return precise + } + if subtype.IsSubtype(precise, assignedType) && !subtype.IsSubtype(assignedType, precise) { + return precise + } + if preferNamedEquivalentDirectType(precise, assignedType) { + return precise + } + return assignedType + } + if typ.IsAny(assignedType) && !typ.IsAny(precise) { + return precise + } + return assignedType +} + +func preferNamedEquivalentDirectType(precise, assignedType typ.Type) bool { + if !isNamedIdentityType(precise) || isNamedIdentityType(assignedType) { + return false + } + return subtype.IsSubtype(precise, assignedType) && subtype.IsSubtype(assignedType, precise) +} + +func isNamedIdentityType(t typ.Type) bool { + switch t.(type) { + case *typ.Alias, *typ.Ref: + return true + default: + return false + } +} diff --git a/compiler/check/flowbuild/assign/precision_test.go b/compiler/check/flowbuild/assign/precision_test.go new file mode 100644 index 00000000..b4fdc423 --- /dev/null +++ b/compiler/check/flowbuild/assign/precision_test.go @@ -0,0 +1,49 @@ +package assign + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/types/typ" +) + +func TestPreferPreciseDirectSourceType_PrefersNamedEquivalentAliasForSingleTarget(t *testing.T) { + record := typ.NewRecord(). + Field("id", typ.String). + Field("count", typ.Integer). + Build() + alias := typ.NewAlias("Counter", record) + + got := preferPreciseDirectSourceType( + record, + &ast.IdentExpr{Value: "x"}, + 0, + nil, + func(ast.Expr, cfg.Point) typ.Type { return alias }, + true, + ) + if got != alias { + t.Fatalf("expected direct named alias to win, got %s", typ.FormatShort(got)) + } +} + +func TestPreferPreciseDirectSourceType_DoesNotReplaceNamedAssignedType(t *testing.T) { + record := typ.NewRecord(). + Field("id", typ.String). + Field("count", typ.Integer). + Build() + alias := typ.NewAlias("Counter", record) + + got := preferPreciseDirectSourceType( + alias, + &ast.IdentExpr{Value: "x"}, + 0, + nil, + func(ast.Expr, cfg.Point) typ.Type { return record }, + true, + ) + if got != alias { + t.Fatalf("expected existing named assigned type to remain, got %s", typ.FormatShort(got)) + } +} diff --git a/compiler/check/flowbuild/assign/preflow_synth.go b/compiler/check/flowbuild/assign/preflow_synth.go new file mode 100644 index 00000000..5f677926 --- /dev/null +++ b/compiler/check/flowbuild/assign/preflow_synth.go @@ -0,0 +1,127 @@ +package assign + +import ( + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/flowbuild/cond" + fbcore "github.com/wippyai/go-lua/compiler/check/flowbuild/core" + fbpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" + "github.com/wippyai/go-lua/compiler/check/flowbuild/predicate" + "github.com/wippyai/go-lua/types/db" + "github.com/wippyai/go-lua/types/flow" + "github.com/wippyai/go-lua/types/narrow" + "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/types/typ" +) + +type narrowResolverAdapter struct { + ctx *db.QueryContext + ops core.TypeOps +} + +var _ narrow.Resolver = (*narrowResolverAdapter)(nil) + +func (r narrowResolverAdapter) Field(t typ.Type, name string) (typ.Type, bool) { + if r.ops == nil { + return nil, false + } + return r.ops.Field(r.ctx, t, name) +} + +func (r narrowResolverAdapter) Index(t typ.Type, key typ.Type) (typ.Type, bool) { + if r.ops == nil { + return nil, false + } + return r.ops.Index(r.ctx, t, key) +} + +// buildPreflowBranchSolution solves only branch/numeric edge facts that are +// already available before assignment extraction completes. +// +// This gives local inference access to canonical branch narrowing such as +// discriminant checks on parameters, without depending on later assignment- +// derived facts or full post-extraction solve. +func buildPreflowBranchSolution(fc *fbcore.FlowContext, inputs *flow.Inputs) *flow.Solution { + if fc == nil || inputs == nil || inputs.Graph == nil || fc.TypeOps == nil { + return nil + } + + temp := *inputs + temp.EdgeConditions = nil + temp.EdgeNumericConstraints = nil + + cond.ExtractEdgeConstraints(fc, &temp) + cond.ExtractNumericConstraints(fc, &temp) + + return flow.Solve(&temp, narrowResolverAdapter{ctx: fc.CallCtx, ops: fc.TypeOps}) +} + +// synthWithOverlayAndPreflow wraps base synthesis with overlay lookup and a +// preflow branch-narrowing view for identifiers and attribute/index reads. +// +// This keeps assignment inference on the canonical synthesis path while letting +// recursive field/index expressions observe already-provable branch facts. +func synthWithOverlayAndPreflow( + overlay map[cfg.SymbolID]typ.Type, + bindings *bind.BindingTable, + inputs *flow.Inputs, + callCtx *db.QueryContext, + typeOps core.TypeOps, + preflow *flow.Solution, + base func(ast.Expr, cfg.Point) typ.Type, +) func(ast.Expr, cfg.Point) typ.Type { + var synth func(ast.Expr, cfg.Point) typ.Type + + synth = func(expr ast.Expr, p cfg.Point) typ.Type { + if expr == nil { + return nil + } + + if ident, ok := expr.(*ast.IdentExpr); ok && bindings != nil { + if sym, ok := bindings.SymbolOf(ident); ok && sym != 0 { + if t, exists := overlay[sym]; exists { + return t + } + } + } + + if preflow != nil && bindings != nil && inputs != nil { + constResolver := predicate.BuildConstResolver(inputs, p) + if path := fbpath.FromExprWithBindings(expr, constResolver, bindings); !path.IsEmpty() { + if narrowed := preflow.NarrowedTypeAt(p, path); !typ.IsAbsentOrUnknown(narrowed) { + return narrowed + } + } + } + + if attr, ok := expr.(*ast.AttrGetExpr); ok && typeOps != nil { + objType := synth(attr.Object, p) + if !typ.IsAbsentOrUnknown(objType) { + switch key := attr.Key.(type) { + case *ast.StringExpr: + if ft, ok := typeOps.Field(callCtx, objType, key.Value); ok && !typ.IsAbsentOrUnknown(ft) { + return ft + } + if it, ok := typeOps.Index(callCtx, objType, typ.LiteralString(key.Value)); ok && !typ.IsAbsentOrUnknown(it) { + return it + } + default: + keyType := synth(attr.Key, p) + if !typ.IsAbsentOrUnknown(keyType) { + if it, ok := typeOps.Index(callCtx, objType, keyType); ok && !typ.IsAbsentOrUnknown(it) { + return it + } + } + } + } + } + + if base == nil { + return nil + } + return base(expr, p) + } + + return synth +} diff --git a/compiler/check/flowbuild/assign/structured_overlay.go b/compiler/check/flowbuild/assign/structured_overlay.go new file mode 100644 index 00000000..7c94be02 --- /dev/null +++ b/compiler/check/flowbuild/assign/structured_overlay.go @@ -0,0 +1,307 @@ +package assign + +import ( + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" + cfganalysis "github.com/wippyai/go-lua/compiler/cfg/analysis" + "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/kind" + "github.com/wippyai/go-lua/types/typ" +) + +type structuredWrite struct { + point cfg.Point + versionID int + segments []constraint.Segment + source ast.Expr +} + +// indexStructuredWrites collects static field/index writes keyed by base symbol. +func indexStructuredWrites(graph *cfg.Graph) map[cfg.SymbolID][]structuredWrite { + result := make(map[cfg.SymbolID][]structuredWrite) + if graph == nil { + return result + } + + graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { + if info == nil { + return + } + for i, target := range info.Targets { + write, sym, ok := structuredWriteForTarget(graph, p, info.SourceAt(i), target) + if !ok { + continue + } + result[sym] = append(result[sym], write) + } + }) + + return result +} + +// enrichStructuredOverlayAtPoint applies dominating visible field writes for the +// current symbol version into a point-specific identifier overlay. +func enrichStructuredOverlayAtPoint( + graph *cfg.Graph, + idom map[cfg.Point]cfg.Point, + writes map[cfg.SymbolID][]structuredWrite, + p cfg.Point, + overlay api.SpecTypes, + resolveSym func(cfg.Point, cfg.SymbolID) (typ.Type, bool), + synth func(ast.Expr, cfg.Point) typ.Type, +) api.SpecTypes { + if graph == nil || len(writes) == 0 { + return overlay + } + + out := overlay + copied := false + for sym, symWrites := range writes { + if sym == 0 || len(symWrites) == 0 { + continue + } + + baseType, ok := out[sym] + if !ok && resolveSym != nil { + baseType, ok = resolveSym(p, sym) + } + + merged := mergeVisibleStructuredWrites(graph, idom, symWrites, sym, p, baseType, synth) + if merged == nil || (ok && typ.TypeEquals(merged, baseType)) { + continue + } + + if !copied { + if len(overlay) == 0 { + out = make(api.SpecTypes, 1) + } else { + out = make(api.SpecTypes, len(overlay)+1) + for k, v := range overlay { + out[k] = v + } + } + copied = true + } + out[sym] = merged + } + + return out +} + +func structuredWriteForTarget(graph *cfg.Graph, p cfg.Point, source ast.Expr, target cfg.AssignTarget) (structuredWrite, cfg.SymbolID, bool) { + if graph == nil || target.BaseSymbol == 0 { + return structuredWrite{}, 0, false + } + + var segments []constraint.Segment + switch target.Kind { + case cfg.TargetField: + if len(target.FieldPath) == 0 { + return structuredWrite{}, 0, false + } + segments = make([]constraint.Segment, len(target.FieldPath)) + for i, field := range target.FieldPath { + if field == "" { + return structuredWrite{}, 0, false + } + segments[i] = constraint.Segment{Kind: constraint.SegmentField, Name: field} + } + case cfg.TargetIndex: + switch key := target.Key.(type) { + case *ast.StringExpr: + if key.Value == "" { + return structuredWrite{}, 0, false + } + segments = []constraint.Segment{{Kind: constraint.SegmentIndexString, Name: key.Value}} + case *ast.NumberExpr: + segments = []constraint.Segment{{Kind: constraint.SegmentIndexInt}} + default: + return structuredWrite{}, 0, false + } + default: + return structuredWrite{}, 0, false + } + + version := graph.VisibleVersion(p, target.BaseSymbol) + if version.ID == 0 { + return structuredWrite{}, 0, false + } + + return structuredWrite{ + point: p, + versionID: version.ID, + segments: segments, + source: source, + }, target.BaseSymbol, true +} + +func mergeVisibleStructuredWrites( + graph *cfg.Graph, + idom map[cfg.Point]cfg.Point, + writes []structuredWrite, + sym cfg.SymbolID, + at cfg.Point, + baseType typ.Type, + synth func(ast.Expr, cfg.Point) typ.Type, +) typ.Type { + if graph == nil || sym == 0 || len(writes) == 0 { + return baseType + } + + currentVersion := graph.VisibleVersion(at, sym) + if currentVersion.ID == 0 { + return baseType + } + + current := baseType + for _, write := range writes { + if write.versionID != currentVersion.ID { + continue + } + if write.point == at || !cfganalysis.StrictlyDominates(idom, write.point, at) { + continue + } + + valueType := typ.Unknown + if write.source != nil && synth != nil { + if resolved := synth(write.source, write.point); resolved != nil { + valueType = resolved + } + } + current = applyStructuredWrite(current, write.segments, valueType) + } + + return current +} + +func applyStructuredWrite(baseType typ.Type, segments []constraint.Segment, valueType typ.Type) typ.Type { + if len(segments) == 0 { + if valueType == nil { + return baseType + } + return valueType + } + + seg := segments[0] + child := childTypeForStructuredSegment(baseType, seg) + updatedChild := applyStructuredWrite(child, segments[1:], valueType) + + switch seg.Kind { + case constraint.SegmentField, constraint.SegmentIndexString: + return overwriteStructuredField(baseType, seg.Name, updatedChild) + case constraint.SegmentIndexInt: + return overwriteStructuredIntIndex(baseType, updatedChild) + default: + return baseType + } +} + +func childTypeForStructuredSegment(baseType typ.Type, seg constraint.Segment) typ.Type { + if baseType == nil { + return nil + } + + switch t := baseType.(type) { + case *typ.Alias: + return childTypeForStructuredSegment(t.Target, seg) + case *typ.Record: + switch seg.Kind { + case constraint.SegmentField, constraint.SegmentIndexString: + if field := t.GetField(seg.Name); field != nil { + return field.Type + } + if t.HasMapComponent() && (typ.IsAny(t.MapKey) || t.MapKey.Kind() == kind.String) { + return t.MapValue + } + case constraint.SegmentIndexInt: + if t.HasMapComponent() && (typ.IsAny(t.MapKey) || t.MapKey.Kind() == kind.Integer || t.MapKey.Kind() == kind.Number) { + return t.MapValue + } + } + case *typ.Map: + switch seg.Kind { + case constraint.SegmentField, constraint.SegmentIndexString: + if typ.IsAny(t.Key) || t.Key.Kind() == kind.String { + return t.Value + } + case constraint.SegmentIndexInt: + if typ.IsAny(t.Key) || t.Key.Kind() == kind.Integer || t.Key.Kind() == kind.Number { + return t.Value + } + } + case *typ.Array: + if seg.Kind == constraint.SegmentIndexInt { + return t.Element + } + } + + return nil +} + +func overwriteStructuredField(baseType typ.Type, field string, fieldType typ.Type) typ.Type { + if field == "" || fieldType == nil { + return baseType + } + + switch t := baseType.(type) { + case *typ.Alias: + updated := overwriteStructuredField(t.Target, field, fieldType) + if updated == nil || typ.TypeEquals(updated, t.Target) { + return baseType + } + return typ.NewAlias(t.Name, updated) + case *typ.Map: + return typ.NewRecord(). + SetOpen(true). + MapComponent(t.Key, t.Value). + Field(field, fieldType). + Build() + default: + return typ.ExtendRecordWithField(baseType, field, fieldType) + } +} + +func overwriteStructuredIntIndex(baseType typ.Type, elemType typ.Type) typ.Type { + if elemType == nil { + return baseType + } + + switch t := baseType.(type) { + case *typ.Alias: + updated := overwriteStructuredIntIndex(t.Target, elemType) + if updated == nil || typ.TypeEquals(updated, t.Target) { + return baseType + } + return typ.NewAlias(t.Name, updated) + case *typ.Array: + return typ.NewArray(elemType) + case *typ.Map: + return typ.NewMap(t.Key, elemType) + case *typ.Record: + builder := typ.NewRecord() + if t.Open { + builder.SetOpen(true) + } + for _, f := range t.Fields { + if f.Optional { + if f.Readonly { + builder.OptReadonlyField(f.Name, f.Type) + } else { + builder.OptField(f.Name, f.Type) + } + } else if f.Readonly { + builder.ReadonlyField(f.Name, f.Type) + } else { + builder.Field(f.Name, f.Type) + } + } + if t.Metatable != nil { + builder.Metatable(t.Metatable) + } + builder.MapComponent(typ.Integer, elemType) + return builder.Build() + default: + return typ.NewMap(typ.Integer, elemType) + } +} diff --git a/compiler/check/flowbuild/cond/condition.go b/compiler/check/flowbuild/cond/condition.go index 005bda0b..bf353130 100644 --- a/compiler/check/flowbuild/cond/condition.go +++ b/compiler/check/flowbuild/cond/condition.go @@ -77,16 +77,16 @@ type BranchConditions struct { // - SymResolver: symbol-to-type resolution (declared and narrowed types) // - TypeKeyRes: type name to TypeKey mapping (for type(x) == "T" patterns) // - ConstResolver: constant value lookup (for const-folded conditions) -// - EffectBySym: function effect lookup (for predicate/terminating functions) +// - RefinementBySym: function refinement lookup (for predicate/terminating functions) type ConditionExtractor struct { - P cfg.Point // Current CFG point - SC *scope.State // Scope state at this point - Inputs *flow.Inputs // Flow inputs being built - Synth func(ast.Expr, cfg.Point) typ.Type // Expression type synthesis - SymResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool) // Symbol type resolution - TypeKeyRes func(string, *scope.State) (narrow.TypeKey, bool) // Type name resolution - ConstResolver func(string) *flow.ConstValue // Constant value lookup - EffectBySym constraint.EffectLookupBySym // Function effect lookup + P cfg.Point // Current CFG point + SC *scope.State // Scope state at this point + Inputs *flow.Inputs // Flow inputs being built + Synth func(ast.Expr, cfg.Point) typ.Type // Expression type synthesis + SymResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool) // Symbol type resolution + TypeKeyRes func(string, *scope.State) (narrow.TypeKey, bool) // Type name resolution + ConstResolver func(string) *flow.ConstValue // Constant value lookup + RefinementBySym constraint.RefinementLookupBySym // Function refinement lookup } // constraintsFromBranch extracts type constraints from branch info. @@ -579,11 +579,11 @@ func (ce *ConditionExtractor) calleeHasEffect(call *ast.FuncCallExpr, want func( if call == nil { return false } - // Try effect lookup by symbol. + // Try refinement lookup by symbol. if ident, ok := call.Func.(*ast.IdentExpr); ok { - if bindings := ce.bindings(); bindings != nil && ce.EffectBySym != nil { + if bindings := ce.bindings(); bindings != nil && ce.RefinementBySym != nil { if sym, ok := bindings.SymbolOf(ident); ok && sym != 0 { - if eff := ce.EffectBySym(sym); eff != nil { + if eff := ce.RefinementBySym(sym); eff != nil { if row, ok := eff.Row.(effect.Row); ok && want(row) { return true } diff --git a/compiler/check/flowbuild/cond/extract.go b/compiler/check/flowbuild/cond/extract.go index 6f0acee3..1003a07b 100644 --- a/compiler/check/flowbuild/cond/extract.go +++ b/compiler/check/flowbuild/cond/extract.go @@ -50,11 +50,11 @@ func ExtractEdgeConstraints(fc *core.FlowContext, inputs *flow.Inputs) { ce := &ConditionExtractor{ P: p, SC: fc.Scopes[p], Inputs: inputs, - Synth: fc.Derived.Synth, - SymResolver: fc.Derived.SymResolver, - TypeKeyRes: fc.Derived.TypeKeyRes, - ConstResolver: constResolver, - EffectBySym: fc.Derived.EffectBySym, + Synth: fc.Derived.Synth, + SymResolver: fc.Derived.SymResolver, + TypeKeyRes: fc.Derived.TypeKeyRes, + ConstResolver: constResolver, + RefinementBySym: fc.Derived.RefinementBySym, } constraints := ce.ConstraintsFromBranch(info) @@ -363,7 +363,7 @@ func ExtractCallOnReturnConstraints( } for _, p := range fc.Graph.RPO() { - if !PointHasTerminatingCallSite(fc.Graph, p, fc.Derived.Synth, fc.Derived.SymResolver, fc.Derived.EffectBySym, fc.ModuleBindings) { + if !PointHasTerminatingCallSite(fc.Graph, p, fc.Derived.Synth, fc.Derived.SymResolver, fc.Derived.RefinementBySym, fc.ModuleBindings) { continue } for _, succ := range fc.Graph.Successors(p) { @@ -382,7 +382,7 @@ func ExtractCallOnReturnConstraints( sc := fc.Scopes[p] constResolver := predicate.BuildConstResolver(inputs, p) - cond := ConstraintsFromCallOnReturn(info, p, sc, inputs, fc.Derived.Synth, fc.Derived.TypeKeyRes, fc.Derived.EffectBySym, constResolver, fc.Derived.SymResolver, fc.Graph, fc.ModuleBindings) + cond := ConstraintsFromCallOnReturn(info, p, sc, inputs, fc.Derived.Synth, fc.Derived.TypeKeyRes, fc.Derived.RefinementBySym, constResolver, fc.Derived.SymResolver, fc.Graph, fc.ModuleBindings) if !cond.HasConstraints() { return } @@ -399,7 +399,7 @@ func ExtractCallOnReturnConstraints( fc.Graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { sc := fc.Scopes[p] constResolver := predicate.BuildConstResolver(inputs, p) - cond := ConstraintsFromAssignOnReturn(info, p, sc, inputs, fc.Derived.Synth, fc.Derived.TypeKeyRes, fc.Derived.EffectBySym, constResolver, fc.Derived.SymResolver, fc.Graph, fc.ModuleBindings) + cond := ConstraintsFromAssignOnReturn(info, p, sc, inputs, fc.Derived.Synth, fc.Derived.TypeKeyRes, fc.Derived.RefinementBySym, constResolver, fc.Derived.SymResolver, fc.Graph, fc.ModuleBindings) if !cond.HasConstraints() { return } @@ -424,7 +424,7 @@ func ConstraintsFromCallOnReturn( inputs *flow.Inputs, synthFn func(ast.Expr, cfg.Point) typ.Type, typeKeyResolver func(string, *scope.State) (narrow.TypeKey, bool), - effectLookupSym constraint.EffectLookupBySym, + refinementLookupSym constraint.RefinementLookupBySym, constResolver func(string) *flow.ConstValue, symResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool), graph *cfg.Graph, @@ -452,7 +452,7 @@ func ConstraintsFromCallOnReturn( } } - eff := ExtractFunctionEffect(info, p, synthFn, effectLookupSym, symResolver, graph, moduleBindings) + eff := ExtractFunctionRefinement(info, p, synthFn, refinementLookupSym, symResolver, graph, moduleBindings) if eff == nil || !eff.OnReturn.HasConstraints() { return constraint.Condition{} } @@ -464,11 +464,11 @@ func ConstraintsFromCallOnReturn( ce := &ConditionExtractor{ P: p, SC: sc, Inputs: inputs, - Synth: synthFn, - SymResolver: symResolver, - TypeKeyRes: typeKeyResolver, - ConstResolver: constResolver, - EffectBySym: effectLookupSym, + Synth: synthFn, + SymResolver: symResolver, + TypeKeyRes: typeKeyResolver, + ConstResolver: constResolver, + RefinementBySym: refinementLookupSym, } // OnReturn summarizes all normal-return paths. At call sites we may only @@ -811,17 +811,17 @@ func callConstraintFallbackFromArgs( return cond.MustConstraints(), true } -// ExtractFunctionEffect extracts the function effect from a call using symbol-based lookup. -// All functions in CFG have symbols, so this is the canonical effect resolution path. -func ExtractFunctionEffect( +// ExtractFunctionRefinement extracts the function refinement from a call using symbol-based lookup. +// All functions in CFG have symbols, so this is the canonical refinement resolution path. +func ExtractFunctionRefinement( info *cfg.CallInfo, p cfg.Point, synthFn func(ast.Expr, cfg.Point) typ.Type, - effectLookupSym constraint.EffectLookupBySym, + refinementLookupSym constraint.RefinementLookupBySym, symResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool), graph *cfg.Graph, moduleBindings *bind.BindingTable, -) *constraint.FunctionEffect { +) *constraint.FunctionRefinement { var bindings *bind.BindingTable if graph != nil { bindings = graph.Bindings() @@ -832,7 +832,7 @@ func ExtractFunctionEffect( graph, bindings, moduleBindings, - effectLookupSym, + refinementLookupSym, synthFn, symResolver, checkeffects.EffectFromType, @@ -840,13 +840,13 @@ func ExtractFunctionEffect( } // CallTerminates checks if a call is to a function that never returns. -// Uses symbol-based effect lookup - all functions have symbols. +// Uses symbol-based refinement lookup; all functions have symbols. func CallTerminates( info *cfg.CallInfo, p cfg.Point, synthFn func(ast.Expr, cfg.Point) typ.Type, symResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool), - effectLookupSym constraint.EffectLookupBySym, + refinementLookupSym constraint.RefinementLookupBySym, graph *cfg.Graph, moduleBindings *bind.BindingTable, ) bool { @@ -863,7 +863,7 @@ func CallTerminates( graph, bindings, moduleBindings, - effectLookupSym, + refinementLookupSym, synthFn, symResolver, checkeffects.EffectFromType, @@ -880,14 +880,14 @@ func PointHasTerminatingCallSite( p cfg.Point, synthFn func(ast.Expr, cfg.Point) typ.Type, symResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool), - effectLookupSym constraint.EffectLookupBySym, + refinementLookupSym constraint.RefinementLookupBySym, moduleBindings *bind.BindingTable, ) bool { if graph == nil { return false } for _, callInfo := range graph.CallSitesAt(p) { - if CallTerminates(callInfo, p, synthFn, symResolver, effectLookupSym, graph, moduleBindings) { + if CallTerminates(callInfo, p, synthFn, symResolver, refinementLookupSym, graph, moduleBindings) { return true } } @@ -902,7 +902,7 @@ func ConstraintsFromAssignOnReturn( inputs *flow.Inputs, synthFn func(ast.Expr, cfg.Point) typ.Type, typeKeyResolver func(string, *scope.State) (narrow.TypeKey, bool), - effectLookupSym constraint.EffectLookupBySym, + refinementLookupSym constraint.RefinementLookupBySym, constResolver func(string) *flow.ConstValue, symResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool), graph *cfg.Graph, @@ -913,7 +913,7 @@ func ConstraintsFromAssignOnReturn( } var combined constraint.Condition info.EachSourceCall(func(_ int, callInfo *cfg.CallInfo) { - if cond := ConstraintsFromCallOnReturn(callInfo, p, sc, inputs, synthFn, typeKeyResolver, effectLookupSym, constResolver, symResolver, graph, moduleBindings); cond.HasConstraints() { + if cond := ConstraintsFromCallOnReturn(callInfo, p, sc, inputs, synthFn, typeKeyResolver, refinementLookupSym, constResolver, symResolver, graph, moduleBindings); cond.HasConstraints() { if !combined.HasConstraints() { combined = cond } else { @@ -934,7 +934,7 @@ func ExtractPredicateLinkFromCallInfo( inputs *flow.Inputs, typeKeyResolver func(string, *scope.State) (narrow.TypeKey, bool), synthFn func(ast.Expr, cfg.Point) typ.Type, - effectLookupSym constraint.EffectLookupBySym, + refinementLookupSym constraint.RefinementLookupBySym, symResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool), graph *cfg.Graph, moduleBindings *bind.BindingTable, @@ -977,7 +977,7 @@ func ExtractPredicateLinkFromCallInfo( return nil } - eff := ExtractFunctionEffect(callInfo, p, synthFn, effectLookupSym, symResolver, graph, moduleBindings) + eff := ExtractFunctionRefinement(callInfo, p, synthFn, refinementLookupSym, symResolver, graph, moduleBindings) if eff == nil || !eff.HasPredicateSemantics() { return nil } @@ -1007,12 +1007,12 @@ func ComputeDeadPoints( graph *cfg.Graph, synthFn func(ast.Expr, cfg.Point) typ.Type, symResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool), - effectLookupSym constraint.EffectLookupBySym, + refinementLookupSym constraint.RefinementLookupBySym, moduleBindings *bind.BindingTable, ) map[cfg.Point]bool { dead := make(map[cfg.Point]bool) for _, p := range graph.RPO() { - if PointHasTerminatingCallSite(graph, p, synthFn, symResolver, effectLookupSym, moduleBindings) { + if PointHasTerminatingCallSite(graph, p, synthFn, symResolver, refinementLookupSym, moduleBindings) { for _, succ := range graph.Successors(p) { preds := graph.Predecessors(succ) if len(preds) == 1 { diff --git a/compiler/check/flowbuild/cond/extract_test.go b/compiler/check/flowbuild/cond/extract_test.go index a81fa961..b18b71a9 100644 --- a/compiler/check/flowbuild/cond/extract_test.go +++ b/compiler/check/flowbuild/cond/extract_test.go @@ -87,12 +87,12 @@ func TestConstraintsFromCallOnReturn_OnlyAppliesMustConstraints(t *testing.T) { Args: []ast.Expr{&ast.IdentExpr{Value: "x"}}, } - effectLookup := func(id typecfg.SymbolID) *constraint.FunctionEffect { + refinementLookup := func(id typecfg.SymbolID) *constraint.FunctionRefinement { if id != sym { return nil } p0 := constraint.NewPlaceholder(0) - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ OnReturn: constraint.FromDisjuncts([][]constraint.Constraint{ {constraint.Truthy{Path: p0}}, {constraint.Falsy{Path: p0}}, @@ -107,7 +107,7 @@ func TestConstraintsFromCallOnReturn_OnlyAppliesMustConstraints(t *testing.T) { nil, nil, nil, - effectLookup, + refinementLookup, nil, nil, nil, @@ -237,14 +237,14 @@ func TestCallTerminates_UsesCanonicalCandidatesWhenRawSymbolMissing(t *testing.T // from call expression/bindings and still detect termination. callInfo.CalleeSymbol = 0 - effectLookup := func(sym typecfg.SymbolID) *constraint.FunctionEffect { + refinementLookup := func(sym typecfg.SymbolID) *constraint.FunctionRefinement { if sym == errorSym { - return &constraint.FunctionEffect{Terminates: true} + return &constraint.FunctionRefinement{Terminates: true} } return nil } - if !CallTerminates(callInfo, point, nil, nil, effectLookup, graph, nil) { + if !CallTerminates(callInfo, point, nil, nil, refinementLookup, graph, nil) { t.Fatal("expected terminating call via canonical callee candidate") } } @@ -291,14 +291,14 @@ func TestCallTerminates_UsesModuleBindingNameFallback(t *testing.T) { callInfo.Callee = &ast.IdentExpr{Value: "error_alias"} callInfo.CalleeName = "error_alias" - effectLookup := func(sym typecfg.SymbolID) *constraint.FunctionEffect { + refinementLookup := func(sym typecfg.SymbolID) *constraint.FunctionRefinement { if sym == errorSym { - return &constraint.FunctionEffect{Terminates: true} + return &constraint.FunctionRefinement{Terminates: true} } return nil } - if !CallTerminates(callInfo, point, nil, nil, effectLookup, graph, moduleBindings) { + if !CallTerminates(callInfo, point, nil, nil, refinementLookup, graph, moduleBindings) { t.Fatal("expected terminating call via module-binding callee candidate") } } @@ -356,17 +356,17 @@ func TestPointHasTerminatingCallSite_AssignSourceCall(t *testing.T) { t.Fatal("expected x assignment with resolvable error() call symbol") } - effectLookup := func(sym typecfg.SymbolID) *constraint.FunctionEffect { + refinementLookup := func(sym typecfg.SymbolID) *constraint.FunctionRefinement { if sym == errorSym { - return &constraint.FunctionEffect{Terminates: true} + return &constraint.FunctionRefinement{Terminates: true} } return nil } - if !PointHasTerminatingCallSite(graph, xPoint, nil, nil, effectLookup, nil) { + if !PointHasTerminatingCallSite(graph, xPoint, nil, nil, refinementLookup, nil) { t.Fatal("expected terminating callsite at x assignment point") } - if yPoint != 0 && PointHasTerminatingCallSite(graph, yPoint, nil, nil, effectLookup, nil) { + if yPoint != 0 && PointHasTerminatingCallSite(graph, yPoint, nil, nil, refinementLookup, nil) { t.Fatal("did not expect terminating callsite at y assignment point") } } @@ -402,14 +402,14 @@ func TestComputeDeadPoints_AssignSourceCallTerminates(t *testing.T) { t.Fatal("expected x assignment with resolvable error() call symbol") } - effectLookup := func(sym typecfg.SymbolID) *constraint.FunctionEffect { + refinementLookup := func(sym typecfg.SymbolID) *constraint.FunctionRefinement { if sym == errorSym { - return &constraint.FunctionEffect{Terminates: true} + return &constraint.FunctionRefinement{Terminates: true} } return nil } - dead := ComputeDeadPoints(graph, nil, nil, effectLookup, nil) + dead := ComputeDeadPoints(graph, nil, nil, refinementLookup, nil) if len(dead) == 0 { t.Fatal("expected at least one dead point from terminating assignment call") } diff --git a/compiler/check/flowbuild/core/context.go b/compiler/check/flowbuild/core/context.go index 3c19d168..1edb7c82 100644 --- a/compiler/check/flowbuild/core/context.go +++ b/compiler/check/flowbuild/core/context.go @@ -65,10 +65,10 @@ type FlowServices interface { // Derived contains computed helpers populated during flow extraction. // These are intentionally separated to keep FlowContext immutable. type Derived struct { - Synth func(ast.Expr, cfg.Point) typ.Type - SymResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool) - TypeKeyRes func(string, *scope.State) (narrow.TypeKey, bool) - EffectBySym constraint.EffectLookupBySym + Synth func(ast.Expr, cfg.Point) typ.Type + SymResolver func(cfg.Point, cfg.SymbolID) (typ.Type, bool) + TypeKeyRes func(string, *scope.State) (narrow.TypeKey, bool) + RefinementBySym constraint.RefinementLookupBySym } // FlowServicesFuncs adapts function fields to FlowServices. diff --git a/compiler/check/flowbuild/extract_constraints_test.go b/compiler/check/flowbuild/extract_constraints_test.go index cc37a29c..2590bfdd 100644 --- a/compiler/check/flowbuild/extract_constraints_test.go +++ b/compiler/check/flowbuild/extract_constraints_test.go @@ -344,22 +344,22 @@ func TestFindBranchEdges_WithEdges(t *testing.T) { } } -func TestExtractFunctionEffect_EmptyInfo(t *testing.T) { +func TestExtractFunctionRefinement_EmptyInfo(t *testing.T) { info := &cfg.CallInfo{} - result := cond.ExtractFunctionEffect(info, 0, nil, nil, nil, nil, nil) + result := cond.ExtractFunctionRefinement(info, 0, nil, nil, nil, nil, nil) if result != nil { t.Error("expected nil for empty info") } } -func TestExtractFunctionEffect_WithCallee(t *testing.T) { +func TestExtractFunctionRefinement_WithCallee(t *testing.T) { info := &cfg.CallInfo{ Callee: &ast.IdentExpr{Value: "fn"}, } synth := func(expr ast.Expr, p cfg.Point) typ.Type { return &typ.Function{} } - result := cond.ExtractFunctionEffect(info, 0, synth, nil, nil, nil, nil) + result := cond.ExtractFunctionRefinement(info, 0, synth, nil, nil, nil, nil) // No refinement, so nil expected if result != nil { t.Error("expected nil for function without refinement") diff --git a/compiler/check/flowbuild/extract_core_test.go b/compiler/check/flowbuild/extract_core_test.go index 5fdcc40e..f5ea9307 100644 --- a/compiler/check/flowbuild/extract_core_test.go +++ b/compiler/check/flowbuild/extract_core_test.go @@ -182,8 +182,8 @@ func TestBuildContextTypeKeyResolver_UnknownType(t *testing.T) { } } -func TestBuildEffectLookup_NilCtx(t *testing.T) { - symLookup := resolve.BuildEffectLookup(nil) +func TestBuildRefinementLookup_NilCtx(t *testing.T) { + symLookup := resolve.BuildRefinementLookup(nil) if symLookup != nil { t.Error("expected nil symLookup for nil context") } diff --git a/compiler/check/flowbuild/extract_test.go b/compiler/check/flowbuild/extract_test.go index 3944b728..2775e8e5 100644 --- a/compiler/check/flowbuild/extract_test.go +++ b/compiler/check/flowbuild/extract_test.go @@ -83,7 +83,7 @@ local y = x } // Create assert_not_nil function with refinement - notNilEffect := constraint.NewEffect( + notNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.NotNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) @@ -205,7 +205,7 @@ local y = x.field } // Create assert_is_nil function with refinement - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) @@ -215,7 +215,7 @@ local y = x.field Build() // Create assert_not_nil function with refinement - notNilEffect := constraint.NewEffect( + notNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.NotNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/flowbuild/resolve/resolve.go b/compiler/check/flowbuild/resolve/resolve.go index cd815105..fb08551a 100644 --- a/compiler/check/flowbuild/resolve/resolve.go +++ b/compiler/check/flowbuild/resolve/resolve.go @@ -347,15 +347,15 @@ func BuildContextTypeKeyResolver(ctx api.BaseEnv) func(string, *scope.State) (na } } -// BuildEffectLookup creates the effect lookup function from Env. +// BuildRefinementLookup creates the refinement lookup function from Env. // Returns symbol-based lookup only - all functions have symbols. -func BuildEffectLookup(ctx api.BaseEnv) constraint.EffectLookupBySym { - if ctx == nil || ctx.Effects() == nil { +func BuildRefinementLookup(ctx api.BaseEnv) constraint.RefinementLookupBySym { + if ctx == nil || ctx.Refinements() == nil { return nil } - effects := ctx.Effects() - return func(sym cfg.SymbolID) *constraint.FunctionEffect { - return effects.LookupBySym(sym) + refinements := ctx.Refinements() + return func(sym cfg.SymbolID) *constraint.FunctionRefinement { + return refinements.LookupBySym(sym) } } diff --git a/compiler/check/flowbuild/resolve/resolve_test.go b/compiler/check/flowbuild/resolve/resolve_test.go index d22d904f..ef0ea529 100644 --- a/compiler/check/flowbuild/resolve/resolve_test.go +++ b/compiler/check/flowbuild/resolve/resolve_test.go @@ -518,8 +518,8 @@ func TestBuildContextTypeKeyResolver_UnknownType(t *testing.T) { } } -func TestBuildEffectLookup_NilContext(t *testing.T) { - result := resolve.BuildEffectLookup(nil) +func TestBuildRefinementLookup_NilContext(t *testing.T) { + result := resolve.BuildRefinementLookup(nil) if result != nil { t.Error("expected nil for nil context") } diff --git a/compiler/check/flowbuild/run.go b/compiler/check/flowbuild/run.go index 7af1b393..936470bc 100644 --- a/compiler/check/flowbuild/run.go +++ b/compiler/check/flowbuild/run.go @@ -90,9 +90,9 @@ func Run(fc *fbcore.FlowContext) *flow.Inputs { // Compute derived resolvers and store in a separate derived bundle. derived := &fbcore.Derived{ - SymResolver: resolve.BuildInputSymbolResolver(fc.CheckCtx, inputs), - TypeKeyRes: resolve.BuildContextTypeKeyResolver(fc.CheckCtx), - EffectBySym: resolve.BuildEffectLookup(fc.CheckCtx), + SymResolver: resolve.BuildInputSymbolResolver(fc.CheckCtx, inputs), + TypeKeyRes: resolve.BuildContextTypeKeyResolver(fc.CheckCtx), + RefinementBySym: resolve.BuildRefinementLookup(fc.CheckCtx), } if fc.API != nil { derived.Synth = fc.API.TypeOf @@ -130,7 +130,7 @@ func Run(fc *fbcore.FlowContext) *flow.Inputs { if fc.Derived == nil { continue } - if !cond.PointHasTerminatingCallSite(fc.Graph, p, fc.Derived.Synth, fc.Derived.SymResolver, fc.Derived.EffectBySym, fc.ModuleBindings) { + if !cond.PointHasTerminatingCallSite(fc.Graph, p, fc.Derived.Synth, fc.Derived.SymResolver, fc.Derived.RefinementBySym, fc.ModuleBindings) { continue } for _, succ := range fc.Graph.Successors(p) { diff --git a/compiler/check/hooks/assign_check.go b/compiler/check/hooks/assign_check.go index 38872e89..046b5c8b 100644 --- a/compiler/check/hooks/assign_check.go +++ b/compiler/check/hooks/assign_check.go @@ -163,6 +163,14 @@ func CheckAssignments(graph *cfg.Graph, scopes map[cfg.Point]*scope.State, narro if valueType == nil { return } + if flowQ != nil { + sourcePath := extractSourcePath(source, graph, p) + if !sourcePath.IsEmpty() { + if narrowed := flowQ.NarrowedTypeAt(p, sourcePath); !typ.IsAbsentOrUnknown(narrowed) { + valueType = preferPreciseSourcePathType(valueType, narrowed) + } + } + } if table, ok := source.(*ast.TableExpr); ok && !sourceUsesTarget { if result := tableCheck(table, declaredType, narrowSynth, p); result.Handled { @@ -245,6 +253,16 @@ func CheckAssignments(graph *cfg.Graph, scopes map[cfg.Point]*scope.State, narro return diags } +func preferPreciseSourcePathType(current, narrowed typ.Type) typ.Type { + if typ.IsAbsentOrUnknown(current) { + return narrowed + } + if subtype.IsSubtype(narrowed, current) { + return narrowed + } + return current +} + func extractSourcePath(source ast.Expr, graph *cfg.Graph, _ cfg.Point) constraint.Path { if graph == nil { return constraint.Path{} diff --git a/compiler/check/hooks/field_check.go b/compiler/check/hooks/field_check.go index 3bdb4b52..68be760d 100644 --- a/compiler/check/hooks/field_check.go +++ b/compiler/check/hooks/field_check.go @@ -601,26 +601,6 @@ func isStringKeyExpr(key ast.Expr) bool { return ok } -func isLiteralStringKeyType(t typ.Type) bool { - switch v := t.(type) { - case *typ.Literal: - return v.Base == kind.String - case *typ.Union: - if len(v.Members) == 0 { - return false - } - for _, m := range v.Members { - lit, ok := m.(*typ.Literal) - if !ok || lit.Base != kind.String { - return false - } - } - return true - default: - return false - } -} - func checkIndexAccess(e *ast.AttrGetExpr, p cfg.Point, narrowView api.BaseSynth, resolver fieldResolverImpl, objType typ.Type, sourceName string) []diag.Diagnostic { if e == nil || narrowView == nil { return nil @@ -636,12 +616,12 @@ func checkIndexAccess(e *ast.AttrGetExpr, p cfg.Point, narrowView api.BaseSynth, keyType = typ.Unknown } - if rec, ok := unwrap.Alias(objType).(*typ.Record); ok && !rec.HasMapComponent() && !rec.Open { + if rec := unwrap.Record(objType); rec != nil && !rec.HasMapComponent() && !rec.Open { // Closed records support dynamic string indexing (Lua table semantics). // Non-string keys remain invalid. keyKind := keyType.Kind() - allowsDynamicString := keyKind == kind.String || keyKind.IsPlaceholder() - if !allowsDynamicString && !isLiteralStringKeyType(keyType) { + allowsStringIndex := keyKind.IsPlaceholder() || subtype.IsSubtype(keyType, typ.String) + if !allowsStringIndex { return []diag.Diagnostic{indexError(objType, e, sourceName)} } } @@ -720,6 +700,10 @@ func memberTypeName(t typ.Type) string { return "{" + f.Name + ": ..., ...}" } return "{}" + case *typ.Recursive: + if v.Name != "" { + return v.Name + } case *typ.Instantiated: if v.Generic != nil && v.Generic.Name != "" { return v.Generic.Name + "<...>" @@ -750,6 +734,8 @@ func hasField(t typ.Type, field string) bool { return false case *typ.Alias: return hasField(v.Target, field) + case *typ.Recursive: + return v.Body != nil && v.Body != v && hasField(v.Body, field) case *typ.Optional: return hasField(v.Inner, field) } diff --git a/compiler/check/hooks/iface_test.go b/compiler/check/hooks/iface_test.go index 17565dcf..6efa4ac6 100644 --- a/compiler/check/hooks/iface_test.go +++ b/compiler/check/hooks/iface_test.go @@ -86,6 +86,7 @@ type flowQueryImpl struct{} func (f *flowQueryImpl) EffectiveTypeAt(p cfg.Point, sym cfg.SymbolID) flow.TypedValue { return flow.TypedValue{} } +func (f *flowQueryImpl) NarrowedTypeAt(p cfg.Point, path constraint.Path) typ.Type { return nil } func (f *flowQueryImpl) ExcludesTypeAt(p cfg.Point, path constraint.Path, declared typ.Type) bool { return false } diff --git a/compiler/check/hooks/table_check.go b/compiler/check/hooks/table_check.go index 88cc51f2..75a13a3e 100644 --- a/compiler/check/hooks/table_check.go +++ b/compiler/check/hooks/table_check.go @@ -77,14 +77,10 @@ func extractTableFields(table *ast.TableExpr, expected typ.Type, synth api.Synth fields := make([]ops.FieldDef, 0, len(table.Fields)) var arrayElems []typ.Type - var mapValueType typ.Type - if m, ok := unwrap.Alias(expected).(*typ.Map); ok { - mapValueType = m.Value - } - for _, field := range table.Fields { if field.Key == nil { - elemType := synth.SynthWithExpected(field.Value, p, mapValueType) + elemExpected := ops.ExpectedTableElementType(expected, len(arrayElems)) + elemType := synth.SynthWithExpected(field.Value, p, elemExpected) if elemType == nil { elemType = typ.Unknown } @@ -103,7 +99,8 @@ func extractTableFields(table *ast.TableExpr, expected typ.Type, synth api.Synth name = k.Value case *ast.NumberExpr: _ = k - elemType := synth.SynthWithExpected(field.Value, p, mapValueType) + elemExpected := ops.ExpectedTableElementType(expected, len(arrayElems)) + elemType := synth.SynthWithExpected(field.Value, p, elemExpected) arrayElems = append(arrayElems, elemType) continue default: diff --git a/compiler/check/infer/nested/processor.go b/compiler/check/infer/nested/processor.go index 162fe75b..e4c32d76 100644 --- a/compiler/check/infer/nested/processor.go +++ b/compiler/check/infer/nested/processor.go @@ -305,8 +305,17 @@ func (p *Processor) resolveSelfTypeForMethod( ) typ.Type { var selfType typ.Type + // Prefer the explicit type-space binding for `T` in `function T:m(...)`. + // The receiver value `T` is the class table; the instance/self contract + // lives in the type namespace binding with the same name. + if info != nil && info.FuncDef != nil && info.FuncDef.ReceiverName != "" && info.DefScope != nil { + if named, ok := info.DefScope.LookupType(info.FuncDef.ReceiverName); ok && named != nil { + selfType = named + } + } + // First try root result facts. - if rootResult != nil && rootResult.Facts != nil { + if selfType == nil && rootResult != nil && rootResult.Facts != nil { tv := rootResult.Facts.EffectiveTypeAt(info.NF.Point, sym) if tv.Type != nil && tv.State == flow.StateResolved { selfType = tv.Type diff --git a/compiler/check/infer/return/infer.go b/compiler/check/infer/return/infer.go index dfe032c6..06340f42 100644 --- a/compiler/check/infer/return/infer.go +++ b/compiler/check/infer/return/infer.go @@ -104,7 +104,7 @@ func New(cfg Config) *Inferencer { type RunContext struct { Ctx *db.QueryContext ParentFacts flow.TypeFacts - EffectLookup constraint.EffectLookupBySym + EffectLookup constraint.RefinementLookupBySym } // collectLocalFunctions gathers local function definitions from assignments and FuncDef nodes. diff --git a/compiler/check/modules/doc.go b/compiler/check/modules/doc.go index 89f71cff..69849ae9 100644 --- a/compiler/check/modules/doc.go +++ b/compiler/check/modules/doc.go @@ -17,7 +17,7 @@ // # Type Enrichment // // Exported types are enriched with: -// - Function effect annotations (terminates, type guards) +// - Function refinement annotations (terminates, type guards) // - Refined field types from flow analysis // - Method signatures from implementation analysis // diff --git a/compiler/check/modules/export.go b/compiler/check/modules/export.go index 76ee3ac6..4494a7ec 100644 --- a/compiler/check/modules/export.go +++ b/compiler/check/modules/export.go @@ -10,8 +10,8 @@ import ( ) // ExportType computes the module's exported type from return statements. -// Pass effectsBySym to enrich exported functions with effect summaries. -func ExportType(result *api.FuncResult, effectsBySym map[cfg.SymbolID]*constraint.FunctionEffect) typ.Type { +// Pass refinementsBySym to enrich exported functions with effect summaries. +func ExportType(result *api.FuncResult, refinementsBySym map[cfg.SymbolID]*constraint.FunctionRefinement) typ.Type { if result == nil || result.Graph == nil || result.NarrowSynth == nil { return typ.Nil } @@ -59,8 +59,8 @@ func ExportType(result *api.FuncResult, effectsBySym map[cfg.SymbolID]*constrain } }) - if export != nil && len(effectsBySym) > 0 && result.Graph != nil { - export = effects.EnrichExportWithEffects(export, exportRootName, effectsBySym, result.Graph) + if export != nil && len(refinementsBySym) > 0 && result.Graph != nil { + export = effects.EnrichExportWithEffects(export, exportRootName, refinementsBySym, result.Graph) } if export == nil { @@ -93,18 +93,18 @@ func ExportTypes(result *api.FuncResult) map[string]typ.Type { return types } -// CopyEffectsForExport returns a defensive copy of effects for manifest export. -func CopyEffectsForExport(effectsBySym map[cfg.SymbolID]*constraint.FunctionEffect) map[cfg.SymbolID]*constraint.FunctionEffect { - if len(effectsBySym) == 0 { +// CopyRefinementsForExport returns a defensive copy of refinements for manifest export. +func CopyRefinementsForExport(refinementsBySym map[cfg.SymbolID]*constraint.FunctionRefinement) map[cfg.SymbolID]*constraint.FunctionRefinement { + if len(refinementsBySym) == 0 { return nil } - effects := make(map[cfg.SymbolID]*constraint.FunctionEffect, len(effectsBySym)) - for sym, eff := range effectsBySym { - if eff != nil { - effects[sym] = eff + refinements := make(map[cfg.SymbolID]*constraint.FunctionRefinement, len(refinementsBySym)) + for sym, refinement := range refinementsBySym { + if refinement != nil { + refinements[sym] = refinement } } - return effects + return refinements } // ResolveExportTypeNames resolves type names to concrete types for a given scope. diff --git a/compiler/check/modules/modules.go b/compiler/check/modules/modules.go index 0b2db65a..875db293 100644 --- a/compiler/check/modules/modules.go +++ b/compiler/check/modules/modules.go @@ -40,7 +40,7 @@ import ( // 3. Defines any exported types // 4. Extracts function summaries for cross-module analysis // 5. Registers the manifest with the database -func Connect(database *db.DB, name string, exportType typ.Type, exportTypes map[string]typ.Type, graph *cfg.Graph, effectsBySym map[cfg.SymbolID]*constraint.FunctionEffect) *io.Manifest { +func Connect(database *db.DB, name string, exportType typ.Type, exportTypes map[string]typ.Type, graph *cfg.Graph, refinementsBySym map[cfg.SymbolID]*constraint.FunctionRefinement) *io.Manifest { manifest := io.NewManifest(name) manifest.SetExport(exportType) @@ -50,7 +50,7 @@ func Connect(database *db.DB, name string, exportType typ.Type, exportTypes map[ } } - ExportFunctionSummaries(manifest, exportType, graph, effectsBySym) + ExportFunctionSummaries(manifest, exportType, graph, refinementsBySym) database.Connect(name, manifest) @@ -65,8 +65,8 @@ func Connect(database *db.DB, name string, exportType typ.Type, exportTypes map[ // // OnReturn constraints encode assert-style narrowing (e.g. assert.not_nil), // enabling callers to narrow types based on imported module behavior. -func ExportFunctionSummaries(manifest *io.Manifest, exportType typ.Type, graph *cfg.Graph, effectsBySym map[cfg.SymbolID]*constraint.FunctionEffect) { - if graph == nil || len(effectsBySym) == 0 { +func ExportFunctionSummaries(manifest *io.Manifest, exportType typ.Type, graph *cfg.Graph, refinementsBySym map[cfg.SymbolID]*constraint.FunctionRefinement) { + if graph == nil || len(refinementsBySym) == 0 { return } @@ -75,12 +75,12 @@ func ExportFunctionSummaries(manifest *io.Manifest, exportType typ.Type, graph * return } - if len(effectsBySym) == 0 { + if len(refinementsBySym) == 0 { return } - for _, sym := range cfg.SortedSymbolIDs(effectsBySym) { - eff := effectsBySym[sym] - if eff == nil || !eff.OnReturn.HasConstraints() { + for _, sym := range cfg.SortedSymbolIDs(refinementsBySym) { + refinement := refinementsBySym[sym] + if refinement == nil || !refinement.OnReturn.HasConstraints() { continue } @@ -108,8 +108,8 @@ func ExportFunctionSummaries(manifest *io.Manifest, exportType typ.Type, graph * params[i] = p.Type } ioSummary := io.NewSummary(params, fn.Returns) - ioSummary.Ensures = eff.OnReturn - if row, ok := eff.Row.(effect.Row); ok { + ioSummary.Ensures = refinement.OnReturn + if row, ok := refinement.Row.(effect.Row); ok { ioSummary.Effects = row } manifest.DefineSummary(fieldName, ioSummary) diff --git a/compiler/check/modules/modules_test.go b/compiler/check/modules/modules_test.go index 7f93d3c8..194a4247 100644 --- a/compiler/check/modules/modules_test.go +++ b/compiler/check/modules/modules_test.go @@ -51,7 +51,7 @@ func TestExportFunctionSummaries_NilGraph(t *testing.T) { func TestExportFunctionSummaries_EmptyEffects(t *testing.T) { manifest := io.NewManifest("test") - ExportFunctionSummaries(manifest, typ.NewRecord().Build(), nil, make(map[cfg.SymbolID]*constraint.FunctionEffect)) + ExportFunctionSummaries(manifest, typ.NewRecord().Build(), nil, make(map[cfg.SymbolID]*constraint.FunctionRefinement)) } func TestExportFunctionSummaries_NonRecordExportType(t *testing.T) { diff --git a/compiler/check/nested/enrich.go b/compiler/check/nested/enrich.go index ed477bed..f6736612 100644 --- a/compiler/check/nested/enrich.go +++ b/compiler/check/nested/enrich.go @@ -16,7 +16,7 @@ import ( // This file provides type enrichment utilities for self-type resolution. // // When a method is defined in a table literal, the function's literal signature -// (with inferred effects and return types) may be more precise than the initially +// (with inferred refinements and return types) may be more precise than the initially // synthesized type. These utilities replace placeholder types with literal sigs. // EnrichTableTypeWithFuncTypes replaces method function types in a record diff --git a/compiler/check/overlaymut/collect.go b/compiler/check/overlaymut/collect.go new file mode 100644 index 00000000..0e2d2b03 --- /dev/null +++ b/compiler/check/overlaymut/collect.go @@ -0,0 +1,159 @@ +package overlaymut + +import ( + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" + "github.com/wippyai/go-lua/types/typ" +) + +// CollectFieldAssignments scans the graph for field assignments and groups them by base symbol. +// Returns a map: symbolID -> map[fieldName]typ.Type representing fields assigned to each symbol. +// The synth function is used to synthesize field value types. +// If filterSyms is non-nil, only symbols in the filter are collected. +func CollectFieldAssignments( + graph *cfg.Graph, + synth func(ast.Expr, cfg.Point) typ.Type, + filterSyms map[cfg.SymbolID]bool, +) map[cfg.SymbolID]map[string]typ.Type { + result := make(map[cfg.SymbolID]map[string]typ.Type) + if graph == nil { + return result + } + + graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { + if info == nil { + return + } + sources := info.Sources + for i, target := range info.Targets { + var source ast.Expr + if i < len(sources) { + source = sources[i] + } + var sym cfg.SymbolID + var fieldName string + + switch target.Kind { + case cfg.TargetField: + if target.BaseSymbol != 0 && len(target.FieldPath) == 1 { + sym = target.BaseSymbol + fieldName = target.FieldPath[0] + } + case cfg.TargetIndex: + if target.BaseSymbol != 0 && target.Key != nil { + if strKey, ok := target.Key.(*ast.StringExpr); ok && strKey.Value != "" { + sym = target.BaseSymbol + fieldName = strKey.Value + } + } + } + + if sym == 0 || fieldName == "" { + continue + } + if filterSyms != nil && !filterSyms[sym] { + continue + } + + var fieldType typ.Type + if source != nil && synth != nil { + fieldType = synth(source, p) + } + if fieldType == nil { + fieldType = typ.Unknown + } + + if result[sym] == nil { + result[sym] = make(map[string]typ.Type) + } + if existing := result[sym][fieldName]; existing != nil { + result[sym][fieldName] = typ.JoinPreferNonSoft(existing, fieldType) + } else { + result[sym][fieldName] = fieldType + } + } + }) + + return result +} + +// CollectIndexerAssignments scans the graph for dynamic index assignments (t[k] = v where k is non-const). +// Returns a map: symbolID -> []IndexerInfo representing index assignments to each symbol. +func CollectIndexerAssignments( + graph *cfg.Graph, + synth func(ast.Expr, cfg.Point) typ.Type, + bindings *bind.BindingTable, + filterSyms map[cfg.SymbolID]bool, +) map[cfg.SymbolID][]mutator.IndexerInfo { + result := make(map[cfg.SymbolID][]mutator.IndexerInfo) + if graph == nil { + return result + } + + graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { + if info == nil { + return + } + sources := info.Sources + for i, target := range info.Targets { + var source ast.Expr + if i < len(sources) { + source = sources[i] + } + if target.Kind != cfg.TargetIndex { + continue + } + sym := target.BaseSymbol + if sym == 0 { + continue + } + if filterSyms != nil && !filterSyms[sym] { + continue + } + + // Skip string literal keys (handled by field assignments) + if _, ok := target.Key.(*ast.StringExpr); ok { + continue + } + + var keyType typ.Type + switch k := target.Key.(type) { + case *ast.IdentExpr: + if synth != nil { + keyType = synth(k, p) + } + case *ast.NumberExpr: + keyType = typ.Integer + default: + if synth != nil && target.Key != nil { + keyType = synth(target.Key, p) + } + } + keyType = canonicalDynamicKeyType(keyType) + + var valType typ.Type + if source != nil && synth != nil { + valType = synth(source, p) + } + if valType == nil { + valType = typ.Unknown + } + + result[sym] = append(result[sym], mutator.IndexerInfo{ + KeyType: keyType, + ValType: valType, + }) + } + }) + + return result +} + +func canonicalDynamicKeyType(keyType typ.Type) typ.Type { + if keyType == nil || keyType.Kind().IsPlaceholder() { + return typ.String + } + return keyType +} diff --git a/compiler/check/overlaymut/merge.go b/compiler/check/overlaymut/merge.go new file mode 100644 index 00000000..13d9c5ec --- /dev/null +++ b/compiler/check/overlaymut/merge.go @@ -0,0 +1,224 @@ +package overlaymut + +import ( + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" + "github.com/wippyai/go-lua/types/flow" + querycore "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +// MergeFieldAssignments merges src into dst. +func MergeFieldAssignments( + dst map[cfg.SymbolID]map[string]typ.Type, + src map[cfg.SymbolID]map[string]typ.Type, +) { + for _, sym := range cfg.SortedSymbolIDs(src) { + fields := src[sym] + if dst[sym] == nil { + dst[sym] = make(map[string]typ.Type) + } + for _, name := range cfg.SortedFieldNames(fields) { + fieldType := fields[name] + if existing := dst[sym][name]; existing != nil { + dst[sym][name] = typ.JoinPreferNonSoft(existing, fieldType) + } else { + dst[sym][name] = fieldType + } + } + } +} + +// ApplyFieldMergeToOverlay merges collected field assignments into symbol types in the overlay. +func ApplyFieldMergeToOverlay( + overlay map[cfg.SymbolID]typ.Type, + fieldAssignments map[cfg.SymbolID]map[string]typ.Type, +) { + for _, sym := range cfg.SortedSymbolIDs(fieldAssignments) { + fields := fieldAssignments[sym] + if len(fields) == 0 { + continue + } + baseType := overlay[sym] + merged := MergeFieldsIntoType(baseType, fields) + if merged != nil { + overlay[sym] = merged + } + } +} + +// MergeFieldsIntoType merges a set of field types into a base type. +func MergeFieldsIntoType(baseType typ.Type, fields map[string]typ.Type) typ.Type { + if len(fields) == 0 { + return baseType + } + + fieldNames := cfg.SortedFieldNames(fields) + + if baseType == nil { + builder := typ.NewRecord().SetOpen(true) + for _, name := range fieldNames { + builder.Field(name, fields[name]) + } + return builder.Build() + } + + switch v := baseType.(type) { + case *typ.Map: + builder := typ.NewRecord().SetOpen(true) + builder.MapComponent(v.Key, v.Value) + for _, name := range fieldNames { + builder.Field(name, fields[name]) + } + return builder.Build() + case *typ.Record: + builder := typ.NewRecord() + if v.Open { + builder.SetOpen(true) + } + existing := make(map[string]bool) + for _, f := range v.Fields { + builder.Field(f.Name, f.Type) + existing[f.Name] = true + } + for _, name := range fieldNames { + if !existing[name] { + builder.Field(name, fields[name]) + } + } + if v.Metatable != nil { + builder.Metatable(v.Metatable) + } + if v.HasMapComponent() { + builder.MapComponent(v.MapKey, v.MapValue) + } + return builder.Build() + default: + builder := typ.NewRecord().SetOpen(true) + for _, name := range fieldNames { + builder.Field(name, fields[name]) + } + return builder.Build() + } +} + +// ApplyIndexerMergeToOverlay adds map components to symbol types based on dynamic index assignments. +func ApplyIndexerMergeToOverlay( + overlay map[cfg.SymbolID]typ.Type, + indexerAssignments map[cfg.SymbolID][]mutator.IndexerInfo, +) { + for _, sym := range cfg.SortedSymbolIDs(indexerAssignments) { + infos := indexerAssignments[sym] + if len(infos) == 0 { + continue + } + + var keyType, valType typ.Type + for _, info := range infos { + keyType = typ.JoinPreferNonSoft(keyType, info.KeyType) + valType = JoinValueTypes(valType, info.ValType) + } + if keyType == nil { + keyType = typ.String + } + if valType == nil { + valType = typ.Unknown + } + + baseType := overlay[sym] + merged := MergeMapComponentIntoType(baseType, keyType, valType) + if merged != nil { + overlay[sym] = merged + } + } +} + +// JoinValueTypes joins two value types, preferring arrays over empty records. +func JoinValueTypes(a, b typ.Type) typ.Type { + if a == nil { + return b + } + if b == nil { + return a + } + + aIsEmptyRecord := unwrap.IsEmptyRecord(a) + bIsEmptyRecord := unwrap.IsEmptyRecord(b) + _, aIsArray := a.(*typ.Array) + _, bIsArray := b.(*typ.Array) + aIsPlaceholder := a.Kind().IsPlaceholder() + bIsPlaceholder := b.Kind().IsPlaceholder() + + if aIsEmptyRecord && bIsArray { + return b + } + if bIsEmptyRecord && aIsArray { + return a + } + if aIsPlaceholder && bIsArray { + return b + } + if bIsPlaceholder && aIsArray { + return a + } + + return typ.JoinPreferNonSoft(a, b) +} + +// MergeMapComponentIntoType adds a map component to a base type. +func MergeMapComponentIntoType(baseType, keyType, valType typ.Type) typ.Type { + if baseType == nil { + return typ.NewMap(keyType, valType) + } + + switch v := baseType.(type) { + case *typ.Map: + newKey := typ.JoinPreferNonSoft(v.Key, keyType) + newVal := typ.JoinPreferNonSoft(v.Value, valType) + return typ.NewMap(newKey, newVal) + case *typ.Record: + builder := typ.NewRecord() + if v.Open { + builder.SetOpen(true) + } + for _, f := range v.Fields { + builder.Field(f.Name, f.Type) + } + if v.Metatable != nil { + builder.Metatable(v.Metatable) + } + if v.HasMapComponent() { + newKey := typ.JoinPreferNonSoft(v.MapKey, keyType) + newVal := typ.JoinPreferNonSoft(v.MapValue, valType) + builder.MapComponent(newKey, newVal) + } else { + existingKey := querycore.KeyType(v) + if existingKey == nil { + existingKey = typ.String + } + builder.MapComponent(typ.JoinPreferNonSoft(existingKey, keyType), valType) + } + return builder.Build() + default: + return typ.NewMap(keyType, valType) + } +} + +// ApplyDirectMutationsToOverlay widens array element types based on table.insert mutations. +func ApplyDirectMutationsToOverlay( + overlay map[cfg.SymbolID]typ.Type, + mutations map[cfg.SymbolID]typ.Type, +) { + for _, sym := range cfg.SortedSymbolIDs(mutations) { + elemType := mutations[sym] + if elemType == nil { + continue + } + baseType := overlay[sym] + merged := flow.WidenArrayElementType(baseType, elemType, typ.JoinPreferNonSoft) + if merged != nil { + overlay[sym] = merged + } + } +} diff --git a/compiler/check/phase/flow.go b/compiler/check/phase/flow.go index 6403323e..2df1db33 100644 --- a/compiler/check/phase/flow.go +++ b/compiler/check/phase/flow.go @@ -125,19 +125,19 @@ func RunLiteral(input LiteralInput) LiteralOutput { } } -// InferEffect computes a FunctionEffect from solved flow analysis. +// InferRefinement computes a FunctionRefinement from solved flow analysis. // Examines return points to determine OnTrue/OnFalse/OnReturn constraints. -func InferEffect( +func InferRefinement( graph *cfg.Graph, solution *flow.Solution, params []flow.ParamInfo, returnType typ.Type, -) *constraint.FunctionEffect { +) *constraint.FunctionRefinement { if graph == nil || solution == nil { return nil } - return flow.InferFunctionEffect(solution, graph.CFG(), params, returnType) + return flow.InferFunctionRefinement(solution, graph.CFG(), params, returnType) } // ExtractParams extracts parameter info from a function expression. @@ -179,7 +179,7 @@ func ExtractParams(fn *ast.FunctionExpr, paramTypes map[cfg.SymbolID]typ.Type, g // EnrichWithKeysCollector detects if a function is a "keys collector" // (returns keys of a parameter) and adds KeyOf constraint to OnReturn. // This enables cross-module key-provenance tracking. -func EnrichWithKeysCollector(eff *constraint.FunctionEffect, fn *ast.FunctionExpr) *constraint.FunctionEffect { +func EnrichWithKeysCollector(eff *constraint.FunctionRefinement, fn *ast.FunctionExpr) *constraint.FunctionRefinement { if fn == nil { return eff } @@ -195,12 +195,12 @@ func EnrichWithKeysCollector(eff *constraint.FunctionEffect, fn *ast.FunctionExp } if eff == nil { - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(keyOf), } } - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ Row: eff.Row, OnReturn: constraint.And(eff.OnReturn, constraint.FromConstraints(keyOf)), OnTrue: eff.OnTrue, diff --git a/compiler/check/phase/flow_test.go b/compiler/check/phase/flow_test.go index 8c86e5fe..1d1e191f 100644 --- a/compiler/check/phase/flow_test.go +++ b/compiler/check/phase/flow_test.go @@ -66,17 +66,17 @@ func TestExtractParams_WithTypes(t *testing.T) { } } -func TestInferEffect_NilGraph(t *testing.T) { - result := InferEffect(nil, &flow.Solution{}, nil, nil) +func TestInferRefinement_NilGraph(t *testing.T) { + result := InferRefinement(nil, &flow.Solution{}, nil, nil) if result != nil { t.Errorf("expected nil for nil graph, got %v", result) } } -func TestInferEffect_NilSolution(t *testing.T) { +func TestInferRefinement_NilSolution(t *testing.T) { fn := &ast.FunctionExpr{ParList: &ast.ParList{}} graph := cfg.Build(fn) - result := InferEffect(graph, nil, nil, nil) + result := InferRefinement(graph, nil, nil, nil) if result != nil { t.Errorf("expected nil for nil solution, got %v", result) } @@ -147,7 +147,7 @@ func TestEnrichWithKeysCollector_AppendsToExistingOnReturn(t *testing.T) { Stmts: body, } - existing := &constraint.FunctionEffect{ + existing := &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(constraint.NotNil{Path: constraint.RetPath(0)}), } result := EnrichWithKeysCollector(existing, fn) diff --git a/compiler/check/phase/narrow.go b/compiler/check/phase/narrow.go index 9429f517..a53b020b 100644 --- a/compiler/check/phase/narrow.go +++ b/compiler/check/phase/narrow.go @@ -69,13 +69,13 @@ func RunNarrow(input NarrowInput) NarrowOutput { input.ModuleAliases, ) - fnEffect := InferEffect(input.Graph, input.Solve.Solution, input.Extract.Params, input.Extract.ReturnType) + fnEffect := InferRefinement(input.Graph, input.Solve.Solution, input.Extract.Params, input.Extract.ReturnType) fnEffect = EnrichWithKeysCollector(fnEffect, input.Fn) return NarrowOutput{ - Facts: narrowingCtx.Types(), - Effect: fnEffect, - Synth: engine, + Facts: narrowingCtx.Types(), + Refinement: fnEffect, + Synth: engine, } } diff --git a/compiler/check/phase/resolve.go b/compiler/check/phase/resolve.go index d7331d53..4ecaad3e 100644 --- a/compiler/check/phase/resolve.go +++ b/compiler/check/phase/resolve.go @@ -9,8 +9,10 @@ package phase import ( "reflect" + "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/modules" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth" basecfg "github.com/wippyai/go-lua/types/cfg" @@ -34,21 +36,22 @@ func RunResolve(input ResolveInput) ResolveOutput { return ResolveOutput{} } - initialSymbolTypes := BuildInitialSymbolTypes(input.Graph, input.GlobalTypes, nil) globalCtx := api.NewDeclaredEnv(api.DeclaredEnvConfig{ Graph: input.Graph, Bindings: input.Bindings, - DeclaredTypes: BuildDeclaredTypesFromSymbolTypes(input.Graph, initialSymbolTypes), + DeclaredTypes: BuildDeclaredTypesForResolve(input.Graph, input.GlobalTypes, nil), BaseScope: input.BaseScope, GlobalTypes: input.GlobalTypes, }) engine := synth.New(synth.Config{ - Ctx: input.Ctx, - Types: input.Types, - Manifests: input.Manifests, - Env: globalCtx, - Phase: api.PhaseTypeResolution, + Ctx: input.Ctx, + Types: input.Types, + Manifests: input.Manifests, + Env: globalCtx, + Phase: api.PhaseTypeResolution, + ModuleBindings: firstNonNilBindings(input.ModuleBindings, input.Bindings), + ModuleAliases: firstNonNilAliases(input.ModuleAliases, modules.CollectAliases(input.Graph)), }) return ResolveOutput{ @@ -66,23 +69,38 @@ func CreateTypeResolutionEngine( types core.TypeOps, manifests io.ManifestQuerier, ) *synth.Engine { - initialSymbolTypes := BuildInitialSymbolTypes(graph, globalTypes, paramTypes) checkCtx := api.NewDeclaredEnv(api.DeclaredEnvConfig{ Graph: graph, Bindings: graph.Bindings(), - DeclaredTypes: BuildDeclaredTypesFromSymbolTypes(graph, initialSymbolTypes), + DeclaredTypes: BuildDeclaredTypesForResolve(graph, globalTypes, paramTypes), BaseScope: base, GlobalTypes: globalTypes, }) return synth.New(synth.Config{ - Ctx: ctx, - Types: types, - Manifests: manifests, - Env: checkCtx, - Phase: api.PhaseTypeResolution, + Ctx: ctx, + Types: types, + Manifests: manifests, + Env: checkCtx, + Phase: api.PhaseTypeResolution, + ModuleBindings: graph.Bindings(), + ModuleAliases: modules.CollectAliases(graph), }) } +func firstNonNilBindings(primary, fallback *bind.BindingTable) *bind.BindingTable { + if primary != nil { + return primary + } + return fallback +} + +func firstNonNilAliases(primary, fallback map[cfg.SymbolID]string) map[cfg.SymbolID]string { + if len(primary) > 0 { + return primary + } + return fallback +} + // BuildInitialSymbolTypes creates SymbolTypes for globals and parameters at all CFG points. func BuildInitialSymbolTypes(graph *cfg.Graph, globalTypes map[string]typ.Type, paramTypes map[cfg.SymbolID]typ.Type) flow.SymbolTypes { if graph == nil { @@ -139,17 +157,14 @@ func BuildInitialSymbolTypes(graph *cfg.Graph, globalTypes map[string]typ.Type, bindings := graph.Bindings() out := make(flow.SymbolTypes) - // Compute types once at entry and reuse if symbols don't change - var prevTypesAt map[cfg.SymbolID]flow.TypedValue - var prevLocalsToken uintptr - hasPrevLocals := false + typesByLocalsToken := make(map[uintptr]map[cfg.SymbolID]flow.TypedValue) for _, p := range graph.RPO() { locals := graph.LocalSymbolsAt(p) localsToken := reflect.ValueOf(locals).Pointer() - if hasPrevLocals && localsToken == prevLocalsToken { - if prevTypesAt != nil { - out[p] = prevTypesAt + if cached, ok := typesByLocalsToken[localsToken]; ok { + if cached != nil { + out[p] = cached } continue } @@ -190,32 +205,12 @@ func BuildInitialSymbolTypes(graph *cfg.Graph, globalTypes map[string]typ.Type, } if len(typesAt) == 0 { - hasPrevLocals = true - prevLocalsToken = localsToken - prevTypesAt = nil + typesByLocalsToken[localsToken] = nil continue } - // Reuse previous map if identical content - if prevTypesAt != nil && len(typesAt) == len(prevTypesAt) { - identical := true - for sym, tv := range typesAt { - if prev, ok := prevTypesAt[sym]; !ok || prev != tv { - identical = false - break - } - } - if identical { - out[p] = prevTypesAt - hasPrevLocals = true - prevLocalsToken = localsToken - continue - } - } out[p] = typesAt - prevTypesAt = typesAt - hasPrevLocals = true - prevLocalsToken = localsToken + typesByLocalsToken[localsToken] = typesAt } if len(out) == 0 { @@ -233,10 +228,21 @@ func BuildDeclaredTypesFromSymbolTypes(graph basecfg.VersionedGraph, symbolTypes entry := graph.Entry() bestPoint := make(map[cfg.SymbolID]cfg.Point, len(symbolTypes)) + lowestPointByTypesToken := make(map[uintptr]cfg.Point, len(symbolTypes)) + typesByToken := make(map[uintptr]map[cfg.SymbolID]flow.TypedValue, len(symbolTypes)) for p, typesAt := range symbolTypes { - if p == entry { + if p == entry || typesAt == nil { continue } + token := reflect.ValueOf(typesAt).Pointer() + if prev, ok := lowestPointByTypesToken[token]; !ok || p < prev { + lowestPointByTypesToken[token] = p + typesByToken[token] = typesAt + } + } + + for token, typesAt := range typesByToken { + p := lowestPointByTypesToken[token] for sym, tv := range typesAt { if tv.State != flow.StateResolved || tv.Type == nil { continue @@ -261,3 +267,72 @@ func BuildDeclaredTypesFromSymbolTypes(graph basecfg.VersionedGraph, symbolTypes } return out } + +// BuildDeclaredTypesForResolve computes declared types directly for the resolve phase. +// +// This avoids materializing the full per-point SymbolTypes map when resolve only +// needs the collapsed DeclaredTypes result. +func BuildDeclaredTypesForResolve(graph *cfg.Graph, globalTypes map[string]typ.Type, paramTypes map[cfg.SymbolID]typ.Type) flow.DeclaredTypes { + if graph == nil || (len(globalTypes) == 0 && len(paramTypes) == 0) { + return nil + } + + out := make(flow.DeclaredTypes, len(globalTypes)+len(paramTypes)) + + paramNameTypes := make(map[string]typ.Type, len(paramTypes)) + for sym, t := range paramTypes { + if t == nil { + continue + } + out[sym] = t + if name := graph.NameOf(sym); name != "" { + paramNameTypes[name] = t + } + } + + for name, t := range globalTypes { + if t == nil { + continue + } + if sym, ok := graph.GlobalSymbol(name); ok && sym != 0 { + out[sym] = t + } + } + + if len(paramNameTypes) == 0 { + if len(out) == 0 { + return nil + } + return out + } + + localsByToken := make(map[uintptr]map[string]cfg.SymbolID) + lowestPointByToken := make(map[uintptr]cfg.Point) + for _, p := range graph.RPO() { + locals := graph.LocalSymbolsAt(p) + if len(locals) == 0 { + continue + } + token := reflect.ValueOf(locals).Pointer() + if prev, ok := lowestPointByToken[token]; !ok || p < prev { + lowestPointByToken[token] = p + localsByToken[token] = locals + } + } + + for _, locals := range localsByToken { + for name, sym := range locals { + if _, exists := out[sym]; exists { + continue + } + if t := paramNameTypes[name]; t != nil { + out[sym] = t + } + } + } + + if len(out) == 0 { + return nil + } + return out +} diff --git a/compiler/check/phase/resolve_test.go b/compiler/check/phase/resolve_test.go index 18918d09..ed7cf60d 100644 --- a/compiler/check/phase/resolve_test.go +++ b/compiler/check/phase/resolve_test.go @@ -1,6 +1,7 @@ package phase import ( + "reflect" "testing" "github.com/wippyai/go-lua/compiler/ast" @@ -196,6 +197,45 @@ func TestBuildDeclaredTypesFromSymbolTypes_EmptySymbolTypes(t *testing.T) { } } +func TestBuildDeclaredTypesForResolve_MatchesSymbolTypePipeline(t *testing.T) { + fn := &ast.FunctionExpr{ + ParList: &ast.ParList{ + Names: []string{"x"}, + }, + Stmts: []ast.Stmt{ + &ast.LocalAssignStmt{ + Names: []string{"print"}, + Exprs: []ast.Expr{&ast.NumberExpr{Value: "1"}}, + }, + &ast.LocalAssignStmt{ + Names: []string{"x"}, + Exprs: []ast.Expr{&ast.NumberExpr{Value: "2"}}, + }, + &ast.ReturnStmt{ + Exprs: []ast.Expr{ + &ast.IdentExpr{Value: "print"}, + &ast.IdentExpr{Value: "x"}, + }, + }, + }, + } + + graph := cfg.Build(fn, "print") + paramSyms := graph.ParamSymbols() + if len(paramSyms) != 1 || paramSyms[0] == 0 { + t.Fatal("expected one parameter symbol") + } + + globalTypes := map[string]typ.Type{"print": typ.String} + paramTypes := map[cfg.SymbolID]typ.Type{paramSyms[0]: typ.Number} + + want := BuildDeclaredTypesFromSymbolTypes(graph, BuildInitialSymbolTypes(graph, globalTypes, paramTypes)) + got := BuildDeclaredTypesForResolve(graph, globalTypes, paramTypes) + if !reflect.DeepEqual(got, want) { + t.Fatalf("BuildDeclaredTypesForResolve mismatch:\n got: %#v\nwant: %#v", got, want) + } +} + func TestBuildDeclaredTypesFromSymbolTypes_WithTypes(t *testing.T) { fn := &ast.FunctionExpr{ParList: &ast.ParList{}} graph := cfg.Build(fn) diff --git a/compiler/check/phase/types.go b/compiler/check/phase/types.go index 9b1dace0..80d01d05 100644 --- a/compiler/check/phase/types.go +++ b/compiler/check/phase/types.go @@ -100,8 +100,8 @@ type PhaseEnv struct { // ModuleBindings is the binding table for the entire module. ModuleBindings *bind.BindingTable - // EffectStore provides function effect lookups for callee analysis. - EffectStore api.EffectStore + // RefinementStore provides function refinement lookups for callee analysis. + RefinementStore api.RefinementStore // Scopes maps CFG points to scope states (populated after scope phase). Scopes map[cfg.Point]*scope.State @@ -271,11 +271,11 @@ type NarrowInput struct { } // NarrowOutput contains outputs from the narrowing phase. -// Phase D outputs: TypeFacts and effects. +// Phase D outputs: TypeFacts and the inferred function refinement. type NarrowOutput struct { - Facts flow.TypeFacts - Effect *constraint.FunctionEffect - Synth synth.Synth + Facts flow.TypeFacts + Refinement *constraint.FunctionRefinement + Synth synth.Synth } // ContextBuilder constructs Env instances from phase outputs. @@ -385,7 +385,7 @@ func (b *ContextBuilder) BuildDeclared() *api.DeclaredEnvImpl { LiteralTypes: b.literalTypes, AnnotatedVars: b.annotatedVars, BaseScope: b.baseScope, - EffectStore: b.env.EffectStore, + RefinementStore: b.env.RefinementStore, ModuleAliases: b.env.ModuleAliases, GlobalTypes: b.env.GlobalTypes, ReturnSummaries: b.returnSummaries, @@ -403,7 +403,7 @@ func (b *ContextBuilder) BuildNarrow() *api.NarrowEnvImpl { AnnotatedVars: b.annotatedVars, Solution: b.solution, BaseScope: b.baseScope, - EffectStore: b.env.EffectStore, + RefinementStore: b.env.RefinementStore, ModuleAliases: b.env.ModuleAliases, GlobalTypes: b.env.GlobalTypes, NarrowReturnSummaries: b.narrowReturnSummaries, diff --git a/compiler/check/phase/types_test.go b/compiler/check/phase/types_test.go index 9acc3c13..4adde914 100644 --- a/compiler/check/phase/types_test.go +++ b/compiler/check/phase/types_test.go @@ -40,8 +40,8 @@ func TestPhaseEnv_Fields(t *testing.T) { if env.ModuleBindings != nil { t.Error("ModuleBindings should be nil by default") } - if env.EffectStore != nil { - t.Error("EffectStore should be nil by default") + if env.RefinementStore != nil { + t.Error("RefinementStore should be nil by default") } if env.Scopes != nil { t.Error("Scopes should be nil by default") @@ -246,7 +246,7 @@ func TestNarrowOutput_Fields(t *testing.T) { if out.Facts != nil { t.Error("Facts should be nil by default") } - if out.Effect != nil { + if out.Refinement != nil { t.Error("Effect should be nil by default") } if out.Synth != nil { diff --git a/compiler/check/pipeline/driver.go b/compiler/check/pipeline/driver.go index 1aab8fbf..ff653539 100644 --- a/compiler/check/pipeline/driver.go +++ b/compiler/check/pipeline/driver.go @@ -175,7 +175,7 @@ func (d *Driver) checkFunctionFixpoint(sess api.AnalysisSession, fn *ast.Functio funcSym = sym } } - d.storeFunctionEffect(store, result, funcSym) + d.storeFunctionRefinement(store, result, funcSym) interprocinfer.StoreFactsFromResult(store, fn, result, parent) d.processNestedFunctions(sess, store, graph, results, result) } @@ -234,15 +234,15 @@ func (d *Driver) runReturnInference( MaxIterations: returns.MaxReturnSummaryIterations, }) - var effectLookup constraint.EffectLookupBySym - if es := store.EffectStore(); es != nil { - effectLookup = es.LookupEffectBySym + var refinementLookup constraint.RefinementLookupBySym + if es := store.RefinementStore(); es != nil { + refinementLookup = es.LookupRefinementBySym } summaries, funcTypes, diags := inferencer.ComputeForGraph(returninfer.RunContext{ Ctx: sess.Context(), ParentFacts: d.parentFactsForGraph(sess, store, graph.ID()), - EffectLookup: effectLookup, + EffectLookup: refinementLookup, }, graph, parent) if len(diags) > 0 { sess.AppendDiagnostics(diags...) @@ -351,18 +351,18 @@ func (d *Driver) emitScopeDepthDiagnostic(sess api.AnalysisSession, fn *ast.Func scopeState[fn] = true } -func (d *Driver) storeFunctionEffect(store api.IterationStore, result *api.FuncResult, funcSym cfg.SymbolID) { +func (d *Driver) storeFunctionRefinement(store api.IterationStore, result *api.FuncResult, funcSym cfg.SymbolID) { if result == nil || store == nil || funcSym == 0 { return } - lookup := func(sym cfg.SymbolID) *constraint.FunctionEffect { - return effects.LookupEffectBySym(store.EffectStore(), store.ModuleBindings(), d.cfg.GlobalTypes, sym) + lookup := func(sym cfg.SymbolID) *constraint.FunctionRefinement { + return effects.LookupRefinementBySym(store.RefinementStore(), store.ModuleBindings(), d.cfg.GlobalTypes, sym) } fnEffect := effects.Propagate(result, lookup) if fnEffect == nil { return } - store.StoreFunctionEffect(funcSym, fnEffect) + store.StoreFunctionRefinement(funcSym, fnEffect) } func collectGlobalNames(globalTypes map[string]typ.Type) []string { diff --git a/compiler/check/pipeline/runner.go b/compiler/check/pipeline/runner.go index eb4e83b9..dd19e746 100644 --- a/compiler/check/pipeline/runner.go +++ b/compiler/check/pipeline/runner.go @@ -117,15 +117,15 @@ func (r *Runner) Run(ctx *db.QueryContext, key api.FuncKey) *api.FuncResult { localAliases := modules.CollectAliases(graph) mergedAliases := modules.MergeAliases(store.ModuleAliases(), localAliases) env := phase.PhaseEnv{ - Ctx: ctx, - Graph: graph, - Fn: fn, - Types: r.types, - Manifests: r.manifests, - GlobalTypes: r.globalTypes, - ModuleAliases: mergedAliases, - ModuleBindings: store.ModuleBindings(), - EffectStore: effectStoreFrom(store), + Ctx: ctx, + Graph: graph, + Fn: fn, + Types: r.types, + Manifests: r.manifests, + GlobalTypes: r.globalTypes, + ModuleAliases: mergedAliases, + ModuleBindings: store.ModuleBindings(), + RefinementStore: effectStoreFrom(store), } // Phase A: Resolve type annotations. @@ -221,7 +221,7 @@ func (r *Runner) Run(ctx *db.QueryContext, key api.FuncKey) *api.FuncResult { Facts: narrowOut.Facts, FlowInputs: extractOut.Inputs, FlowSolution: solveOut.Solution, - FnEffect: narrowOut.Effect, + FnRefinement: narrowOut.Refinement, NarrowSynth: narrowOut.Synth, LiteralSignatures: literalOut.Signatures, Extras: extras, diff --git a/compiler/check/pipeline/runner_stages.go b/compiler/check/pipeline/runner_stages.go index 8d0498e7..7ad22200 100644 --- a/compiler/check/pipeline/runner_stages.go +++ b/compiler/check/pipeline/runner_stages.go @@ -181,15 +181,15 @@ func (r *Runner) literalSigProvider(store api.StoreView, graph *cfg.Graph, paren } type effectStoreProvider interface { - EffectStore() api.EffectStore + RefinementStore() api.RefinementStore } -func effectStoreFrom(store api.StoreView) api.EffectStore { +func effectStoreFrom(store api.StoreView) api.RefinementStore { if store == nil { return nil } if provider, ok := store.(effectStoreProvider); ok { - return provider.EffectStore() + return provider.RefinementStore() } return nil } diff --git a/compiler/check/returns/join.go b/compiler/check/returns/join.go index 1d33abb5..37396615 100644 --- a/compiler/check/returns/join.go +++ b/compiler/check/returns/join.go @@ -119,6 +119,12 @@ func ReturnTypesElideOptional(a, b []typ.Type) bool { // The nil-only guard prevents a refined-but-empty-looking update from // regressing an already informative summary to just nil. func SelectPreferredReturnVector(a, b []typ.Type) ([]typ.Type, bool) { + if ReturnTypesRepairNever(a, b) { + return a, true + } + if ReturnTypesRepairNever(b, a) { + return b, true + } if ReturnTypesRefine(a, b) { if ReturnTypesAllNil(a) && !ReturnTypesAllNil(b) { return b, true @@ -196,6 +202,30 @@ func ReturnTypesFillNilSlots(a, b []typ.Type) bool { return strict } +// ReturnTypesRepairNever reports whether candidate is a runtime-possible repair +// of baseline by replacing nested never artifacts while otherwise widening +// compatibly. This lets post-flow summaries correct pre-flow bottoms such as +// `{data?: never}` -> `{data?: unknown}`. +func ReturnTypesRepairNever(candidate, baseline []typ.Type) bool { + if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { + return false + } + strict := false + for i := range candidate { + if candidate[i] == nil || baseline[i] == nil { + return false + } + if typ.TypeEquals(candidate[i], baseline[i]) { + continue + } + if !typeRepairsNever(candidate[i], baseline[i]) { + return false + } + strict = true + } + return strict +} + // TypeExtendsRecord reports whether type a extends type b by adding record fields. // This treats record field supersets as refinements when b is a record or union of records. func TypeExtendsRecord(a, b typ.Type) bool { @@ -216,6 +246,290 @@ func TypeExtendsRecord(a, b typ.Type) bool { } } +func typeRepairsNever(candidate, baseline typ.Type) bool { + if candidate == nil || baseline == nil { + return false + } + if !typeContainsNever(baseline) || typeContainsNever(candidate) { + return false + } + ok, strict := typeNeverRepairRelation(candidate, baseline) + return ok && strict +} + +func typeNeverRepairRelation(candidate, baseline typ.Type) (bool, bool) { + if candidate == nil || baseline == nil { + return false, false + } + if typ.TypeEquals(candidate, baseline) { + return true, false + } + + candidate = unwrap.Alias(candidate) + baseline = unwrap.Alias(baseline) + if candidate == nil || baseline == nil { + return false, false + } + + if typ.IsNever(baseline) { + return !typ.IsNever(candidate), !typ.IsNever(candidate) + } + if !typeContainsNever(baseline) { + return false, false + } + + switch b := baseline.(type) { + case *typ.Optional: + c, ok := candidate.(*typ.Optional) + if !ok { + return false, false + } + return typeNeverRepairRelation(c.Inner, b.Inner) + case *typ.Union: + c, ok := candidate.(*typ.Union) + if !ok || len(c.Members) != len(b.Members) { + return false, false + } + used := make([]bool, len(c.Members)) + strict := false + for _, bm := range b.Members { + matched := false + for j, cm := range c.Members { + if used[j] || !typ.TypeEquals(cm, bm) { + continue + } + used[j] = true + matched = true + break + } + if matched { + continue + } + for j, cm := range c.Members { + if used[j] { + continue + } + ok, repaired := typeNeverRepairRelation(cm, bm) + if !ok { + continue + } + used[j] = true + matched = true + if repaired { + strict = true + } + break + } + if !matched { + return false, false + } + } + return true, strict + case *typ.Record: + c, ok := candidate.(*typ.Record) + if !ok || c.Open != b.Open || c.HasMapComponent() != b.HasMapComponent() || len(c.Fields) != len(b.Fields) { + return false, false + } + strict := false + for _, bf := range b.Fields { + cf := c.GetField(bf.Name) + if cf == nil || cf.Optional != bf.Optional || cf.Readonly != bf.Readonly { + return false, false + } + ok, repaired := typeNeverRepairRelation(cf.Type, bf.Type) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + if b.HasMapComponent() { + ok, repaired := typeNeverRepairRelation(c.MapKey, b.MapKey) + if !ok { + return false, false + } + if repaired { + strict = true + } + ok, repaired = typeNeverRepairRelation(c.MapValue, b.MapValue) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + if b.Metatable != nil || c.Metatable != nil { + if b.Metatable == nil || c.Metatable == nil { + return false, false + } + ok, repaired := typeNeverRepairRelation(c.Metatable, b.Metatable) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + return true, strict + case *typ.Array: + c, ok := candidate.(*typ.Array) + if !ok { + return false, false + } + return typeNeverRepairRelation(c.Element, b.Element) + case *typ.Map: + c, ok := candidate.(*typ.Map) + if !ok { + return false, false + } + keyOK, keyStrict := typeNeverRepairRelation(c.Key, b.Key) + if !keyOK { + return false, false + } + valOK, valStrict := typeNeverRepairRelation(c.Value, b.Value) + if !valOK { + return false, false + } + return true, keyStrict || valStrict + case *typ.Tuple: + c, ok := candidate.(*typ.Tuple) + if !ok || len(c.Elements) != len(b.Elements) { + return false, false + } + strict := false + for i := range b.Elements { + ok, repaired := typeNeverRepairRelation(c.Elements[i], b.Elements[i]) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + return true, strict + case *typ.Function: + c, ok := candidate.(*typ.Function) + if !ok || !sameFunctionShapeForFactMerge(c, b) || len(c.Returns) != len(b.Returns) { + return false, false + } + for i := range b.Params { + if c.Params[i].Name != b.Params[i].Name || + c.Params[i].Optional != b.Params[i].Optional || + !typ.TypeEquals(c.Params[i].Type, b.Params[i].Type) { + return false, false + } + } + switch { + case (c.Variadic == nil) != (b.Variadic == nil): + return false, false + case c.Variadic != nil && !typ.TypeEquals(c.Variadic, b.Variadic): + return false, false + } + strict := false + for i := range b.Returns { + ok, repaired := typeNeverRepairRelation(c.Returns[i], b.Returns[i]) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + return true, strict + default: + return false, false + } +} + +func typeContainsNever(t typ.Type) bool { + seen := make(map[typ.Type]bool) + return typeContainsNeverMemo(t, seen) +} + +func typeContainsNeverMemo(t typ.Type, seen map[typ.Type]bool) bool { + if t == nil { + return false + } + if seen[t] { + return false + } + seen[t] = true + t = unwrap.Alias(t) + if t == nil { + return false + } + if typ.IsNever(t) { + return true + } + return typ.Visit(t, typ.Visitor[bool]{ + Optional: func(o *typ.Optional) bool { + return typeContainsNeverMemo(o.Inner, seen) + }, + Union: func(u *typ.Union) bool { + for _, m := range u.Members { + if typeContainsNeverMemo(m, seen) { + return true + } + } + return false + }, + Intersection: func(in *typ.Intersection) bool { + for _, m := range in.Members { + if typeContainsNeverMemo(m, seen) { + return true + } + } + return false + }, + Tuple: func(tup *typ.Tuple) bool { + for _, e := range tup.Elements { + if typeContainsNeverMemo(e, seen) { + return true + } + } + return false + }, + Array: func(a *typ.Array) bool { + return typeContainsNeverMemo(a.Element, seen) + }, + Map: func(m *typ.Map) bool { + return typeContainsNeverMemo(m.Key, seen) || typeContainsNeverMemo(m.Value, seen) + }, + Record: func(r *typ.Record) bool { + for _, f := range r.Fields { + if typeContainsNeverMemo(f.Type, seen) { + return true + } + } + if r.HasMapComponent() { + return typeContainsNeverMemo(r.MapKey, seen) || typeContainsNeverMemo(r.MapValue, seen) + } + return false + }, + Function: func(fn *typ.Function) bool { + for _, p := range fn.Params { + if typeContainsNeverMemo(p.Type, seen) { + return true + } + } + if fn.Variadic != nil && typeContainsNeverMemo(fn.Variadic, seen) { + return true + } + for _, ret := range fn.Returns { + if typeContainsNeverMemo(ret, seen) { + return true + } + } + return false + }, + Default: func(typ.Type) bool { + return false + }, + }) +} + func typeElidesOptional(a, b typ.Type) bool { if a == nil || b == nil { return false @@ -338,6 +652,12 @@ func MergeReturnSummary(existing, candidate []typ.Type) []typ.Type { if replaced, ok := replaceOpenTopWithStructured(existing, candidate); ok { existing = normalizeAndPruneReturnVector(replaced) } + if ReturnTypesRepairNever(existing, candidate) { + return existing + } + if ReturnTypesRepairNever(candidate, existing) { + return candidate + } // Higher-order summaries are merged monotonically for fixpoint stability. if shouldUseMonotoneReturnJoin(existing, candidate) { diff --git a/compiler/check/returns/join_test.go b/compiler/check/returns/join_test.go index fa114120..e41facdc 100644 --- a/compiler/check/returns/join_test.go +++ b/compiler/check/returns/join_test.go @@ -580,6 +580,70 @@ func TestAlignFunctionTypeWithSummary_DoesNotDowngradeStructuredToPlaceholder(t } } +func TestMergeReturnSummary_PrefersRuntimePossibleSummaryOverNeverArtifact(t *testing.T) { + bad := []typ.Type{ + typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().OptField("data", typ.Never).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ), + } + good := []typ.Type{ + typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().OptField("data", typ.Unknown).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ), + } + + got := MergeReturnSummary(bad, good) + if !ReturnTypesEqual(got, good) { + t.Fatalf("MergeReturnSummary(%v, %v) = %v, want %v", bad, good, got, good) + } +} + +func TestAlignFunctionTypeWithSummary_RepairsNestedNeverArtifact(t *testing.T) { + bad := typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().OptField("data", typ.Never).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ) + good := typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().OptField("data", typ.Unknown).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ) + + fn := typ.Func().Returns(bad).Build() + aligned, changed := AlignFunctionTypeWithSummary(fn, []typ.Type{good}) + if !changed { + t.Fatal("expected never-artifact repair to update function returns") + } + if aligned == nil || len(aligned.Returns) != 1 || !typ.TypeEquals(aligned.Returns[0], good) { + t.Fatalf("aligned returns = %v, want %v", aligned, good) + } +} + func TestRecordSuperset_NewHasMapComponentOldDoesNot(t *testing.T) { oldRec := typ.NewRecord().Field("x", typ.Number).Build() newRec := typ.NewRecord().Field("x", typ.Number).MapComponent(typ.String, typ.Any).Build() diff --git a/compiler/check/returns/kernel_test.go b/compiler/check/returns/kernel_test.go index 6576909b..665f285f 100644 --- a/compiler/check/returns/kernel_test.go +++ b/compiler/check/returns/kernel_test.go @@ -139,6 +139,55 @@ func TestReconcileFunctionFact_NarrowSummaryReplacesOpenTopPlaceholder(t *testin } } +func TestReconcileFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { + bad := []typ.Type{ + typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().OptField("data", typ.Never).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ), + } + good := []typ.Type{ + typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().OptField("data", typ.Unknown).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ), + } + existingFunc := typ.Func().Returns(bad...).Build() + + out := ReconcileFunctionFact(ReconcileFunctionFactInput{ + ExistingSummary: bad, + ExistingNarrow: nil, + ExistingFunc: existingFunc, + CandidateNarrow: good, + }) + + if !ReturnTypesEqual(out.Summary, good) { + t.Fatalf("summary mismatch: got %v want %v", out.Summary, good) + } + if !ReturnTypesEqual(out.Narrow, good) { + t.Fatalf("narrow mismatch: got %v want %v", out.Narrow, good) + } + fn, ok := out.Func.(*typ.Function) + if !ok { + t.Fatalf("expected function fact, got %T", out.Func) + } + if !ReturnTypesEqual(fn.Returns, good) { + t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, good) + } +} + func TestMergeFunctionFactIntoFacts_ReadsLegacyAndWritesCanonical(t *testing.T) { sym := cfg.SymbolID(41) facts := &api.Facts{ diff --git a/compiler/check/returns/overlay.go b/compiler/check/returns/overlay.go index 6838aa5e..8388b922 100644 --- a/compiler/check/returns/overlay.go +++ b/compiler/check/returns/overlay.go @@ -3,10 +3,8 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" - "github.com/wippyai/go-lua/types/flow" - querycore "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/compiler/check/overlaymut" "github.com/wippyai/go-lua/types/typ" - "github.com/wippyai/go-lua/types/typ/unwrap" ) // This file provides utilities for applying type mutations (field assignments, @@ -25,20 +23,7 @@ func MergeFieldAssignments( dst map[cfg.SymbolID]map[string]typ.Type, src map[cfg.SymbolID]map[string]typ.Type, ) { - for _, sym := range cfg.SortedSymbolIDs(src) { - fields := src[sym] - if dst[sym] == nil { - dst[sym] = make(map[string]typ.Type) - } - for _, name := range cfg.SortedFieldNames(fields) { - fieldType := fields[name] - if existing := dst[sym][name]; existing != nil { - dst[sym][name] = typ.JoinPreferNonSoft(existing, fieldType) - } else { - dst[sym][name] = fieldType - } - } - } + overlaymut.MergeFieldAssignments(dst, src) } // ApplyFieldMergeToOverlay merges collected field assignments into symbol types in the overlay. @@ -54,17 +39,7 @@ func ApplyFieldMergeToOverlay( overlay map[cfg.SymbolID]typ.Type, fieldAssignments map[cfg.SymbolID]map[string]typ.Type, ) { - for _, sym := range cfg.SortedSymbolIDs(fieldAssignments) { - fields := fieldAssignments[sym] - if len(fields) == 0 { - continue - } - baseType := overlay[sym] - merged := MergeFieldsIntoType(baseType, fields) - if merged != nil { - overlay[sym] = merged - } - } + overlaymut.ApplyFieldMergeToOverlay(overlay, fieldAssignments) } // MergeFieldsIntoType merges a set of field types into a base type. @@ -78,63 +53,7 @@ func ApplyFieldMergeToOverlay( // Field names are sorted for deterministic output. Existing record fields // are preserved (not overwritten) since they represent more precise type info. func MergeFieldsIntoType(baseType typ.Type, fields map[string]typ.Type) typ.Type { - if len(fields) == 0 { - return baseType - } - - fieldNames := cfg.SortedFieldNames(fields) - - if baseType == nil { - // No base type - create a fresh record with just the fields - builder := typ.NewRecord().SetOpen(true) - for _, name := range fieldNames { - builder.Field(name, fields[name]) - } - return builder.Build() - } - - switch v := baseType.(type) { - case *typ.Map: - // Map base: create Record(open) with MapComponent + merged fields - builder := typ.NewRecord().SetOpen(true) - builder.MapComponent(v.Key, v.Value) - for _, name := range fieldNames { - builder.Field(name, fields[name]) - } - return builder.Build() - - case *typ.Record: - // Build merged record: existing fields + new fields - builder := typ.NewRecord() - if v.Open { - builder.SetOpen(true) - } - existing := make(map[string]bool) - for _, f := range v.Fields { - builder.Field(f.Name, f.Type) - existing[f.Name] = true - } - for _, name := range fieldNames { - if !existing[name] { - builder.Field(name, fields[name]) - } - } - if v.Metatable != nil { - builder.Metatable(v.Metatable) - } - if v.HasMapComponent() { - builder.MapComponent(v.MapKey, v.MapValue) - } - return builder.Build() - - default: - // Base is not a record or map; create one with just the field assignments - builder := typ.NewRecord().SetOpen(true) - for _, name := range fieldNames { - builder.Field(name, fields[name]) - } - return builder.Build() - } + return overlaymut.MergeFieldsIntoType(baseType, fields) } // ApplyIndexerMergeToOverlay adds map components to symbol types based on dynamic index assignments. @@ -151,31 +70,7 @@ func ApplyIndexerMergeToOverlay( overlay map[cfg.SymbolID]typ.Type, indexerAssignments map[cfg.SymbolID][]mutator.IndexerInfo, ) { - for _, sym := range cfg.SortedSymbolIDs(indexerAssignments) { - infos := indexerAssignments[sym] - if len(infos) == 0 { - continue - } - - // Join all key types and value types, preferring array types over empty records - var keyType, valType typ.Type - for _, info := range infos { - keyType = typ.JoinPreferNonSoft(keyType, info.KeyType) - valType = JoinValueTypes(valType, info.ValType) - } - if keyType == nil { - keyType = typ.String - } - if valType == nil { - valType = typ.Unknown - } - - baseType := overlay[sym] - merged := MergeMapComponentIntoType(baseType, keyType, valType) - if merged != nil { - overlay[sym] = merged - } - } + overlaymut.ApplyIndexerMergeToOverlay(overlay, indexerAssignments) } // JoinValueTypes joins two value types, preferring arrays over empty records. @@ -185,38 +80,7 @@ func ApplyIndexerMergeToOverlay( // then using it as an array via table.insert or indexed assignment. // The array type takes precedence because it carries more specific information. func JoinValueTypes(a, b typ.Type) typ.Type { - if a == nil { - return b - } - if b == nil { - return a - } - - // Check if one is an empty record and the other is an array - aIsEmptyRecord := unwrap.IsEmptyRecord(a) - bIsEmptyRecord := unwrap.IsEmptyRecord(b) - _, aIsArray := a.(*typ.Array) - _, bIsArray := b.(*typ.Array) - aIsPlaceholder := a.Kind().IsPlaceholder() - bIsPlaceholder := b.Kind().IsPlaceholder() - - // Prefer array over empty record - if aIsEmptyRecord && bIsArray { - return b - } - if bIsEmptyRecord && aIsArray { - return a - } - - // Prefer array over placeholder (unknown/any) when array evidence exists. - if aIsPlaceholder && bIsArray { - return b - } - if bIsPlaceholder && aIsArray { - return a - } - - return typ.JoinPreferNonSoft(a, b) + return overlaymut.JoinValueTypes(a, b) } // MergeMapComponentIntoType adds a map component to a base type. @@ -230,46 +94,7 @@ func JoinValueTypes(a, b typ.Type) typ.Type { // This is used when dynamic index assignments are detected, indicating the // variable is used as a map or has map-like access patterns. func MergeMapComponentIntoType(baseType, keyType, valType typ.Type) typ.Type { - if baseType == nil { - return typ.NewMap(keyType, valType) - } - - switch v := baseType.(type) { - case *typ.Map: - newKey := typ.JoinPreferNonSoft(v.Key, keyType) - newVal := typ.JoinPreferNonSoft(v.Value, valType) - return typ.NewMap(newKey, newVal) - - case *typ.Record: - builder := typ.NewRecord() - if v.Open { - builder.SetOpen(true) - } - for _, f := range v.Fields { - builder.Field(f.Name, f.Type) - } - if v.Metatable != nil { - builder.Metatable(v.Metatable) - } - if v.HasMapComponent() { - newKey := typ.JoinPreferNonSoft(v.MapKey, keyType) - newVal := typ.JoinPreferNonSoft(v.MapValue, valType) - builder.MapComponent(newKey, newVal) - } else { - // Preserve existing record key domain when adding first map component. - // For open records (`{...}`) this keeps the canonical string-key domain - // instead of degrading to unknown from dynamic-key assignments. - existingKey := querycore.KeyType(v) - if existingKey == nil { - existingKey = typ.String - } - builder.MapComponent(typ.JoinPreferNonSoft(existingKey, keyType), valType) - } - return builder.Build() - - default: - return typ.NewMap(keyType, valType) - } + return overlaymut.MergeMapComponentIntoType(baseType, keyType, valType) } // ApplyDirectMutationsToOverlay widens array element types based on table.insert mutations. @@ -284,15 +109,5 @@ func ApplyDirectMutationsToOverlay( overlay map[cfg.SymbolID]typ.Type, mutations map[cfg.SymbolID]typ.Type, ) { - for _, sym := range cfg.SortedSymbolIDs(mutations) { - elemType := mutations[sym] - if elemType == nil { - continue - } - baseType := overlay[sym] - merged := flow.WidenArrayElementType(baseType, elemType, typ.JoinPreferNonSoft) - if merged != nil { - overlay[sym] = merged - } - } + overlaymut.ApplyDirectMutationsToOverlay(overlay, mutations) } diff --git a/compiler/check/session.go b/compiler/check/session.go index 13e9511a..3475a9ae 100644 --- a/compiler/check/session.go +++ b/compiler/check/session.go @@ -219,7 +219,7 @@ func (s *Session) ScopeDepthDiagState() map[*ast.FunctionExpr]bool { func New(ctx *db.QueryContext, name string) *Session { store := store.NewSessionStore() api.AttachStore(ctx, store) - return &Session{ + sess := &Session{ Ctx: ctx, SourceName: name, Store: store, @@ -230,6 +230,8 @@ func New(ctx *db.QueryContext, name string) *Session { cfgCache: make(map[*ast.FunctionExpr]*cfg.Graph), scopeDepthDiagEmitted: make(map[*ast.FunctionExpr]bool), } + api.AttachGraphs(ctx, sess) + return sess } // GetOrBuildCFG returns a cached CFG for the function or builds and caches a new one. @@ -358,8 +360,8 @@ func (s *Session) PluginLoad(key any) any { // // EFFECT ENRICHMENT: // If the returned value is a table/record, its method types are enriched with -// function effects computed during analysis. This ensures exported functions -// carry their side effect annotations. +// function refinements computed during analysis. This ensures exported functions +// carry their refinement summaries. // // USAGE: // @@ -370,13 +372,13 @@ func (s *Session) ExportType() typ.Type { if s == nil { return typ.Nil } - var effects map[cfg.SymbolID]*constraint.FunctionEffect + var refinements map[cfg.SymbolID]*constraint.FunctionRefinement if s.Store != nil { if s.Store.InterprocPrev != nil { - effects = s.Store.InterprocPrev.Effects + refinements = s.Store.InterprocPrev.Refinements } } - return modules.ExportType(s.RootResult, effects) + return modules.ExportType(s.RootResult, refinements) } // Release frees heavy allocations to reduce memory pressure after analysis. @@ -414,12 +416,12 @@ func (s *Session) Release() { // Clear interproc snapshots if s.Store.InterprocPrev != nil { clear(s.Store.InterprocPrev.Facts) - clear(s.Store.InterprocPrev.Effects) + clear(s.Store.InterprocPrev.Refinements) clear(s.Store.InterprocPrev.ConstructorFields) } if s.Store.InterprocNext != nil { clear(s.Store.InterprocNext.Facts) - clear(s.Store.InterprocNext.Effects) + clear(s.Store.InterprocNext.Refinements) clear(s.Store.InterprocNext.ConstructorFields) } @@ -467,7 +469,7 @@ func (s *Session) ExportTypes() map[string]typ.Type { // ExportManifest builds a module manifest from this session using canonical export policy. // // The manifest includes: -// - Export type from root returns (with converged function effects applied) +// - Export type from root returns (with converged function refinements applied) // - Exported type definitions // - Function summaries for exported functions (ensures/effects for cross-module narrowing) // @@ -489,24 +491,24 @@ func (s *Session) ExportManifest(modulePath string) *io.Manifest { manifest.DefineType(typeName, t) } - modules.ExportFunctionSummaries(manifest, exportType, s.RootGraph(), s.EffectsForExport()) + modules.ExportFunctionSummaries(manifest, exportType, s.RootGraph(), s.RefinementsForExport()) return manifest } -// EffectsForExport extracts computed function effects for manifest generation. -// Returns effects from the final converged interproc snapshot. +// RefinementsForExport extracts computed function refinements for manifest generation. +// Returns refinements from the final converged interproc snapshot. // -// The returned map associates each function's SymbolID with its computed effect, +// The returned map associates each function's SymbolID with its computed refinement, // including IO effects (row), termination status, and conditional effects. // This enables importers to see side effect information for exported functions. -func (s *Session) EffectsForExport() map[cfg.SymbolID]*constraint.FunctionEffect { +func (s *Session) RefinementsForExport() map[cfg.SymbolID]*constraint.FunctionRefinement { if s == nil || s.Store == nil { return nil } if s.Store.InterprocPrev == nil { return nil } - return modules.CopyEffectsForExport(s.Store.InterprocPrev.Effects) + return modules.CopyRefinementsForExport(s.Store.InterprocPrev.Refinements) } // RootGraph returns the root function's control flow graph. diff --git a/compiler/check/session_test.go b/compiler/check/session_test.go index 87203b1e..703c3f47 100644 --- a/compiler/check/session_test.go +++ b/compiler/check/session_test.go @@ -351,6 +351,16 @@ func TestStoreFrom(t *testing.T) { } } +func TestGraphsFrom(t *testing.T) { + ctx := db.NewQueryContext(db.New()) + sess := New(ctx, "test.lua") + + graphs := api.GraphsFrom(ctx) + if graphs != sess { + t.Error("GraphsFrom should return the session graph provider") + } +} + func TestStoreFrom_NilContext(t *testing.T) { store := api.StoreFrom(nil) if store != nil { @@ -388,10 +398,10 @@ func TestAttachStore_NilStore(t *testing.T) { func TestSessionStore_EffectMaps(t *testing.T) { store := store.NewSessionStore() - if store.InterprocPrev == nil || store.InterprocPrev.Effects == nil { + if store.InterprocPrev == nil || store.InterprocPrev.Refinements == nil { t.Error("InterprocPrev effects not initialized") } - if store.InterprocNext == nil || store.InterprocNext.Effects == nil { + if store.InterprocNext == nil || store.InterprocNext.Refinements == nil { t.Error("InterprocNext effects not initialized") } } @@ -400,8 +410,8 @@ func TestFixpointChannelDiffs_IsolatedBetweenStores(t *testing.T) { storeA := store.NewSessionStore() storeB := store.NewSessionStore() - storeA.StoreFunctionEffect(cfg.SymbolID(42), &constraint.FunctionEffect{}) - storeB.StoreFunctionEffect(cfg.SymbolID(42), &constraint.FunctionEffect{}) + storeA.StoreFunctionRefinement(cfg.SymbolID(42), &constraint.FunctionRefinement{}) + storeB.StoreFunctionRefinement(cfg.SymbolID(42), &constraint.FunctionRefinement{}) if !storeA.FixpointSwap() { t.Fatal("expected storeA FixpointSwap to report change") @@ -419,15 +429,15 @@ func TestSessionStore_ClearIterationChannels(t *testing.T) { store := store.NewSessionStore() store.StoreConstructorFields(cfg.SymbolID(2), map[string]typ.Type{"name": typ.String}) - store.InterprocPrev.Effects[cfg.SymbolID(4)] = &constraint.FunctionEffect{} - store.StoreFunctionEffect(cfg.SymbolID(5), &constraint.FunctionEffect{}) + store.InterprocPrev.Refinements[cfg.SymbolID(4)] = &constraint.FunctionRefinement{} + store.StoreFunctionRefinement(cfg.SymbolID(5), &constraint.FunctionRefinement{}) store.ClearIterationChannels() if store.InterprocNext == nil || len(store.InterprocNext.ConstructorFields) != 0 { t.Fatal("expected constructor fields to be cleared") } - if len(store.InterprocPrev.Effects) != 0 || len(store.InterprocNext.Effects) != 0 { + if len(store.InterprocPrev.Refinements) != 0 || len(store.InterprocNext.Refinements) != 0 { t.Fatal("expected effects to be cleared") } } diff --git a/compiler/check/store/store.go b/compiler/check/store/store.go index 4c145220..ebd1534a 100644 --- a/compiler/check/store/store.go +++ b/compiler/check/store/store.go @@ -40,10 +40,10 @@ type SessionStore struct { phase api.Phase } -// InterprocState holds interprocedural facts and effects for an iteration snapshot. +// InterprocState holds interprocedural facts and refinements for an iteration snapshot. type InterprocState struct { Facts map[api.GraphKey]api.Facts - Effects map[cfg.SymbolID]*constraint.FunctionEffect + Refinements map[cfg.SymbolID]*constraint.FunctionRefinement ConstructorFields api.ConstructorFields } @@ -51,7 +51,7 @@ type InterprocState struct { func NewInterprocState() *InterprocState { return &InterprocState{ Facts: make(map[api.GraphKey]api.Facts), - Effects: make(map[cfg.SymbolID]*constraint.FunctionEffect), + Refinements: make(map[cfg.SymbolID]*constraint.FunctionRefinement), ConstructorFields: make(api.ConstructorFields), } } @@ -107,8 +107,8 @@ type IterationScratch struct { LiteralSigsByGraphID map[uint64]map[*ast.FunctionExpr]*typ.Function } -// effectsEqual compares two FunctionEffects for structural equality. -func effectsEqual(a, b *constraint.FunctionEffect) bool { +// effectsEqual compares two FunctionRefinements for structural equality. +func effectsEqual(a, b *constraint.FunctionRefinement) bool { if a == b { return true } @@ -119,7 +119,7 @@ func effectsEqual(a, b *constraint.FunctionEffect) bool { } // effectsMapEqual compares two effect maps for structural equality. -func effectsMapEqual(a, b map[cfg.SymbolID]*constraint.FunctionEffect) bool { +func effectsMapEqual(a, b map[cfg.SymbolID]*constraint.FunctionRefinement) bool { if len(a) != len(b) { return false } @@ -292,24 +292,24 @@ func (s *SessionStore) swapInterprocChannels() []string { s.ensureInterprocStates() // Channel policies: - // - Effects/ConstructorFields are overwrite channels (next snapshot replaces prev). + // - Refinements/ConstructorFields are overwrite channels (next snapshot replaces prev). // - InterprocFacts is a widening channel (monotone merge across iterations). channels := []struct { name string swap func() bool }{ { - name: "Effects", + name: "Refinements", swap: func() bool { return swapSnapshotChannel( - &s.InterprocPrev.Effects, - &s.InterprocNext.Effects, - func(_prev, next map[cfg.SymbolID]*constraint.FunctionEffect) map[cfg.SymbolID]*constraint.FunctionEffect { + &s.InterprocPrev.Refinements, + &s.InterprocNext.Refinements, + func(_prev, next map[cfg.SymbolID]*constraint.FunctionRefinement) map[cfg.SymbolID]*constraint.FunctionRefinement { return next }, effectsMapEqual, - func() map[cfg.SymbolID]*constraint.FunctionEffect { - return make(map[cfg.SymbolID]*constraint.FunctionEffect) + func() map[cfg.SymbolID]*constraint.FunctionRefinement { + return make(map[cfg.SymbolID]*constraint.FunctionRefinement) }, ) }, @@ -367,7 +367,7 @@ func (s *SessionStore) swapInterprocChannels() []string { // 5. Record which channels changed for diagnostic reporting // // CHANGE DETECTION: Each channel uses type-appropriate equality: -// - Effects: FunctionEffect.Equals (structural comparison) +// - Refinements: FunctionRefinement.Equals (structural comparison) // - ConstructorFields: typ.TypeEquals (structural equality) // // RETURN VALUE: Returns true if any channel changed, signaling another iteration @@ -416,31 +416,31 @@ func (s *SessionStore) BumpRevision() { s.Iteration.Revision++ } -// LookupEffectBySym returns the effect for a function by its SymbolID. -// Reads from the stable interproc effect snapshot for order-independent analysis. -func (s *SessionStore) LookupEffectBySym(sym cfg.SymbolID) *constraint.FunctionEffect { +// LookupRefinementBySym returns the refinement for a function by its SymbolID. +// Reads from the stable interproc refinement snapshot for order-independent analysis. +func (s *SessionStore) LookupRefinementBySym(sym cfg.SymbolID) *constraint.FunctionRefinement { if sym == 0 { return nil } - if s.InterprocPrev == nil || s.InterprocPrev.Effects == nil { + if s.InterprocPrev == nil || s.InterprocPrev.Refinements == nil { return nil } - return s.InterprocPrev.Effects[sym] + return s.InterprocPrev.Refinements[sym] } -// StoreFunctionEffect records a function effect for the current iteration. -func (s *SessionStore) StoreFunctionEffect(sym cfg.SymbolID, eff *constraint.FunctionEffect) { +// StoreFunctionRefinement records a function refinement for the current iteration. +func (s *SessionStore) StoreFunctionRefinement(sym cfg.SymbolID, eff *constraint.FunctionRefinement) { if s == nil || sym == 0 || eff == nil { return } s.ensureInterprocStates() - if s.InterprocNext.Effects == nil { - s.InterprocNext.Effects = make(map[cfg.SymbolID]*constraint.FunctionEffect) + if s.InterprocNext.Refinements == nil { + s.InterprocNext.Refinements = make(map[cfg.SymbolID]*constraint.FunctionRefinement) } - if existing := s.InterprocNext.Effects[sym]; effectsEqual(existing, eff) { + if existing := s.InterprocNext.Refinements[sym]; effectsEqual(existing, eff) { return } - s.InterprocNext.Effects[sym] = eff + s.InterprocNext.Refinements[sym] = eff } // StoreConstructorFields stores constructor fields for a class symbol. @@ -495,12 +495,12 @@ func (s *SessionStore) ClearIterationChannels() { s.lastSwapDiffs = nil } -// EffectStore returns a view over the stable interproc effect snapshot. -func (s *SessionStore) EffectStore() api.EffectStore { +// RefinementStore returns a view over the stable interproc refinement snapshot. +func (s *SessionStore) RefinementStore() api.RefinementStore { if s == nil || s.InterprocPrev == nil { - return &snapshotEffectStore{effects: nil} + return &snapshotRefinementStore{refinements: nil} } - return &snapshotEffectStore{effects: s.InterprocPrev.Effects} + return &snapshotRefinementStore{refinements: s.InterprocPrev.Refinements} } // ModuleBindings returns the module binding table. @@ -905,20 +905,20 @@ func (s *SessionStore) GetCapturedContainerMutationsSnapshot( return s.GetInterprocFactsSnapshot(graph, parent).CapturedContainers } -// snapshotEffectStore implements api.EffectStore using the stable snapshot. -type snapshotEffectStore struct { - effects map[cfg.SymbolID]*constraint.FunctionEffect +// snapshotRefinementStore implements api.RefinementStore using the stable snapshot. +type snapshotRefinementStore struct { + refinements map[cfg.SymbolID]*constraint.FunctionRefinement } -func (o *snapshotEffectStore) LookupEffectBySym(sym cfg.SymbolID) *constraint.FunctionEffect { +func (o *snapshotRefinementStore) LookupRefinementBySym(sym cfg.SymbolID) *constraint.FunctionRefinement { if o == nil || sym == 0 { return nil } - if o.effects == nil { + if o.refinements == nil { return nil } - if eff := o.effects[sym]; eff != nil { - return eff + if refinement := o.refinements[sym]; refinement != nil { + return refinement } return nil } diff --git a/compiler/check/store/store_test.go b/compiler/check/store/store_test.go index e93371a6..7855c401 100644 --- a/compiler/check/store/store_test.go +++ b/compiler/check/store/store_test.go @@ -20,8 +20,8 @@ func TestNewInterprocState(t *testing.T) { if state.Facts == nil { t.Error("Facts map should be initialized") } - if state.Effects == nil { - t.Error("Effects map should be initialized") + if state.Refinements == nil { + t.Error("Refinements map should be initialized") } if state.ConstructorFields == nil { t.Error("ConstructorFields map should be initialized") @@ -35,7 +35,7 @@ func TestEffectsEqual_BothNil(t *testing.T) { } func TestEffectsEqual_OneNil(t *testing.T) { - eff := &constraint.FunctionEffect{} + eff := &constraint.FunctionRefinement{} if effectsEqual(eff, nil) { t.Error("non-nil and nil should not be equal") } @@ -45,7 +45,7 @@ func TestEffectsEqual_OneNil(t *testing.T) { } func TestEffectsEqual_Same(t *testing.T) { - eff := &constraint.FunctionEffect{Terminates: true} + eff := &constraint.FunctionRefinement{Terminates: true} if !effectsEqual(eff, eff) { t.Error("same reference should be equal") } @@ -55,14 +55,14 @@ func TestEffectsMapEqual_Empty(t *testing.T) { if !effectsMapEqual(nil, nil) { t.Error("two nils should be equal") } - if !effectsMapEqual(map[cfg.SymbolID]*constraint.FunctionEffect{}, map[cfg.SymbolID]*constraint.FunctionEffect{}) { + if !effectsMapEqual(map[cfg.SymbolID]*constraint.FunctionRefinement{}, map[cfg.SymbolID]*constraint.FunctionRefinement{}) { t.Error("two empty maps should be equal") } } func TestEffectsMapEqual_DifferentLength(t *testing.T) { - a := map[cfg.SymbolID]*constraint.FunctionEffect{1: {}} - b := map[cfg.SymbolID]*constraint.FunctionEffect{} + a := map[cfg.SymbolID]*constraint.FunctionRefinement{1: {}} + b := map[cfg.SymbolID]*constraint.FunctionRefinement{} if effectsMapEqual(a, b) { t.Error("maps of different length should not be equal") } @@ -245,7 +245,7 @@ func TestIterationScratch_Fields(t *testing.T) { func TestFixpointSwap_TracksChannelDiffsAndResetsNext(t *testing.T) { s := NewSessionStore() - s.InterprocNext.Effects[1] = &constraint.FunctionEffect{Terminates: true} + s.InterprocNext.Refinements[1] = &constraint.FunctionRefinement{Terminates: true} s.InterprocNext.Facts[api.GraphKey{GraphID: 7, ParentHash: 11}] = api.Facts{ FunctionFacts: api.FunctionFacts{ 1: {Summary: []typ.Type{typ.String}}, @@ -263,15 +263,15 @@ func TestFixpointSwap_TracksChannelDiffsAndResetsNext(t *testing.T) { if len(diffs) != 3 { t.Fatalf("expected 3 channel diffs, got %v", diffs) } - if diffs[0] != "Effects" || diffs[1] != "InterprocFacts" || diffs[2] != "ConstructorFields" { + if diffs[0] != "Refinements" || diffs[1] != "InterprocFacts" || diffs[2] != "ConstructorFields" { t.Fatalf("unexpected diff order/content: %v", diffs) } - if len(s.InterprocPrev.Effects) != 1 || s.InterprocPrev.Effects[1] == nil { - t.Fatalf("expected prev effects populated, got %#v", s.InterprocPrev.Effects) + if len(s.InterprocPrev.Refinements) != 1 || s.InterprocPrev.Refinements[1] == nil { + t.Fatalf("expected prev effects populated, got %#v", s.InterprocPrev.Refinements) } - if len(s.InterprocNext.Effects) != 0 { - t.Fatalf("expected next effects reset, got %#v", s.InterprocNext.Effects) + if len(s.InterprocNext.Refinements) != 0 { + t.Fatalf("expected next effects reset, got %#v", s.InterprocNext.Refinements) } if len(s.InterprocPrev.Facts) != 1 { t.Fatalf("expected prev facts populated, got %#v", s.InterprocPrev.Facts) @@ -315,7 +315,7 @@ func TestBumpRevision_InitializesIterationStore(t *testing.T) { func TestFixpointChannelDiffs_ReturnsCopy(t *testing.T) { s := NewSessionStore() - s.StoreFunctionEffect(1, &constraint.FunctionEffect{Terminates: true}) + s.StoreFunctionRefinement(1, &constraint.FunctionRefinement{Terminates: true}) if !s.FixpointSwap() { t.Fatal("expected change from effect swap") } diff --git a/compiler/check/synth/engine.go b/compiler/check/synth/engine.go index 1a323d79..42246409 100644 --- a/compiler/check/synth/engine.go +++ b/compiler/check/synth/engine.go @@ -55,6 +55,7 @@ type Config struct { Paths api.PathFromExprFunc PreCache api.Cache NarrowCache api.Cache + Graphs api.GraphProvider Phase api.Phase ModuleBindings *bind.BindingTable ModuleAliases map[cfg.SymbolID]string @@ -105,6 +106,10 @@ func New(cfg Config) *Engine { if narrowCache == nil && isNarrowing { narrowCache = make(api.Cache) } + graphs := cfg.Graphs + if graphs == nil { + graphs = api.GraphsFrom(cfg.Ctx) + } deps := &extract.Deps{ Ctx: cfg.Ctx, @@ -112,6 +117,7 @@ func New(cfg Config) *Engine { Scopes: cfg.Scopes, Manifests: cfg.Manifests, CheckCtx: cfg.Env, + Graphs: graphs, Flow: cfg.Flow, Paths: cfg.Paths, PreCache: preCache, @@ -247,6 +253,8 @@ func (e *Engine) ResolveTypeDefAt(name string, typeExpr ast.TypeExpr, typeParams ExprSynth: func(expr ast.Expr, _ cfg.Point) typ.Type { return e.SynthExprAt(expr, p, sc) }, + ModuleBindings: e.deps.ModuleBindings, + ModuleAliases: e.deps.ModuleAliases, }) return resolver.ResolveTypeDef(name, typeExpr, typeParams, sc) } diff --git a/compiler/check/synth/ops/call.go b/compiler/check/synth/ops/call.go index 8ee8688b..cbac1ae6 100644 --- a/compiler/check/synth/ops/call.go +++ b/compiler/check/synth/ops/call.go @@ -67,8 +67,9 @@ import ( // Type is the computed return type (typ.Tuple for multiple returns, typ.Nil for void). // Errors contains type mismatches, arity problems, and other call-related issues. type CallResult struct { - Type typ.Type // Return type (or tuple for multiple returns) - Errors []CallError // Type errors detected + Type typ.Type // Packed return type (or tuple for multiple returns) + Returns []typ.Type // Expression-adjusted return vector + Errors []CallError // Type errors detected } // CallError describes a type error in a function call. @@ -258,9 +259,7 @@ func resolveCallee(ctx *db.QueryContext, def CallDef) (*resolvedCallee, *CallRes // unwrapCallee performs alias, generic body, and instantiated unwrapping. func unwrapCallee(callee typ.Type) typ.Type { - if callee.Kind() == kind.Alias { - callee = callee.(*typ.Alias).Target - } + callee = unwrap.Alias(callee) if g, ok := callee.(*typ.Generic); ok { callee = g.Body @@ -273,7 +272,7 @@ func unwrapCallee(callee typ.Type) typ.Type { } } - return callee + return unwrap.Alias(callee) } // inferAndCall performs generic type inference and calls the instantiated function. @@ -286,7 +285,7 @@ func inferAndCall(ctx *db.QueryContext, fn *typ.Function, def CallDef, isMethod typeArgs, err = InferTypeArgsWithExpectedAndMode(fn, def.Args, isMethod, receiver, nil, false) if err != nil { errors = append(errors, CallError{Kind: ErrTypeInference, Message: err.Error()}) - return CallResult{Type: typ.Unknown, Errors: errors} + return singleValueCallResult(typ.Unknown, errors) } } @@ -551,18 +550,18 @@ func computeExpectedArgs(ctx *db.QueryContext, query core.TypeOps, fn *typ.Funct // Otherwise, performs full argument checking and return type computation. func FinishCall(ctx *db.QueryContext, def CallDef, infer InferResult) CallResult { if infer.ShortCircuit != nil { - return CallResult{Type: infer.ShortCircuit, Errors: infer.Errors} + return singleValueCallResult(infer.ShortCircuit, infer.Errors) } switch infer.Kind { case InferKindNotCallable: - return CallResult{Type: typ.Unknown, Errors: infer.Errors} + return singleValueCallResult(typ.Unknown, infer.Errors) case InferKindAny: - return CallResult{Type: typ.Any, Errors: infer.Errors} + return singleValueCallResult(typ.Any, infer.Errors) case InferKindUnknown: - return CallResult{Type: typ.Unknown, Errors: infer.Errors} + return singleValueCallResult(typ.Unknown, infer.Errors) case InferKindUnion: return callUnionWithGenericInference( @@ -584,12 +583,12 @@ func FinishCall(ctx *db.QueryContext, def CallDef, infer InferResult) CallResult fn = infer.Function } if fn == nil { - return CallResult{Type: typ.Unknown, Errors: infer.Errors} + return singleValueCallResult(typ.Unknown, infer.Errors) } return callFunction(ctx, def.Query, fn, def.Args, infer.Receiver, infer.IsMethod, infer.ForceMethodReceiver, infer.Errors) } - return CallResult{Type: typ.Unknown, Errors: infer.Errors} + return singleValueCallResult(typ.Unknown, infer.Errors) } // ReInfer performs re-inference after arguments have been updated. @@ -642,6 +641,7 @@ func (r *InferResult) ExpectedArgType(idx int) typ.Type { // The return type is the intersection of all member return types. func callIntersection(ctx *db.QueryContext, query core.TypeOps, inter *typ.Intersection, args []typ.Type, receiver typ.Type, isMethod bool, forceMethodReceiver bool, baseErrors []CallError) CallResult { var returnTypes []typ.Type + var returnVectors [][]typ.Type for _, member := range inter.Members { if member.Kind().IsPlaceholder() { @@ -651,7 +651,8 @@ func callIntersection(ctx *db.QueryContext, query core.TypeOps, inter *typ.Inter fn, ok := member.(*typ.Function) if !ok { return CallResult{ - Type: typ.Unknown, + Type: typ.Unknown, + Returns: []typ.Type{typ.Unknown}, Errors: append(baseErrors, CallError{Kind: ErrNotCallable, Message: fmt.Sprintf("intersection member is not callable: %s", typ.FormatShort(member))}), } } @@ -663,25 +664,34 @@ func callIntersection(ctx *db.QueryContext, query core.TypeOps, inter *typ.Inter } returnTypes = append(returnTypes, result.Type) + returnVectors = append(returnVectors, normalizedCallReturns(result)) } if len(returnTypes) == 0 { - return CallResult{Type: typ.Unknown, Errors: baseErrors} + return singleValueCallResult(typ.Unknown, baseErrors) } if len(returnTypes) == 1 { - return CallResult{Type: returnTypes[0], Errors: baseErrors} + return callResultFromReturns(returnVectors[0], baseErrors) + } + + if returns, ok := intersectReturnVectors(returnVectors); ok { + return callResultFromReturns(returns, baseErrors) } - return CallResult{Type: typ.NewIntersection(returnTypes...), Errors: baseErrors} + return CallResult{ + Type: typ.NewIntersection(returnTypes...), + Returns: []typ.Type{typ.NewIntersection(returnTypes...)}, + Errors: baseErrors, + } } // callUnionWithGenericInference handles calling a union of functions where each // member may be generic. Per-member generic inference is applied before calling. // Union semantics: the call succeeds if any member succeeds. func callUnionWithGenericInference(ctx *db.QueryContext, u *typ.Union, def CallDef, isMethod bool, receiver typ.Type, forceMethodReceiver bool, baseErrors []CallError) CallResult { - var validTypes []typ.Type - var allTypes []typ.Type + var validReturns [][]typ.Type + var allReturns [][]typ.Type var hardErrors []CallError for _, member := range u.Members { @@ -698,70 +708,58 @@ func callUnionWithGenericInference(ctx *db.QueryContext, u *typ.Union, def CallD } else { result = inferAndCall(ctx, fn, def, isMethod, receiver, seedErrors) } - allTypes = append(allTypes, result.Type) + allReturns = append(allReturns, normalizedCallReturns(result)) if hasHardErrors(result.Errors[len(seedErrors):]) { hardErrors = append(hardErrors, result.Errors...) continue } - validTypes = append(validTypes, result.Type) + validReturns = append(validReturns, normalizedCallReturns(result)) } - if len(validTypes) > 0 { - return CallResult{Type: mergeReturnTypes(validTypes), Errors: baseErrors} + if len(validReturns) > 0 { + return callResultFromReturns(mergeReturnVectors(validReturns), baseErrors) } - if len(allTypes) > 0 { - return CallResult{Type: mergeReturnTypes(allTypes), Errors: uniqueCallErrors(hardErrors)} + if len(allReturns) > 0 { + return callResultFromReturns(mergeReturnVectors(allReturns), uniqueCallErrors(hardErrors)) } - return CallResult{Type: typ.Unknown, Errors: uniqueCallErrors(hardErrors)} + return singleValueCallResult(typ.Unknown, uniqueCallErrors(hardErrors)) } -// mergeReturnTypes merges multiple return types position-wise for tuples. -// If all types are tuples with the same arity, merges them position-wise: -// (A, B) and (A, C) become (A, B | C). -// Otherwise falls back to creating a union. -func mergeReturnTypes(types []typ.Type) typ.Type { - if len(types) == 0 { - return typ.Unknown +func mergeReturnVectors(vectors [][]typ.Type) []typ.Type { + if len(vectors) == 0 { + return []typ.Type{typ.Unknown} } - if len(types) == 1 { - return types[0] + if len(vectors) == 1 { + return copyTypeSlice(vectors[0]) } - // Check if all types are tuples - var tuples []*typ.Tuple maxLen := 0 - for _, t := range types { - tuple, ok := t.(*typ.Tuple) - if !ok { - return typ.NewUnion(types...) - } - tuples = append(tuples, tuple) - if len(tuple.Elements) > maxLen { - maxLen = len(tuple.Elements) + for _, returns := range vectors { + if len(returns) > maxLen { + maxLen = len(returns) } } + if maxLen == 0 { + return []typ.Type{typ.Nil} + } - // Merge position-wise merged := make([]typ.Type, maxLen) for i := 0; i < maxLen; i++ { - var posTypes []typ.Type - for _, tuple := range tuples { - if i < len(tuple.Elements) { - posTypes = append(posTypes, tuple.Elements[i]) + slotTypes := make([]typ.Type, 0, len(vectors)) + for _, returns := range vectors { + if i < len(returns) { + slotTypes = append(slotTypes, returns[i]) + } else { + slotTypes = append(slotTypes, typ.Nil) } } - if len(posTypes) == 1 { - merged[i] = posTypes[0] - } else { - merged[i] = typ.NewUnion(posTypes...) - } + merged[i] = typ.NewUnion(slotTypes...) } - - return typ.NewTuple(merged...) + return merged } func methodConsumesReceiver(ctx *db.QueryContext, query core.TypeOps, fn *typ.Function, receiver typ.Type, isMethod bool, forceMethodReceiver bool) bool { @@ -786,7 +784,7 @@ func methodConsumesReceiverSimple(fn *typ.Function, receiver typ.Type, isMethod func callFunction(ctx *db.QueryContext, query core.TypeOps, fn *typ.Function, args []typ.Type, receiver typ.Type, isMethod bool, forceMethodReceiver bool, errors []CallError) CallResult { if fn == nil { - return CallResult{Type: typ.Unknown, Errors: append(errors, CallError{Kind: ErrNotCallable, Message: "nil function"})} + return singleValueCallResult(typ.Unknown, append(errors, CallError{Kind: ErrNotCallable, Message: "nil function"})) } argCount := len(args) @@ -797,13 +795,14 @@ func callFunction(ctx *db.QueryContext, query core.TypeOps, fn *typ.Function, ar minArgs := typ.MinRequiredArgs(fn) hasVariadic := fn.Variadic != nil + allowExtraArgs := len(fn.Params) == 0 && !hasVariadic if argCount < minArgs { errors = append(errors, CallError{ Kind: ErrWrongArity, Message: "not enough arguments", }) - } else if !hasVariadic && argCount > len(fn.Params) { + } else if !hasVariadic && !allowExtraArgs && argCount > len(fn.Params) { errors = append(errors, CallError{ Kind: ErrWrongArity, Message: "too many arguments", @@ -868,14 +867,82 @@ func callFunction(ctx *db.QueryContext, query core.TypeOps, fn *typ.Function, ar } if len(returns) == 0 { - return CallResult{Type: typ.Nil, Errors: errors} + return singleValueCallResult(typ.Nil, errors) } + return callResultFromReturns(returns, errors) +} + +func normalizedCallReturns(result CallResult) []typ.Type { + if len(result.Returns) > 0 { + return copyTypeSlice(result.Returns) + } + return []typ.Type{result.Type} +} + +func callResultFromReturns(returns []typ.Type, errors []CallError) CallResult { + if len(returns) == 0 { + return singleValueCallResult(typ.Nil, errors) + } if len(returns) == 1 { - return CallResult{Type: returns[0], Errors: errors} + return CallResult{ + Type: returns[0], + Returns: copyTypeSlice(returns), + Errors: errors, + } + } + return CallResult{ + Type: typ.NewTuple(returns...), + Returns: copyTypeSlice(returns), + Errors: errors, + } +} + +func singleValueCallResult(t typ.Type, errors []CallError) CallResult { + return CallResult{ + Type: t, + Returns: []typ.Type{t}, + Errors: errors, + } +} + +func intersectReturnVectors(vectors [][]typ.Type) ([]typ.Type, bool) { + if len(vectors) == 0 { + return nil, false + } + if len(vectors) == 1 { + return copyTypeSlice(vectors[0]), true + } + + arity := len(vectors[0]) + for _, returns := range vectors[1:] { + if len(returns) != arity { + return nil, false + } } - return CallResult{Type: typ.NewTuple(returns...), Errors: errors} + merged := make([]typ.Type, arity) + for i := 0; i < arity; i++ { + slotTypes := make([]typ.Type, 0, len(vectors)) + for _, returns := range vectors { + slotTypes = append(slotTypes, returns[i]) + } + if len(slotTypes) == 1 { + merged[i] = slotTypes[0] + continue + } + merged[i] = typ.NewIntersection(slotTypes...) + } + return merged, true +} + +func copyTypeSlice(types []typ.Type) []typ.Type { + if len(types) == 0 { + return nil + } + out := make([]typ.Type, len(types)) + copy(out, types) + return out } func hasHardErrors(errors []CallError) bool { diff --git a/compiler/check/synth/ops/call_test.go b/compiler/check/synth/ops/call_test.go index b7ca6fe9..61541ad9 100644 --- a/compiler/check/synth/ops/call_test.go +++ b/compiler/check/synth/ops/call_test.go @@ -67,6 +67,23 @@ func TestCallWithGenericInference_TooManyArgs(t *testing.T) { } } +func TestCallWithGenericInference_ZeroParamAllowsExtraArgs(t *testing.T) { + fn := typ.Func(). + Returns(typ.Boolean). + Build() + + ctx := db.NewQueryContext(db.New()) + def := CallDef{ + Callee: fn, + Args: []typ.Type{typ.Integer, typ.String}, + } + + result := CallWithGenericInference(ctx, def) + if len(result.Errors) > 0 { + t.Fatalf("zero-param function should accept extra args, got: %v", result.Errors) + } +} + func TestCallWithGenericInference_Variadic(t *testing.T) { fn := typ.Func(). Param("x", typ.Integer). @@ -176,6 +193,79 @@ func TestCallWithGenericInference_MultiReturn(t *testing.T) { if tuple.Elements[1] != typ.Boolean { t.Errorf("second return should be boolean, got %v", tuple.Elements[1]) } + + if len(result.Returns) != 2 { + t.Fatalf("expected 2 return slots, got %d", len(result.Returns)) + } + if result.Returns[0] != typ.String || result.Returns[1] != typ.Boolean { + t.Fatalf("unexpected return vector: %v", result.Returns) + } +} + +func TestCallWithGenericInference_SingleTupleReturnPreservesArityOne(t *testing.T) { + tupleValue := typ.NewTuple(typ.Integer, typ.String) + fn := typ.Func(). + Returns(tupleValue). + Build() + + ctx := db.NewQueryContext(db.New()) + def := CallDef{Callee: fn} + + result := CallWithGenericInference(ctx, def) + if result.Type != tupleValue { + t.Fatalf("expected tuple-valued return, got %v", result.Type) + } + if len(result.Returns) != 1 { + t.Fatalf("expected 1 return slot, got %d", len(result.Returns)) + } + if result.Returns[0] != tupleValue { + t.Fatalf("expected first return slot to be tuple value, got %v", result.Returns[0]) + } +} + +func TestCallWithGenericInference_NestedAliasFunctionIsCallable(t *testing.T) { + fn := typ.Func(). + Param("x", typ.String). + Returns(typ.Integer). + Build() + moduleAlias := typ.NewAlias("ModuleHandler", fn) + localAlias := typ.NewAlias("Handler", moduleAlias) + + ctx := db.NewQueryContext(db.New()) + def := CallDef{ + Callee: localAlias, + Args: []typ.Type{typ.String}, + } + + result := CallWithGenericInference(ctx, def) + if len(result.Errors) > 0 { + t.Fatalf("expected nested alias callee to unwrap to function, got errors: %v", result.Errors) + } + if result.Type != typ.Integer { + t.Fatalf("expected integer return through nested aliases, got %v", result.Type) + } +} + +func TestCallWithGenericInference_UnionCarriesMergedReturnVector(t *testing.T) { + one := typ.Func().Returns(typ.String).Build() + two := typ.Func().Returns(typ.String, typ.Boolean).Build() + + ctx := db.NewQueryContext(db.New()) + def := CallDef{ + Callee: typ.NewUnion(one, two), + } + + result := CallWithGenericInference(ctx, def) + if len(result.Returns) != 2 { + t.Fatalf("expected merged arity 2, got %d", len(result.Returns)) + } + if result.Returns[0] != typ.String { + t.Fatalf("expected first merged slot string, got %v", result.Returns[0]) + } + wantSecond := typ.NewUnion(typ.Nil, typ.Boolean) + if !typ.TypeEquals(result.Returns[1], wantSecond) { + t.Fatalf("expected second merged slot %v, got %v", wantSecond, result.Returns[1]) + } } func TestCallWithGenericInference_NoReturn(t *testing.T) { @@ -525,3 +615,19 @@ func TestCallFunction_MethodAlwaysConsumesReceiver(t *testing.T) { t.Fatal("expected method call to fail when signature does not accept receiver") } } + +func TestCallFunction_ZeroParamAllowsExtraArgs(t *testing.T) { + fn := typ.Func(). + Returns(typ.Boolean). + Build() + + ctx := db.NewQueryContext(db.New()) + result := callFunction(ctx, nil, fn, []typ.Type{typ.Number, typ.String}, nil, false, false, nil) + + if len(result.Errors) != 0 { + t.Fatalf("zero-param function should accept extra args, got: %v", result.Errors) + } + if result.Type != typ.Boolean { + t.Fatalf("expected boolean return type, got: %v", result.Type) + } +} diff --git a/compiler/check/synth/ops/check.go b/compiler/check/synth/ops/check.go index e7366150..e4529c14 100644 --- a/compiler/check/synth/ops/check.go +++ b/compiler/check/synth/ops/check.go @@ -48,6 +48,10 @@ func CheckTable(fields []FieldDef, arrayElems []typ.Type, expected typ.Type) Che return CheckResult{Type: tableConstructor(fields, arrayElems)} } + if rec := unwrap.Record(expected); rec != nil { + return checkTableAsRecord(fields, arrayElems, rec) + } + if alias, ok := expected.(*typ.Alias); ok { return CheckTable(fields, arrayElems, alias.Target) } diff --git a/compiler/check/synth/ops/check_test.go b/compiler/check/synth/ops/check_test.go index 69321d9e..8e376bc4 100644 --- a/compiler/check/synth/ops/check_test.go +++ b/compiler/check/synth/ops/check_test.go @@ -208,3 +208,40 @@ func TestCheckError_Fields(t *testing.T) { t.Error("wrong field") } } + +func TestExpectedTableElementType_Array(t *testing.T) { + expected := typ.NewArray(typ.String) + got := ExpectedTableElementType(expected, 0) + if got != typ.String { + t.Fatalf("got %v, want string", got) + } +} + +func TestExpectedTableElementType_TupleUsesIndex(t *testing.T) { + expected := typ.NewTuple(typ.String, typ.Integer) + if got := ExpectedTableElementType(expected, 0); got != typ.String { + t.Fatalf("index 0 got %v, want string", got) + } + if got := ExpectedTableElementType(expected, 1); got != typ.Integer { + t.Fatalf("index 1 got %v, want integer", got) + } +} + +func TestExpectedTableElementType_UnionCollectsMembers(t *testing.T) { + expected := typ.NewUnion( + typ.NewArray(typ.String), + typ.NewTuple(typ.Integer, typ.Boolean), + ) + got := ExpectedTableElementType(expected, 0) + if !typ.TypeEquals(got, typ.NewUnion(typ.String, typ.Integer)) { + t.Fatalf("got %v, want string|integer", got) + } +} + +func TestExpectedTableElementType_NumericMap(t *testing.T) { + expected := typ.NewMap(typ.Integer, typ.Boolean) + got := ExpectedTableElementType(expected, 0) + if got != typ.Boolean { + t.Fatalf("got %v, want boolean", got) + } +} diff --git a/compiler/check/synth/ops/generic_test.go b/compiler/check/synth/ops/generic_test.go index 23235c59..2ee5be65 100644 --- a/compiler/check/synth/ops/generic_test.go +++ b/compiler/check/synth/ops/generic_test.go @@ -170,3 +170,131 @@ func TestInferTypeArgs_CannotInfer(t *testing.T) { t.Errorf("expected unknown for unresolved T, got %v", args[0]) } } + +func TestInferTypeArgs_ExpectedInstantiatedUnionReturn(t *testing.T) { + tParam := typ.NewTypeParam("T", nil) + user := typ.NewRecord(). + Field("id", typ.String). + Field("email", typ.String). + Build() + resultGeneric := typ.NewGeneric("Result", []*typ.TypeParam{tParam}, + typ.NewUnion( + typ.NewRecord(). + Field("ok", typ.True). + Field("value", tParam). + Build(), + typ.NewRecord(). + Field("ok", typ.False). + Field("error", typ.String). + Build(), + ), + ) + fn := typ.Func(). + TypeParam("T", nil). + Param("message", typ.String). + Returns(typ.Instantiate(resultGeneric, tParam)). + Build() + + typeArgs, err := InferTypeArgsWithExpectedAndMode( + fn, + []typ.Type{typ.String}, + false, + nil, + typ.Instantiate(resultGeneric, user), + false, + ) + if err != nil { + t.Fatalf("InferTypeArgs error: %v", err) + } + if len(typeArgs) != 1 { + t.Fatalf("type args len = %d, want 1", len(typeArgs)) + } + if !typ.TypeEquals(typeArgs[0], user) { + t.Fatalf("T = %v, want %v", typeArgs[0], user) + } +} + +func TestInferTypeArgs_FunctionParamWithInstantiatedUnionReturn(t *testing.T) { + tParam := typ.NewTypeParam("T", nil) + uParam := typ.NewTypeParam("U", nil) + user := typ.NewRecord(). + Field("id", typ.String). + Field("email", typ.String). + Build() + resultGeneric := typ.NewGeneric("Result", []*typ.TypeParam{tParam}, + typ.NewUnion( + typ.NewRecord(). + Field("ok", typ.True). + Field("value", tParam). + Build(), + typ.NewRecord(). + Field("ok", typ.False). + Field("error", typ.String). + Build(), + ), + ) + fn := typ.Func(). + TypeParam("T", nil). + TypeParam("U", nil). + Param("r", typ.Instantiate(resultGeneric, tParam)). + Param("mapper", typ.Func(). + Param("value", tParam). + Returns(typ.Instantiate(resultGeneric, uParam)). + Build()). + Returns(typ.Instantiate(resultGeneric, uParam)). + Build() + callback := typ.Func(). + Param("value", user). + Returns(typ.Instantiate(resultGeneric, user)). + Build() + + typeArgs, err := InferTypeArgsWithExpectedAndMode( + fn, + []typ.Type{ + typ.Instantiate(resultGeneric, user), + callback, + }, + false, + nil, + nil, + false, + ) + if err != nil { + t.Fatalf("InferTypeArgs error: %v", err) + } + if len(typeArgs) != 2 { + t.Fatalf("type args len = %d, want 2", len(typeArgs)) + } + if !typ.TypeEquals(typeArgs[0], user) { + t.Fatalf("T = %v, want %v", typeArgs[0], user) + } + if !typ.TypeEquals(typeArgs[1], user) { + t.Fatalf("U = %v, want %v", typeArgs[1], user) + } +} + +func TestInferTypeArgs_ExpectedExplicitUnionPrefersSpecificMember(t *testing.T) { + tParam := typ.NewTypeParam("T", nil) + fn := typ.Func(). + TypeParam("T", nil). + Returns(typ.NewUnion(tParam, typ.Nil)). + Build() + + typeArgs, err := InferTypeArgsWithExpectedAndMode( + fn, + nil, + false, + nil, + typ.NewUnion(typ.String, typ.Nil), + false, + ) + if err != nil { + t.Fatalf("InferTypeArgs error: %v", err) + } + if len(typeArgs) != 1 { + t.Fatalf("type args len = %d, want 1", len(typeArgs)) + } + if !typ.TypeEquals(typeArgs[0], typ.String) { + t.Fatalf("T = %v, want string", typeArgs[0]) + } +} diff --git a/compiler/check/synth/ops/table_expected.go b/compiler/check/synth/ops/table_expected.go new file mode 100644 index 00000000..d8b6e634 --- /dev/null +++ b/compiler/check/synth/ops/table_expected.go @@ -0,0 +1,84 @@ +package ops + +import ( + "github.com/wippyai/go-lua/types/kind" + querycore "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/types/typ" +) + +// ExpectedTableElementType returns the contextual expected type for the element +// at index in a table literal used in array-like position. +func ExpectedTableElementType(expected typ.Type, index int) typ.Type { + return expectedTableElementType(expected, index) +} + +func expectedTableElementType(expected typ.Type, index int) typ.Type { + if expected == nil { + return nil + } + + expected = typ.UnwrapAnnotated(expected) + + switch v := expected.(type) { + case *typ.Alias: + return expectedTableElementType(v.Target, index) + case *typ.Optional: + return expectedTableElementType(v.Inner, index) + case *typ.Instantiated: + if resolved, err := querycore.ResolveInstantiated(v); err == nil { + return expectedTableElementType(resolved, index) + } + return nil + case *typ.Array: + return v.Element + case *typ.Tuple: + if index < 0 || index >= len(v.Elements) { + return nil + } + return v.Elements[index] + case *typ.Map: + if v.Key == nil { + return nil + } + switch v.Key.Kind() { + case kind.Integer, kind.Number: + return v.Value + default: + return nil + } + case *typ.Record: + if !v.HasMapComponent() || v.MapKey == nil { + return nil + } + switch v.MapKey.Kind() { + case kind.Integer, kind.Number: + return v.MapValue + default: + return nil + } + case *typ.Union: + var members []typ.Type + for _, member := range v.Members { + if elem := expectedTableElementType(member, index); elem != nil { + members = append(members, elem) + } + } + if len(members) == 0 { + return nil + } + return typ.NewUnion(members...) + case *typ.Intersection: + var members []typ.Type + for _, member := range v.Members { + if elem := expectedTableElementType(member, index); elem != nil { + members = append(members, elem) + } + } + if len(members) == 0 { + return nil + } + return typ.NewIntersection(members...) + default: + return nil + } +} diff --git a/compiler/check/synth/phase/extract/call.go b/compiler/check/synth/phase/extract/call.go index c8eb4737..e5a4a904 100644 --- a/compiler/check/synth/phase/extract/call.go +++ b/compiler/check/synth/phase/extract/call.go @@ -116,11 +116,23 @@ func (s *Synthesizer) GetCallQuery() core.TypeOps { // // For method calls (obj:method()), dispatches to synthMethodCallCoreWithExpected. func (s *Synthesizer) SynthCallCore(ex *ast.FuncCallExpr, p cfg.Point, sc *scope.State, narrower api.FlowOps, recurse ExprSynth) []typ.Type { - return s.synthCallCoreWithNarrower(ex, p, sc, narrower, recurse, nil) + return s.synthCallCoreWithCaptureTypes(ex, p, sc, narrower, recurse, nil, nil) } // synthCallCoreWithNarrower synthesizes call with narrower context preserved. func (s *Synthesizer) synthCallCoreWithNarrower(ex *ast.FuncCallExpr, p cfg.Point, sc *scope.State, narrower api.FlowOps, recurse ExprSynth, expected typ.Type) []typ.Type { + return s.synthCallCoreWithCaptureTypes(ex, p, sc, narrower, recurse, expected, nil) +} + +func (s *Synthesizer) synthCallCoreWithCaptureTypes( + ex *ast.FuncCallExpr, + p cfg.Point, + sc *scope.State, + narrower api.FlowOps, + recurse ExprSynth, + expected typ.Type, + captureTypes map[cfg.SymbolID]typ.Type, +) []typ.Type { if callsite.IsMethodLikeExpr(ex) { return s.synthMethodCallCoreWithExpected(ex, p, sc, recurse, expected) } @@ -137,6 +149,9 @@ func (s *Synthesizer) synthCallCoreWithNarrower(ex *ast.FuncCallExpr, p cfg.Poin } calleeType := recurse(ex.Func) + if specialized := s.specializedLocalFunctionCalleeType(ex, p, sc, calleeType, captureTypes); specialized != nil { + calleeType = specialized + } args := synthArgs(ex.Args, recurse) typeArgs := s.resolveTypeArgs(ex.TypeArgs, sc) @@ -164,6 +179,47 @@ func (s *Synthesizer) synthCallCoreWithNarrower(ex *ast.FuncCallExpr, p cfg.Poin return intercept.ApplyOverride(returns, specOverride) } +func (s *Synthesizer) specializedLocalFunctionCalleeType( + ex *ast.FuncCallExpr, + p cfg.Point, + sc *scope.State, + current typ.Type, + captureTypes map[cfg.SymbolID]typ.Type, +) typ.Type { + if s == nil || ex == nil || s.deps.CheckCtx == nil { + return nil + } + if specialized := s.stableLocalFunctionValueType(ex.Func, p, sc, current, captureTypes); specialized != nil { + return specialized + } + graph, ok := s.deps.CheckCtx.Graph().(*compcfg.Graph) + if !ok || graph == nil { + return nil + } + bindings := graph.Bindings() + if bindings == nil { + bindings = s.deps.ModuleBindings + } + if bindings == nil { + return nil + } + info := graph.CallSiteAt(p, ex) + if info == nil { + return nil + } + for _, sym := range callsite.CallableCalleeSymbolCandidates(info, graph, bindings, nil) { + fn := callsite.FunctionLiteralForGraphSymbol(graph, sym) + if fn == nil { + continue + } + expectedFn, _ := unwrap.Optional(unwrap.Alias(current)).(*typ.Function) + if fnType := s.synthFunctionTypeWithCapturePoint(fn, sc, expectedFn, p, captureTypes); fnType != nil { + return fnType + } + } + return nil +} + // SynthCallCoreWithExpected synthesizes call with optional expected return type for generic inference. func (s *Synthesizer) SynthCallCoreWithExpected(ex *ast.FuncCallExpr, p cfg.Point, sc *scope.State, recurse ExprSynth, expected typ.Type) []typ.Type { return s.synthCallCoreWithNarrower(ex, p, sc, nil, recurse, expected) @@ -390,6 +446,9 @@ func (s *Synthesizer) specReturnOverride(fnType typ.Type, astArgs []ast.Expr, ar // unwrapCallResult converts CallResult to a slice of types. func unwrapCallResult(result ops.CallResult) []typ.Type { + if len(result.Returns) > 0 { + return CopyTypes(result.Returns) + } if tuple, ok := result.Type.(*typ.Tuple); ok { return CopyTypes(tuple.Elements) } @@ -494,17 +553,19 @@ func (s *Synthesizer) withEnvOverlay(overlay map[string]typ.Type) *Synthesizer { overlaidCtx = overlaidCtx.WithGlobalOverlay(overlay) } overlaidDeps := &Deps{ - Ctx: s.deps.Ctx, - Types: s.deps.Types, - Scopes: s.deps.Scopes, - Manifests: s.deps.Manifests, - CheckCtx: overlaidCtx, - Flow: s.deps.Flow, - Paths: s.deps.Paths, - PreCache: make(api.Cache), - NarrowCache: make(api.Cache), - ModuleBindings: s.deps.ModuleBindings, - ModuleAliases: s.deps.ModuleAliases, + Ctx: s.deps.Ctx, + Types: s.deps.Types, + Scopes: s.deps.Scopes, + Manifests: s.deps.Manifests, + CheckCtx: overlaidCtx, + Graphs: s.deps.Graphs, + Flow: s.deps.Flow, + Paths: s.deps.Paths, + PreCache: make(api.Cache), + NarrowCache: make(api.Cache), + FunctionTypeInProgress: s.deps.FunctionTypeInProgress, + ModuleBindings: s.deps.ModuleBindings, + ModuleAliases: s.deps.ModuleAliases, } return NewSynthesizer(overlaidDeps, s.phase) } diff --git a/compiler/check/synth/phase/extract/deps.go b/compiler/check/synth/phase/extract/deps.go index 32201424..53ab09c0 100644 --- a/compiler/check/synth/phase/extract/deps.go +++ b/compiler/check/synth/phase/extract/deps.go @@ -1,6 +1,7 @@ package extract import ( + "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/scope" @@ -31,6 +32,7 @@ type Deps struct { DefaultScope *scope.State Manifests io.ManifestQuerier CheckCtx api.BaseEnv + Graphs api.GraphProvider Flow api.FlowOps Paths api.PathFromExprFunc @@ -38,6 +40,10 @@ type Deps struct { PreCache api.Cache NarrowCache api.Cache + // FunctionTypeInProgress guards call-point local function specialization + // against recursion across temporary synthesizers. + FunctionTypeInProgress map[functionTypeProgressKey]bool + // Module-level bindings for nested function CFG building. ModuleBindings *bind.BindingTable @@ -45,16 +51,23 @@ type Deps struct { ModuleAliases map[cfg.SymbolID]string } +type functionTypeProgressKey struct { + Func *ast.FunctionExpr + CapturePoint cfg.Point +} + // NewDeps creates a new Deps instance. func NewDeps(ctx *db.QueryContext, types core.TypeOps, scopes api.ScopeMap, manifests io.ManifestQuerier, checkCtx api.BaseEnv) *Deps { return &Deps{ - Ctx: ctx, - Types: types, - Scopes: scopes, - Manifests: manifests, - CheckCtx: checkCtx, - PreCache: make(api.Cache), - NarrowCache: make(api.Cache), + Ctx: ctx, + Types: types, + Scopes: scopes, + Manifests: manifests, + CheckCtx: checkCtx, + Graphs: api.GraphsFrom(ctx), + PreCache: make(api.Cache), + NarrowCache: make(api.Cache), + FunctionTypeInProgress: make(map[functionTypeProgressKey]bool), } } diff --git a/compiler/check/synth/phase/extract/expr.go b/compiler/check/synth/phase/extract/expr.go index 1eb3f75e..c6d54a3a 100644 --- a/compiler/check/synth/phase/extract/expr.go +++ b/compiler/check/synth/phase/extract/expr.go @@ -100,16 +100,12 @@ func (s *Synthesizer) synthAttrGetCore(ex *ast.AttrGetExpr, p cfg.Point, sc *sco if !path.IsEmpty() { narrowed := narrower.NarrowedTypeAt(p, path) if narrowed != nil { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, narrowed, nil); specialized != nil { + return specialized + } if typ.IsUnknown(unwrap.Alias(narrowed)) && typ.IsAny(unwrap.Alias(objType)) { goto skipNarrowedAttr } - if _, isStringKey := ex.Key.(*ast.StringExpr); isStringKey { - if mapValueType(objType) != nil { - if opt, ok := narrowed.(*typ.Optional); ok { - return opt.Inner - } - } - } return narrowed } } @@ -130,22 +126,34 @@ skipNarrowedAttr: case *ast.StringExpr: if ft, ok := s.deps.Types.Field(s.deps.Ctx, objType, key.Value); ok { if manifestPath != "" { - return enrichWithManifest(s.deps.Manifests, ft, manifestPath, key.Value) + ft = enrichWithManifest(s.deps.Manifests, ft, manifestPath, key.Value) + } + if specialized := s.stableLocalFunctionValueType(ex, p, sc, ft, nil); specialized != nil { + return specialized } return ft } if ft := fieldOnPartialUnion(objType, key.Value, s.deps.Types, s.deps.Ctx); ft != nil { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, ft, nil); specialized != nil { + return specialized + } return ft } if vt := mapValueType(objType); vt != nil { return vt } if it, ok := s.deps.Types.Index(s.deps.Ctx, objType, typ.LiteralString(key.Value)); ok { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, it, nil); specialized != nil { + return specialized + } return it } case *ast.NumberExpr: keyType := ops.ParseNumber(key.Value) if it, ok := s.deps.Types.Index(s.deps.Ctx, objType, keyType); ok { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, it, nil); specialized != nil { + return specialized + } return it } case *ast.IdentExpr: @@ -194,6 +202,9 @@ skipNarrowedAttr: } } } + if specialized := s.stableLocalFunctionValueType(ex, p, sc, it, nil); specialized != nil { + return specialized + } return it } if derived := s.indexFromKeyOf(objType, ex.Object, key, p, sc, narrower); derived != nil { @@ -209,6 +220,9 @@ skipNarrowedAttr: } } } + if specialized := s.stableLocalFunctionValueType(ex, p, sc, it, nil); specialized != nil { + return specialized + } return it } } @@ -556,6 +570,13 @@ func (s *Synthesizer) synthExprWithSpec(expr ast.Expr, p cfg.Point, specTypes ap if expr == nil { return typ.Nil } + if call, ok := expr.(*ast.FuncCallExpr); ok { + multi := s.synthMultiWithSpec(call, p, specTypes) + if len(multi) == 0 || multi[0] == nil { + return typ.Unknown + } + return multi[0] + } if ident, ok := expr.(*ast.IdentExpr); ok { if sym := s.LookupSymbol(ident); sym != 0 { if t, exists := specTypes[sym]; exists { @@ -583,7 +604,7 @@ func (s *Synthesizer) synthMultiWithSpec(expr ast.Expr, p cfg.Point, specTypes a } } } - return s.SynthCallCore(call, p, sc, nil, recurse) + return s.synthCallCoreWithCaptureTypes(call, p, sc, nil, recurse, nil, specTypes) }, ) } diff --git a/compiler/check/synth/phase/extract/function.go b/compiler/check/synth/phase/extract/function.go index ba191dc9..ed3f9293 100644 --- a/compiler/check/synth/phase/extract/function.go +++ b/compiler/check/synth/phase/extract/function.go @@ -40,8 +40,11 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/erreffect" + "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" + "github.com/wippyai/go-lua/compiler/check/overlaymut" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth/phase/core" + "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/contract" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/typ" @@ -72,9 +75,48 @@ func (s *Synthesizer) FunctionType(fn *ast.FunctionExpr, sc *scope.State) *typ.F // // If fn is nil, returns nil. If scope is nil, returns an empty function type. func (s *Synthesizer) SynthFunctionTypeWithExpected(fn *ast.FunctionExpr, sc *scope.State, expected *typ.Function) *typ.Function { + return s.synthFunctionTypeWithCapturePoint(fn, sc, expected, 0, nil) +} + +func (s *Synthesizer) getOrBuildFunctionGraph(fn *ast.FunctionExpr) *cfg.Graph { + if fn == nil { + return nil + } + if s.deps.CheckCtx != nil { + if g, ok := s.deps.CheckCtx.Graph().(*cfg.Graph); ok && g != nil && g.Func() == fn { + return g + } + } + if s.deps.Graphs != nil { + if g := s.deps.Graphs.GetOrBuildCFG(fn); g != nil { + return g + } + } + if s.deps.ModuleBindings != nil { + return cfg.BuildWithBindings(fn, s.deps.ModuleBindings) + } + return cfg.Build(fn) +} + +func (s *Synthesizer) synthFunctionTypeWithCapturePoint( + fn *ast.FunctionExpr, + sc *scope.State, + expected *typ.Function, + capturePoint cfg.Point, + captureTypes map[cfg.SymbolID]typ.Type, +) *typ.Function { if fn == nil { return nil } + if s.deps.FunctionTypeInProgress == nil { + s.deps.FunctionTypeInProgress = make(map[functionTypeProgressKey]bool) + } + progressKey := functionTypeProgressKey{Func: fn, CapturePoint: capturePoint} + if s.deps.FunctionTypeInProgress[progressKey] { + return s.buildFunctionTypeSummaryFallback(fn, sc, expected) + } + s.deps.FunctionTypeInProgress[progressKey] = true + defer delete(s.deps.FunctionTypeInProgress, progressKey) builder := typ.Func() @@ -117,17 +159,8 @@ func (s *Synthesizer) SynthFunctionTypeWithExpected(fn *ast.FunctionExpr, sc *sc // Build CFG once, shared between overlay inference and return inference. var fnGraph *cfg.Graph - if s.deps.CheckCtx != nil { - if g, ok := s.deps.CheckCtx.Graph().(*cfg.Graph); ok && g != nil && g.Func() == fn { - fnGraph = g - } - } - if fnGraph == nil && fn.Stmts != nil && len(fn.Stmts) > 0 { - if s.deps.ModuleBindings != nil { - fnGraph = cfg.BuildWithBindings(fn, s.deps.ModuleBindings) - } else { - fnGraph = cfg.Build(fn) - } + if fn.Stmts != nil && len(fn.Stmts) > 0 { + fnGraph = s.getOrBuildFunctionGraph(fn) } // Infer callback env overlays (runs before return types). @@ -140,7 +173,7 @@ func (s *Synthesizer) SynthFunctionTypeWithExpected(fn *ast.FunctionExpr, sc *sc returns := s.ResolveReturnTypes(fn.ReturnTypes, resolveScope) builder = builder.Returns(returns...) } else { - if bodyReturns, hasErrorReturn := s.inferReturnTypesFromBody(fn, resolveScope, expected, fnGraph); len(bodyReturns) > 0 { + if bodyReturns, hasErrorReturn := s.inferReturnTypesFromBody(fn, resolveScope, expected, fnGraph, capturePoint, captureTypes); len(bodyReturns) > 0 { inferredErrorReturn = hasErrorReturn if expected != nil && len(expected.Returns) > 0 { if typ.IsUnknownOnlyOrEmpty(bodyReturns) { @@ -167,6 +200,8 @@ func (s *Synthesizer) inferReturnTypesFromBody( parentScope *scope.State, expected *typ.Function, fnGraph *cfg.Graph, + capturePoint cfg.Point, + captureTypes map[cfg.SymbolID]typ.Type, ) ([]typ.Type, bool) { if len(fn.Stmts) == 0 { return nil, false @@ -200,7 +235,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( if rt := returnSummaries[fnSym]; len(rt) > 0 { if typ.HasKnownType(rt) { summaryFallback = rt - if !s.IsNarrowing() { + if !s.IsNarrowing() && capturePoint == 0 && len(captureTypes) == 0 { return rt, false } } @@ -208,18 +243,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( } if fnGraph == nil { - if s.deps.CheckCtx != nil { - if g, ok := s.deps.CheckCtx.Graph().(*cfg.Graph); ok && g != nil && g.Func() == fn { - fnGraph = g - } - } - } - if fnGraph == nil { - if s.deps.ModuleBindings != nil { - fnGraph = cfg.BuildWithBindings(fn, s.deps.ModuleBindings) - } else { - fnGraph = cfg.Build(fn) - } + fnGraph = s.getOrBuildFunctionGraph(fn) } if fnGraph == nil { return nil, false @@ -265,9 +289,11 @@ func (s *Synthesizer) inferReturnTypesFromBody( // This allows nested local functions to call sibling locals defined in the parent scope. if s.deps.CheckCtx != nil { if types := s.deps.CheckCtx.Types(); types != nil { - p := cfg.Point(0) + p := capturePoint if g := s.deps.CheckCtx.Graph(); g != nil { - p = g.Entry() + if p == 0 { + p = g.Entry() + } } if bindings := fnGraph.Bindings(); bindings != nil { for _, sym := range bindings.CapturedSymbols(fn) { @@ -277,7 +303,17 @@ func (s *Synthesizer) inferReturnTypesFromBody( if _, ok := overlay[sym]; ok { continue } - if tv := types.DeclaredAt(p, sym); tv.State == flow.StateResolved && tv.Type != nil { + if t := captureTypes[sym]; t != nil { + overlay[sym] = t + continue + } + if solution := s.deps.CheckCtx.Consts(); solution != nil { + if t := solution.TypeAt(p, constraint.Path{Symbol: sym}); t != nil { + overlay[sym] = t + continue + } + } + if tv := types.EffectiveTypeAt(p, sym); tv.State == flow.StateResolved && tv.Type != nil { overlay[sym] = tv.Type } } @@ -367,16 +403,18 @@ func (s *Synthesizer) inferReturnTypesFromBody( }) prelimDeps := &Deps{ - Ctx: s.deps.Ctx, - Types: s.deps.Types, - DefaultScope: resolveScope, - Manifests: s.deps.Manifests, - CheckCtx: prelimCtx, - PreCache: make(api.Cache), - NarrowCache: make(api.Cache), - ModuleBindings: s.deps.ModuleBindings, - ModuleAliases: moduleAliases, - Paths: s.deps.Paths, + Ctx: s.deps.Ctx, + Types: s.deps.Types, + DefaultScope: resolveScope, + Manifests: s.deps.Manifests, + CheckCtx: prelimCtx, + Graphs: s.deps.Graphs, + PreCache: make(api.Cache), + NarrowCache: make(api.Cache), + FunctionTypeInProgress: s.deps.FunctionTypeInProgress, + ModuleBindings: s.deps.ModuleBindings, + ModuleAliases: moduleAliases, + Paths: s.deps.Paths, } prelimSynth = NewSynthesizer(prelimDeps, s.phase) return prelimSynth @@ -463,6 +501,37 @@ func (s *Synthesizer) inferReturnTypesFromBody( } } + // Apply the same direct mutation enrichment law used by return inference: + // returned locals must reflect visible field/index/direct container writes + // before return expressions are synthesized. + mutationBindings := fnGraph.Bindings() + if mutationBindings == nil { + mutationBindings = s.deps.ModuleBindings + } + if mutationBindings != nil { + enrichedSynth := func(expr ast.Expr, p cfg.Point) typ.Type { + if ident, ok := expr.(*ast.IdentExpr); ok { + if sym, found := mutationBindings.SymbolOf(ident); found && sym != 0 { + if t := overlay[sym]; t != nil { + return t + } + } + } + return ensurePrelimSynth().SynthExpr(expr, p, nil) + } + + fieldAssignments := overlaymut.CollectFieldAssignments(fnGraph, enrichedSynth, nil) + overlaymut.ApplyFieldMergeToOverlay(overlay, fieldAssignments) + + indexerAssignments := overlaymut.CollectIndexerAssignments(fnGraph, enrichedSynth, mutationBindings, nil) + tableMutations := mutator.CollectTableInsertMutations(fnGraph, enrichedSynth, mutationBindings) + mutator.MergeIndexerMutations(indexerAssignments, tableMutations) + overlaymut.ApplyIndexerMergeToOverlay(overlay, indexerAssignments) + + directMutations := mutator.CollectTableInsertOnDirect(fnGraph, enrichedSynth, mutationBindings) + overlaymut.ApplyDirectMutationsToOverlay(overlay, directMutations) + } + // Phase 2: build final context with enriched overlay for return inference. fnCheckCtx := api.NewReturnInferenceEnv(api.ReturnInferenceEnvConfig{ Graph: fnGraph, @@ -474,19 +543,21 @@ func (s *Synthesizer) inferReturnTypesFromBody( }) tempDeps := &Deps{ - Ctx: s.deps.Ctx, - Types: s.deps.Types, - DefaultScope: resolveScope, - Manifests: s.deps.Manifests, - CheckCtx: fnCheckCtx, - PreCache: make(api.Cache), - NarrowCache: make(api.Cache), - ModuleBindings: s.deps.ModuleBindings, - ModuleAliases: moduleAliases, - Paths: s.deps.Paths, - } - if s.IsNarrowing() && s.deps.Flow != nil { - if fnCheckCtx != nil && fnCheckCtx.Graph() == fnGraph { + Ctx: s.deps.Ctx, + Types: s.deps.Types, + DefaultScope: resolveScope, + Manifests: s.deps.Manifests, + CheckCtx: fnCheckCtx, + Graphs: s.deps.Graphs, + PreCache: make(api.Cache), + NarrowCache: make(api.Cache), + FunctionTypeInProgress: s.deps.FunctionTypeInProgress, + ModuleBindings: s.deps.ModuleBindings, + ModuleAliases: moduleAliases, + Paths: s.deps.Paths, + } + if s.IsNarrowing() && s.deps.Flow != nil && s.deps.CheckCtx != nil { + if currentGraph, ok := s.deps.CheckCtx.Graph().(*cfg.Graph); ok && currentGraph == fnGraph { tempDeps.Flow = s.deps.Flow } } @@ -711,6 +782,43 @@ func (s *Synthesizer) buildFunctionTypeWithSummary( return join.WithReturnsOrUnknown(sig, returnTypes) } +func (s *Synthesizer) buildFunctionTypeSummaryFallback( + fn *ast.FunctionExpr, + sc *scope.State, + expected *typ.Function, +) *typ.Function { + if fn == nil { + return nil + } + sig := s.ResolveFunctionSignature(fn, sc) + if sig == nil { + return nil + } + if expected != nil && len(sig.Returns) == 0 && len(expected.Returns) > 0 { + sig = join.WithReturns(sig, expected.Returns) + } + var summaries map[cfg.SymbolID][]typ.Type + if s.deps.CheckCtx != nil { + if s.IsNarrowing() { + if ctx, ok := s.deps.CheckCtx.(api.NarrowEnv); ok { + summaries = ctx.NarrowReturnSummaries() + } + } else if ctx, ok := s.deps.CheckCtx.(api.DeclaredEnv); ok { + summaries = ctx.ReturnSummaries() + } + } + var fnSym cfg.SymbolID + if s.deps.CheckCtx != nil { + if pg, ok := s.deps.CheckCtx.Graph().(*cfg.Graph); ok && pg != nil { + fnSym = localFunctionSymbol(pg, fn) + } + } + if fnSym != 0 { + return join.WithReturnsOrUnknown(sig, summaries[fnSym]) + } + return join.WithReturnsOrUnknown(sig, nil) +} + func (s *Synthesizer) buildParamOverlay(fnGraph *cfg.Graph, sc *scope.State, expected *typ.Function) map[cfg.SymbolID]typ.Type { paramSlots := fnGraph.ParamSlotsReadOnly() overlay := make(map[cfg.SymbolID]typ.Type, len(paramSlots)) @@ -786,6 +894,7 @@ func (s *Synthesizer) inferCallbackOverlaySpec( DefaultScope: sc, Manifests: s.deps.Manifests, CheckCtx: fnCheckCtx, + Graphs: s.deps.Graphs, PreCache: make(api.Cache), NarrowCache: make(api.Cache), ModuleBindings: s.deps.ModuleBindings, diff --git a/compiler/check/synth/phase/extract/function_test.go b/compiler/check/synth/phase/extract/function_test.go index 19904a78..2131e4a5 100644 --- a/compiler/check/synth/phase/extract/function_test.go +++ b/compiler/check/synth/phase/extract/function_test.go @@ -4,10 +4,23 @@ import ( "testing" "github.com/wippyai/go-lua/compiler/ast" + ccfg "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/scope" + "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/typ" ) +type countingGraphProvider struct { + graph *ccfg.Graph + calls int +} + +func (p *countingGraphProvider) GetOrBuildCFG(*ast.FunctionExpr) *ccfg.Graph { + p.calls++ + return p.graph +} + func TestSynthFunctionType_Nil(t *testing.T) { s := newTestSynthesizer() result := s.FunctionType(nil, nil) @@ -178,6 +191,33 @@ func TestSynthFunctionTypeWithExpected_VariadicInference(t *testing.T) { } } +func TestSynthFunctionType_UsesAttachedGraphProvider(t *testing.T) { + fn := &ast.FunctionExpr{ + Stmts: []ast.Stmt{ + &ast.ReturnStmt{Exprs: []ast.Expr{&ast.StringExpr{Value: "ok"}}}, + }, + } + provider := &countingGraphProvider{graph: ccfg.Build(fn)} + ctx := db.NewQueryContext(db.New()) + api.AttachGraphs(ctx, provider) + + s := NewSynthesizer(&Deps{ + Ctx: ctx, + Types: mockTypeQuerier{}, + Scopes: make(api.ScopeMap), + Graphs: provider, + PreCache: make(api.Cache), + }, api.PhaseTypeResolution) + + result := s.FunctionType(fn, scope.New()) + if result == nil { + t.Fatal("expected non-nil function") + } + if provider.calls == 0 { + t.Fatal("expected function synthesis to use attached graph provider") + } +} + func TestSynthFunctionType_TypeParams(t *testing.T) { s := newTestSynthesizer() sc := scope.New() @@ -297,7 +337,7 @@ func TestReturnInference_ArityUnion_ErrorObject(t *testing.T) { }, } - result, _ := s.inferReturnTypesFromBody(fn, sc, nil, nil) + result, _ := s.inferReturnTypesFromBody(fn, sc, nil, nil, 0, nil) if len(result) != 2 { t.Fatalf("got %d return types, want 2", len(result)) } @@ -353,7 +393,7 @@ func TestReturnInference_LastExprExpands(t *testing.T) { }, } - result, _ := s.inferReturnTypesFromBody(fn, sc, nil, nil) + result, _ := s.inferReturnTypesFromBody(fn, sc, nil, nil, 0, nil) if len(result) != 2 { t.Fatalf("got %d return types, want 2", len(result)) } @@ -395,7 +435,7 @@ func TestReturnInference_ZeroReturn(t *testing.T) { }, } - result, _ := s.inferReturnTypesFromBody(fn, sc, nil, nil) + result, _ := s.inferReturnTypesFromBody(fn, sc, nil, nil, 0, nil) if len(result) != 1 { t.Fatalf("got %d return types, want 1", len(result)) } diff --git a/compiler/check/synth/phase/extract/named_function.go b/compiler/check/synth/phase/extract/named_function.go index 31b82b4f..2721b2f9 100644 --- a/compiler/check/synth/phase/extract/named_function.go +++ b/compiler/check/synth/phase/extract/named_function.go @@ -3,7 +3,15 @@ package extract import ( "github.com/wippyai/go-lua/compiler/ast" compcfg "github.com/wippyai/go-lua/compiler/cfg" + cfganalysis "github.com/wippyai/go-lua/compiler/cfg/analysis" + "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" + "github.com/wippyai/go-lua/compiler/check/scope" + "github.com/wippyai/go-lua/types/cfg" + "github.com/wippyai/go-lua/types/flow" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" ) // functionLiteralForIdent resolves an identifier to its underlying function @@ -54,3 +62,219 @@ func (s *Synthesizer) functionLiteralForIdent(ident *ast.IdentExpr) *ast.Functio return nil } + +// graphLocalFunctionLiteralForExpr resolves an expression to a graph-local stable +// function literal when one exists. +// +// Canonical boundary: +// - include alias-expanded graph-local function definitions and local identifier +// assignments of function literals +// - exclude mutable field-path symbols, which must continue to read their +// current callable type from value flow +func (s *Synthesizer) graphLocalFunctionForExpr(expr ast.Expr) (compcfg.SymbolID, *ast.FunctionExpr, bool) { + if expr == nil || s == nil || s.deps.CheckCtx == nil { + return 0, nil, false + } + + graph, ok := s.deps.CheckCtx.Graph().(*compcfg.Graph) + if !ok || graph == nil { + return 0, nil, false + } + + bindings := graph.Bindings() + if bindings == nil { + bindings = s.deps.ModuleBindings + } + moduleBindings := s.deps.ModuleBindings + + hasGraphLocalLiteral := func(sym compcfg.SymbolID) bool { + return callsite.FunctionLiteralForGraphSymbol(graph, sym) != nil + } + + raw := callsite.SymbolFromExpr(expr, bindings) + if raw == 0 && moduleBindings != nil && moduleBindings != bindings { + raw = callsite.SymbolFromExpr(expr, moduleBindings) + } + + sym := callsite.CanonicalSymbolFromExprWithAliases( + expr, + raw, + graph, + bindings, + moduleBindings, + hasGraphLocalLiteral, + ) + if sym == 0 { + return 0, nil, false + } + + fn := callsite.FunctionLiteralForGraphSymbol(graph, sym) + if fn == nil { + return 0, nil, false + } + + captureBindings := bindings + if captureBindings == nil { + captureBindings = moduleBindings + } + hasCaptures := captureBindings != nil && len(captureBindings.CapturedSymbols(fn)) > 0 + + return sym, fn, hasCaptures +} + +func (s *Synthesizer) graphLocalFunctionLiteralForExpr(expr ast.Expr) *ast.FunctionExpr { + _, fn, _ := s.graphLocalFunctionForExpr(expr) + return fn +} + +func (s *Synthesizer) hasDominatingDirectFunctionRebind(sym compcfg.SymbolID, stableFn *ast.FunctionExpr, p cfg.Point) bool { + if s == nil || sym == 0 || stableFn == nil || s.deps == nil || s.deps.CheckCtx == nil { + return false + } + + graph, ok := s.deps.CheckCtx.Graph().(*compcfg.Graph) + if !ok || graph == nil { + return false + } + + idom, _ := cfganalysis.ComputeDominators(graph.CFG()) + rebound := false + + graph.EachAssign(func(assignPoint cfg.Point, info *compcfg.AssignInfo) { + if rebound || info == nil || assignPoint == p || !cfganalysis.StrictlyDominates(idom, assignPoint, p) { + return + } + + info.EachTarget(func(_ int, target compcfg.AssignTarget) { + if rebound || target.Symbol != sym { + return + } + if target.Kind == compcfg.TargetField || target.Kind == compcfg.TargetIndex { + rebound = true + } + }) + }) + + if rebound { + return true + } + + graph.EachFuncDef(func(defPoint cfg.Point, info *compcfg.FuncDefInfo) { + if rebound || info == nil || info.Symbol != sym || info.FuncExpr == nil || info.FuncExpr == stableFn { + return + } + if !cfganalysis.StrictlyDominates(idom, defPoint, p) { + return + } + if info.TargetKind == compcfg.FuncDefField || info.TargetKind == compcfg.FuncDefGlobal { + rebound = true + } + }) + + return rebound +} + +func (s *Synthesizer) expectedGraphLocalFunctionValueType( + expr ast.Expr, + p cfg.Point, + sc *scope.State, + expected *typ.Function, + captureTypes map[cfg.SymbolID]typ.Type, +) typ.Type { + if s == nil || expected == nil { + return nil + } + + sym, fn, _ := s.graphLocalFunctionForExpr(expr) + if fn == nil { + return nil + } + if s.hasDominatingDirectFunctionRebind(sym, fn, p) { + return nil + } + + return s.synthFunctionTypeWithCapturePoint(fn, sc, expected, p, captureTypes) +} + +func (s *Synthesizer) stableGraphLocalFunctionSnapshotType(sym compcfg.SymbolID) typ.Type { + if s == nil || sym == 0 || s.deps == nil || s.deps.Ctx == nil || s.deps.CheckCtx == nil { + return nil + } + + store := api.StoreFrom(s.deps.Ctx) + if store == nil { + return nil + } + + graph, ok := s.deps.CheckCtx.Graph().(*compcfg.Graph) + if !ok || graph == nil { + return nil + } + + fallbackParent := s.deps.DefaultScope + if fallbackParent == nil { + fallbackParent = s.deps.CheckCtx.TypeNames() + } + parent := api.ParentScopeForGraph(store, graph.ID(), fallbackParent) + if parent == nil { + return nil + } + + var fnTypes map[cfg.SymbolID]typ.Type + load := func() { + fnTypes = store.GetLocalFuncTypesSnapshot(graph, parent) + } + if phaser, ok := store.(interface{ WithPhase(api.Phase, func()) }); ok { + phaser.WithPhase(api.PhaseScopeCompute, load) + } else { + load() + } + if len(fnTypes) == 0 { + return nil + } + + return fnTypes[sym] +} + +func (s *Synthesizer) stableLocalFunctionValueType( + expr ast.Expr, + p cfg.Point, + sc *scope.State, + current typ.Type, + captureTypes map[cfg.SymbolID]typ.Type, +) typ.Type { + sym, fn, hasCaptures := s.graphLocalFunctionForExpr(expr) + if fn == nil { + return nil + } + + authoritative := current + if s.deps != nil && s.deps.CheckCtx != nil { + if types := s.deps.CheckCtx.Types(); types != nil { + if tv := types.EffectiveTypeAt(p, sym); tv.State == flow.StateResolved && tv.Type != nil { + authoritative = tv.Type + } + } + } + if snapshot := s.stableGraphLocalFunctionSnapshotType(sym); snapshot != nil { + if authoritative == nil || subtype.IsSubtype(snapshot, authoritative) { + authoritative = snapshot + } + } + if !hasCaptures && authoritative != nil { + return authoritative + } + + expectedFn, _ := unwrap.Optional(unwrap.Alias(authoritative)).(*typ.Function) + specialized := s.synthFunctionTypeWithCapturePoint(fn, sc, expectedFn, p, captureTypes) + if authoritative != nil && specialized != nil { + if subtype.IsSubtype(specialized, authoritative) { + return specialized + } + return authoritative + } + if authoritative != nil { + return authoritative + } + return specialized +} diff --git a/compiler/check/synth/phase/extract/named_function_test.go b/compiler/check/synth/phase/extract/named_function_test.go index f897c05b..2a25e35d 100644 --- a/compiler/check/synth/phase/extract/named_function_test.go +++ b/compiler/check/synth/phase/extract/named_function_test.go @@ -122,3 +122,249 @@ func TestFunctionLiteralForIdent_ResolvesAliasChainLiteral(t *testing.T) { t.Fatalf("functionLiteralForIdent(alias chain) = %v, want %v", got, want) } } + +func TestGraphLocalFunctionLiteralForExpr_ResolvesFieldDefinitionAttr(t *testing.T) { + stmts, err := parse.ParseString(` + local M = {} + function M.run() + return 1 + end + local f = M.run + `, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + fn := &ast.FunctionExpr{Stmts: stmts} + localBindings := bind.Bind(fn, nil) + graph := ccfg.BuildWithBindings(fn, localBindings) + checkCtx := api.NewDeclaredEnv(api.DeclaredEnvConfig{ + Graph: graph, + Bindings: localBindings, + }) + synth := NewSynthesizer(&Deps{ + CheckCtx: checkCtx, + ModuleBindings: localBindings, + PreCache: make(api.Cache), + NarrowCache: make(api.Cache), + }, api.PhaseTypeResolution) + + var ( + attr *ast.AttrGetExpr + want *ast.FunctionExpr + ) + graph.EachAssign(func(_ ccfg.Point, info *ccfg.AssignInfo) { + if attr != nil || info == nil { + return + } + info.EachTargetSource(func(_ int, _ ccfg.AssignTarget, source ast.Expr) { + if attr != nil { + return + } + if candidate, ok := source.(*ast.AttrGetExpr); ok { + attr = candidate + } + }) + }) + graph.EachFuncDef(func(_ ccfg.Point, info *ccfg.FuncDefInfo) { + if want != nil || info == nil || info.Name != "run" { + return + } + want = info.FuncExpr + }) + if attr == nil || want == nil { + t.Fatalf("expected alias assignment attr and run func def, got attr=%v want=%v", attr, want) + } + + got := synth.graphLocalFunctionLiteralForExpr(attr) + if got != want { + t.Fatalf("graphLocalFunctionLiteralForExpr(M.run) = %v, want %v", got, want) + } +} + +func TestGraphLocalFunctionLiteralForExpr_IgnoresMutableFieldPathAttr(t *testing.T) { + stmts, err := parse.ParseString(` + local M = { + dep = { + get = function() + return nil + end, + }, + } + M.dep = { + get = function() + return 1 + end, + } + local f = M.dep.get + `, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + fn := &ast.FunctionExpr{Stmts: stmts} + localBindings := bind.Bind(fn, nil) + graph := ccfg.BuildWithBindings(fn, localBindings) + checkCtx := api.NewDeclaredEnv(api.DeclaredEnvConfig{ + Graph: graph, + Bindings: localBindings, + }) + synth := NewSynthesizer(&Deps{ + CheckCtx: checkCtx, + ModuleBindings: localBindings, + PreCache: make(api.Cache), + NarrowCache: make(api.Cache), + }, api.PhaseTypeResolution) + + var attr *ast.AttrGetExpr + graph.EachAssign(func(_ ccfg.Point, info *ccfg.AssignInfo) { + if attr != nil || info == nil { + return + } + info.EachTargetSource(func(_ int, _ ccfg.AssignTarget, source ast.Expr) { + if attr != nil { + return + } + if candidate, ok := source.(*ast.AttrGetExpr); ok { + attr = candidate + } + }) + }) + if attr == nil { + t.Fatal("expected alias assignment attr source") + } + + if got := synth.graphLocalFunctionLiteralForExpr(attr); got != nil { + t.Fatalf("graphLocalFunctionLiteralForExpr(M.dep.get) = %v, want nil", got) + } +} + +func TestHasDominatingDirectFunctionRebind_FalseWhenOnlyCapturedFieldChanges(t *testing.T) { + stmts, err := parse.ParseString(` + local M = { + dep = { + get = function() + return nil + end, + }, + } + function M.run() + return M.dep.get() + end + M.dep = { + get = function() + return { answer = "ok" } + end, + } + local f = M.run + `, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + fn := &ast.FunctionExpr{Stmts: stmts} + localBindings := bind.Bind(fn, nil) + graph := ccfg.BuildWithBindings(fn, localBindings) + checkCtx := api.NewDeclaredEnv(api.DeclaredEnvConfig{ + Graph: graph, + Bindings: localBindings, + }) + synth := NewSynthesizer(&Deps{ + CheckCtx: checkCtx, + ModuleBindings: localBindings, + PreCache: make(api.Cache), + NarrowCache: make(api.Cache), + }, api.PhaseTypeResolution) + + var ( + attr *ast.AttrGetExpr + at ccfg.Point + ) + graph.EachAssign(func(p ccfg.Point, info *ccfg.AssignInfo) { + if attr != nil || info == nil { + return + } + info.EachTargetSource(func(_ int, _ ccfg.AssignTarget, source ast.Expr) { + if attr != nil { + return + } + if candidate, ok := source.(*ast.AttrGetExpr); ok { + attr = candidate + at = p + } + }) + }) + if attr == nil { + t.Fatal("expected alias assignment attr source") + } + + sym, stableFn, _ := synth.graphLocalFunctionForExpr(attr) + if stableFn == nil || sym == 0 { + t.Fatalf("expected stable graph-local function for attr, got sym=%d fn=%v", sym, stableFn) + } + if synth.hasDominatingDirectFunctionRebind(sym, stableFn, at) { + t.Fatal("captured field mutation should not invalidate field-defined wrapper value") + } +} + +func TestHasDominatingDirectFunctionRebind_TrueWhenFieldIsReassigned(t *testing.T) { + stmts, err := parse.ParseString(` + local M = { + dep = { + get = function() + return nil + end, + }, + } + function M.run() + return M.dep.get() + end + M.run = function() + return nil + end + local f = M.run + `, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + fn := &ast.FunctionExpr{Stmts: stmts} + localBindings := bind.Bind(fn, nil) + graph := ccfg.BuildWithBindings(fn, localBindings) + checkCtx := api.NewDeclaredEnv(api.DeclaredEnvConfig{ + Graph: graph, + Bindings: localBindings, + }) + synth := NewSynthesizer(&Deps{ + CheckCtx: checkCtx, + ModuleBindings: localBindings, + PreCache: make(api.Cache), + NarrowCache: make(api.Cache), + }, api.PhaseTypeResolution) + + var ( + attr *ast.AttrGetExpr + at ccfg.Point + ) + graph.EachAssign(func(p ccfg.Point, info *ccfg.AssignInfo) { + if attr != nil || info == nil { + return + } + info.EachTargetSource(func(_ int, _ ccfg.AssignTarget, source ast.Expr) { + if attr != nil { + return + } + if candidate, ok := source.(*ast.AttrGetExpr); ok { + attr = candidate + at = p + } + }) + }) + if attr == nil { + t.Fatal("expected alias assignment attr source") + } + + sym, stableFn, _ := synth.graphLocalFunctionForExpr(attr) + if stableFn == nil || sym == 0 { + t.Fatalf("expected stable graph-local function for attr, got sym=%d fn=%v", sym, stableFn) + } + if !synth.hasDominatingDirectFunctionRebind(sym, stableFn, at) { + t.Fatal("direct dominating field reassignment should invalidate field-defined wrapper value") + } +} diff --git a/compiler/check/synth/phase/extract/synthesizer.go b/compiler/check/synth/phase/extract/synthesizer.go index 79dedf88..f745063d 100644 --- a/compiler/check/synth/phase/extract/synthesizer.go +++ b/compiler/check/synth/phase/extract/synthesizer.go @@ -191,7 +191,9 @@ func (s *Synthesizer) Resolver() *resolve.Resolver { ExprSynth: func(expr ast.Expr, p cfg.Point) typ.Type { return s.SynthExpr(expr, p, nil) }, - Bindings: s.deps.ModuleBindings, + Bindings: s.deps.ModuleBindings, + ModuleBindings: s.deps.ModuleBindings, + ModuleAliases: s.deps.ModuleAliases, }) } @@ -357,6 +359,9 @@ func (s *Synthesizer) synthIdentCore(ex *ast.IdentExpr, p cfg.Point, sc *scope.S if s.IsNarrowing() && narrower != nil { path := constraint.Path{Root: ex.Value, Symbol: sym} if narrowed := narrower.NarrowedTypeAt(p, path); narrowed != nil { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, narrowed, nil); specialized != nil { + return specialized + } // Guard against unsound narrowing for annotated symbols by ensuring // the narrowed type remains a subtype of the declared type. Function // signatures from declared overlays are also authoritative and should @@ -384,6 +389,9 @@ fallback: if types := ctx.Types(); types != nil { tv := types.EffectiveTypeAt(p, sym) if tv.State == flow.StateResolved && tv.Type != nil { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, tv.Type, nil); specialized != nil { + return specialized + } // Prefer concrete resolved types over module aliases. // Allow module aliases to override unknown/any placeholders. if tv.Type.Kind().IsPlaceholder() { @@ -403,6 +411,9 @@ fallback: if moduleSym != 0 && moduleSym != sym { moduleTV := types.EffectiveTypeAt(p, moduleSym) if moduleTV.State == flow.StateResolved && moduleTV.Type != nil { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, moduleTV.Type, nil); specialized != nil { + return specialized + } if moduleTV.Type.Kind().IsPlaceholder() { // keep looking for better sources } else { @@ -426,11 +437,17 @@ fallback: if types := ctx.Types(); types != nil { tv := types.EffectiveTypeAt(p, sym) if tv.State == flow.StateResolved && tv.Type != nil { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, tv.Type, nil); specialized != nil { + return specialized + } return tv.Type } if moduleSym != 0 && moduleSym != sym { moduleTV := types.EffectiveTypeAt(p, moduleSym) if moduleTV.State == flow.StateResolved && moduleTV.Type != nil { + if specialized := s.stableLocalFunctionValueType(ex, p, sc, moduleTV.Type, nil); specialized != nil { + return specialized + } return moduleTV.Type } } diff --git a/compiler/check/synth/phase/extract/table.go b/compiler/check/synth/phase/extract/table.go index ae0bb1b4..a07f8c03 100644 --- a/compiler/check/synth/phase/extract/table.go +++ b/compiler/check/synth/phase/extract/table.go @@ -3,6 +3,7 @@ package extract import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/check/scope" + "github.com/wippyai/go-lua/compiler/check/synth/ops" phasecore "github.com/wippyai/go-lua/compiler/check/synth/phase/core" querycore "github.com/wippyai/go-lua/types/query/core" "github.com/wippyai/go-lua/types/typ" @@ -63,14 +64,22 @@ func (s *Synthesizer) SynthTableWithExpected(ex *ast.TableExpr, sc *scope.State, if ft == nil { ft = typ.Unknown } - selfBuilder.Field(k.Value, ft) + if inner, optional := typ.SplitNilableFieldType(ft); optional { + selfBuilder.OptField(k.Value, inner) + } else { + selfBuilder.Field(k.Value, ft) + } fieldCount++ case *ast.IdentExpr: ft := recurse(field.Value) if ft == nil { ft = typ.Unknown } - selfBuilder.Field(k.Value, ft) + if inner, optional := typ.SplitNilableFieldType(ft); optional { + selfBuilder.OptField(k.Value, inner) + } else { + selfBuilder.Field(k.Value, ft) + } fieldCount++ } } @@ -80,6 +89,7 @@ func (s *Synthesizer) SynthTableWithExpected(ex *ast.TableExpr, sc *scope.State, } builder := typ.NewRecord() + var fieldDefs []ops.FieldDef var arrayElements []typ.Type hasVararg := false fieldCount := 0 @@ -89,7 +99,8 @@ func (s *Synthesizer) SynthTableWithExpected(ex *ast.TableExpr, sc *scope.State, if _, ok := field.Value.(*ast.Comma3Expr); ok { hasVararg = true } - elemType := recurse(field.Value) + elemExpected := ops.ExpectedTableElementType(expected, len(arrayElements)) + elemType := s.synthFieldValueWithExpected(field.Value, sc, recurse, elemExpected, selfType) if elemType == nil { elemType = typ.Unknown } @@ -103,17 +114,28 @@ func (s *Synthesizer) SynthTableWithExpected(ex *ast.TableExpr, sc *scope.State, if ft == nil { ft = typ.Unknown } - builder.Field(k.Value, ft) + fieldDefs = append(fieldDefs, ops.FieldDef{Name: k.Value, Type: ft}) + if inner, optional := typ.SplitNilableFieldType(ft); optional { + builder.OptField(k.Value, inner) + } else { + builder.Field(k.Value, ft) + } fieldCount++ case *ast.IdentExpr: ft := s.synthFieldValueWithExpected(field.Value, sc, recurse, expectedFields[k.Value], selfType) if ft == nil { ft = typ.Unknown } - builder.Field(k.Value, ft) + fieldDefs = append(fieldDefs, ops.FieldDef{Name: k.Value, Type: ft}) + if inner, optional := typ.SplitNilableFieldType(ft); optional { + builder.OptField(k.Value, inner) + } else { + builder.Field(k.Value, ft) + } fieldCount++ case *ast.NumberExpr: - elemType := recurse(field.Value) + elemExpected := ops.ExpectedTableElementType(expected, len(arrayElements)) + elemType := s.synthFieldValueWithExpected(field.Value, sc, recurse, elemExpected, selfType) if elemType == nil { elemType = typ.Unknown } @@ -122,16 +144,25 @@ func (s *Synthesizer) SynthTableWithExpected(ex *ast.TableExpr, sc *scope.State, } if len(arrayElements) > 0 && fieldCount == 0 { + var result typ.Type if hasVararg { - return typ.NewArray(typ.NewUnion(arrayElements...)) + result = typ.NewArray(typ.NewUnion(arrayElements...)) + } else if querycore.IsArrayLike(expected) { + result = typ.NewArray(typ.NewUnion(arrayElements...)) + } else { + result = typ.NewTuple(arrayElements...) } - if querycore.IsArrayLike(expected) { - return typ.NewArray(typ.NewUnion(arrayElements...)) + if expected != nil && len(ops.CheckTable(nil, arrayElements, expected).Errors) == 0 { + return expected } - return typ.NewTuple(arrayElements...) + return result } - return builder.Build() + result := builder.Build() + if expected != nil && len(ops.CheckTable(fieldDefs, arrayElements, expected).Errors) == 0 { + return expected + } + return result } // synthFieldValueWithExpected synthesizes type for a table field value with optional expected type. diff --git a/compiler/check/synth/phase/extract/union_expected.go b/compiler/check/synth/phase/extract/union_expected.go index 0600a774..6e581e67 100644 --- a/compiler/check/synth/phase/extract/union_expected.go +++ b/compiler/check/synth/phase/extract/union_expected.go @@ -96,6 +96,13 @@ func (s *Synthesizer) synthExprWithExpectedSingle( return expected } return inferred + case *ast.AttrGetExpr: + if expectedFn, ok := unwrap.Alias(expected).(*typ.Function); ok { + if fnType := s.expectedGraphLocalFunctionValueType(ex, p, sc, expectedFn, nil); fnType != nil { + return fnType + } + } + return s.synthExprCore(expr, sc, p, s.deps.Flow, recurse) default: return s.synthExprCore(expr, sc, p, s.deps.Flow, recurse) } diff --git a/compiler/check/synth/phase/resolve/resolver.go b/compiler/check/synth/phase/resolve/resolver.go index 4849052f..179d5a4e 100644 --- a/compiler/check/synth/phase/resolve/resolver.go +++ b/compiler/check/synth/phase/resolve/resolver.go @@ -22,9 +22,11 @@ import ( "fmt" "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth/phase/core" + typecfg "github.com/wippyai/go-lua/types/cfg" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/io" "github.com/wippyai/go-lua/types/kind" @@ -45,6 +47,8 @@ type Resolver struct { manifests io.ManifestQuerier exprSynth api.ExprSynth bindings core.ParamSymbolLookup + moduleBindings *bind.BindingTable + moduleAliases map[typecfg.SymbolID]string } // Config configures a Resolver. @@ -52,14 +56,18 @@ type Config struct { Manifests io.ManifestQuerier ExprSynth api.ExprSynth Bindings core.ParamSymbolLookup + ModuleBindings *bind.BindingTable + ModuleAliases map[typecfg.SymbolID]string } // New creates a new type resolver. func New(c Config) *Resolver { return &Resolver{ - manifests: c.Manifests, - exprSynth: c.ExprSynth, - bindings: c.Bindings, + manifests: c.Manifests, + exprSynth: c.ExprSynth, + bindings: c.Bindings, + moduleBindings: c.ModuleBindings, + moduleAliases: c.ModuleAliases, } } @@ -169,7 +177,131 @@ func (r *Resolver) ResolveTypeDef(name string, typeExpr ast.TypeExpr, typeParams body := r.ResolveType(typeExpr, bodyScope) return typ.NewGeneric(name, params, body) } - return r.ResolveType(typeExpr, sc) + return r.resolveNonGenericTypeDef(name, typeExpr, sc) +} + +// resolveNonGenericTypeDef resolves a non-generic type alias, preserving +// self-recursive aliases as canonical recursive types. +// +// The resolution uses two passes: +// - Pass 1 builds a provisional body to detect self recursion and seed the +// recursive placeholder with a structurally correct body. +// - Pass 2 rebuilds the body once the placeholder has a real body so any +// enclosing function/record hashes that mention self are finalized against +// the completed recursive shape rather than an empty placeholder. +func (r *Resolver) resolveNonGenericTypeDef(name string, typeExpr ast.TypeExpr, sc *scope.State) typ.Type { + self := typ.NewRecursivePlaceholder(name) + bodyScope := sc.WithType(name, self) + + provisional := r.ResolveType(typeExpr, bodyScope) + if !containsRecursiveRef(provisional, self, 0) { + return provisional + } + + self.SetBody(provisional) + finalBody := r.ResolveType(typeExpr, bodyScope) + self.SetBody(finalBody) + return self +} + +func containsRecursiveRef(t typ.Type, self *typ.Recursive, depth int) bool { + if t == nil || self == nil || typ.DepthExceeded(depth) { + return false + } + if t == self { + return true + } + + return typ.Visit(t, typ.Visitor[bool]{ + Optional: func(o *typ.Optional) bool { + return containsRecursiveRef(o.Inner, self, depth+1) + }, + Union: func(u *typ.Union) bool { + for _, m := range u.Members { + if containsRecursiveRef(m, self, depth+1) { + return true + } + } + return false + }, + Intersection: func(in *typ.Intersection) bool { + for _, m := range in.Members { + if containsRecursiveRef(m, self, depth+1) { + return true + } + } + return false + }, + Array: func(a *typ.Array) bool { + return containsRecursiveRef(a.Element, self, depth+1) + }, + Map: func(m *typ.Map) bool { + return containsRecursiveRef(m.Key, self, depth+1) || + containsRecursiveRef(m.Value, self, depth+1) + }, + Tuple: func(tup *typ.Tuple) bool { + for _, elem := range tup.Elements { + if containsRecursiveRef(elem, self, depth+1) { + return true + } + } + return false + }, + Function: func(fn *typ.Function) bool { + for _, p := range fn.Params { + if containsRecursiveRef(p.Type, self, depth+1) { + return true + } + } + if containsRecursiveRef(fn.Variadic, self, depth+1) { + return true + } + for _, ret := range fn.Returns { + if containsRecursiveRef(ret, self, depth+1) { + return true + } + } + return false + }, + Record: func(rec *typ.Record) bool { + for _, f := range rec.Fields { + if containsRecursiveRef(f.Type, self, depth+1) { + return true + } + } + return containsRecursiveRef(rec.Metatable, self, depth+1) || + containsRecursiveRef(rec.MapKey, self, depth+1) || + containsRecursiveRef(rec.MapValue, self, depth+1) + }, + Alias: func(a *typ.Alias) bool { + return containsRecursiveRef(a.Target, self, depth+1) + }, + Interface: func(iface *typ.Interface) bool { + for _, m := range iface.Methods { + if containsRecursiveRef(m.Type, self, depth+1) { + return true + } + } + return false + }, + Instantiated: func(inst *typ.Instantiated) bool { + if inst.Generic != nil && containsRecursiveRef(inst.Generic.Body, self, depth+1) { + return true + } + for _, arg := range inst.TypeArgs { + if containsRecursiveRef(arg, self, depth+1) { + return true + } + } + return false + }, + Recursive: func(rec *typ.Recursive) bool { + return rec == self + }, + Default: func(typ.Type) bool { + return false + }, + }) } func (r *Resolver) resolveTypeDepth(expr ast.TypeExpr, sc *scope.State, depth int) typ.Type { @@ -407,6 +539,10 @@ func (r *Resolver) resolveFunction(te *ast.FunctionTypeExpr, sc *scope.State, de for _, p := range te.Params { paramType := r.resolveTypeDepth(p.Type, sc, depth+1) + if _, ok := p.Type.(*ast.OptionalTypeExpr); ok { + builder.OptParam(p.Name, paramType) + continue + } builder.Param(p.Name, paramType) } @@ -452,7 +588,7 @@ func (r *Resolver) resolveFunction(te *ast.FunctionTypeExpr, sc *scope.State, de return builder.Build() } -func (r *Resolver) buildAssertEffect(paramIdx int, narrowTo ast.TypeExpr, sc *scope.State, depth int) *constraint.FunctionEffect { +func (r *Resolver) buildAssertEffect(paramIdx int, narrowTo ast.TypeExpr, sc *scope.State, depth int) *constraint.FunctionRefinement { placeholder := fmt.Sprintf("$%d", paramIdx) path := constraint.Path{Root: placeholder} @@ -472,7 +608,7 @@ func (r *Resolver) buildAssertEffect(paramIdx int, narrowTo ast.TypeExpr, sc *sc if len(constraints) == 0 { return nil } - return constraint.NewEffect(constraints, nil, nil) + return constraint.NewRefinement(constraints, nil, nil) } func (r *Resolver) resolveRef(te *ast.TypeRefExpr, sc *scope.State) typ.Type { @@ -493,7 +629,7 @@ func (r *Resolver) resolveRef(te *ast.TypeRefExpr, sc *scope.State) typ.Type { return typ.NewRef("", name) } - module := te.Path[0] + module := r.resolveModuleAliasPath(te.Path[0]) for i := 1; i < len(te.Path)-1; i++ { module += "." + te.Path[i] } @@ -510,6 +646,37 @@ func (r *Resolver) resolveRef(te *ast.TypeRefExpr, sc *scope.State) typ.Type { return typ.NewRef(module, typeName) } +func (r *Resolver) resolveModuleAliasPath(name string) string { + if name == "" || r == nil || r.moduleBindings == nil || len(r.moduleAliases) == 0 { + return name + } + + syms := r.moduleBindings.SymbolsByName(name) + if len(syms) == 0 { + return name + } + + resolved := "" + for _, sym := range syms { + path := r.moduleAliases[sym] + if path == "" { + continue + } + if resolved == "" { + resolved = path + continue + } + if resolved != path { + return name + } + } + + if resolved == "" { + return name + } + return resolved +} + func (r *Resolver) resolveGeneric(te *ast.GenericTypeExpr, sc *scope.State, depth int) typ.Type { if te.Base == nil || len(te.Base.Path) == 0 { return typ.Unknown diff --git a/compiler/check/synth/phase/resolve/resolver_test.go b/compiler/check/synth/phase/resolve/resolver_test.go index 7ad5b226..dbda0dc1 100644 --- a/compiler/check/synth/phase/resolve/resolver_test.go +++ b/compiler/check/synth/phase/resolve/resolver_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/check/scope" typecfg "github.com/wippyai/go-lua/types/cfg" "github.com/wippyai/go-lua/types/io" @@ -94,6 +95,31 @@ func TestResolveType_SelfNoScope(t *testing.T) { } } +func TestResolveType_ModuleAliasPrefixUsesRequireAliasPath(t *testing.T) { + manifest := io.NewManifest("store") + storeType := typ.NewAlias("Store", typ.NewRecord(). + Field("cache", typ.NewMap(typ.String, typ.String)). + Build()) + manifest.DefineType("Store", storeType) + + bt := bind.NewBindingTable() + const sym typecfg.SymbolID = 42 + bt.SetName(sym, "store_mod") + + r := New(Config{ + Manifests: manifestQuerierStub{ + manifests: map[string]*io.Manifest{"store": manifest}, + }, + ModuleBindings: bt, + ModuleAliases: map[typecfg.SymbolID]string{sym: "store"}, + }) + + result := r.ResolveType(&ast.TypeRefExpr{Path: []string{"store_mod", "Store"}}, scope.New()) + if !typ.TypeEquals(result, storeType) { + t.Fatalf("got %s, want %s", typ.FormatShort(result), typ.FormatShort(storeType)) + } +} + func TestResolveType_Optional(t *testing.T) { r := newTestResolver() sc := scope.New() @@ -264,6 +290,33 @@ func TestResolveType_Function(t *testing.T) { } } +func TestResolveType_FunctionOptionalParam(t *testing.T) { + r := newTestResolver() + sc := scope.New() + + expr := &ast.FunctionTypeExpr{ + Params: []ast.FunctionParamExpr{ + {Name: "n", Type: &ast.OptionalTypeExpr{Inner: &ast.PrimitiveTypeExpr{Name: "number"}}}, + }, + Returns: []ast.TypeExpr{&ast.PrimitiveTypeExpr{Name: "string"}}, + } + result := r.ResolveType(expr, sc) + + fn, ok := result.(*typ.Function) + if !ok { + t.Fatalf("got %T, want function", result) + } + if len(fn.Params) != 1 { + t.Fatalf("got %d params, want 1", len(fn.Params)) + } + if !fn.Params[0].Optional { + t.Fatal("expected function type optional param to preserve optional arity") + } + if !typ.TypeEquals(fn.Params[0].Type, typ.NewOptional(typ.Number)) { + t.Fatalf("got param type %v, want number?", fn.Params[0].Type) + } +} + func TestResolveType_FunctionVariadic(t *testing.T) { r := newTestResolver() sc := scope.New() @@ -731,3 +784,61 @@ func TestResolveTypeDef_Generic(t *testing.T) { t.Fatalf("got %d type params, want 1", len(generic.TypeParams)) } } + +func TestResolveTypeDef_RecursiveAliasBodyUsesResolvedSelfType(t *testing.T) { + r := newTestResolver() + sc := scope.New() + + typeExpr := &ast.RecordTypeExpr{ + Fields: []ast.RecordFieldExpr{ + { + Name: "f", + Type: &ast.FunctionTypeExpr{ + Params: []ast.FunctionParamExpr{ + {Name: "self", Type: &ast.PrimitiveTypeExpr{Name: "Node"}}, + }, + Returns: []ast.TypeExpr{ + &ast.PrimitiveTypeExpr{Name: "Node"}, + }, + }, + }, + { + Name: "g", + Type: &ast.FunctionTypeExpr{ + Params: []ast.FunctionParamExpr{ + {Name: "self", Type: &ast.PrimitiveTypeExpr{Name: "Node"}}, + }, + Returns: []ast.TypeExpr{ + &ast.PrimitiveTypeExpr{Name: "number"}, + }, + }, + }, + }, + } + + result := r.ResolveTypeDef("Node", typeExpr, nil, sc) + rec, ok := result.(*typ.Recursive) + if !ok { + t.Fatalf("got %T, want recursive type", result) + } + + body, ok := rec.Body.(*typ.Record) + if !ok { + t.Fatalf("body: got %T, want record", rec.Body) + } + + fField := body.GetField("f") + if fField == nil { + t.Fatal("missing f field") + } + fType, ok := fField.Type.(*typ.Function) + if !ok { + t.Fatalf("f: got %T, want function", fField.Type) + } + if len(fType.Params) != 1 || fType.Params[0].Type != rec { + t.Fatalf("f self param: got %v, want recursive self type", fType.Params[0].Type) + } + if len(fType.Returns) != 1 || fType.Returns[0] != rec { + t.Fatalf("f return: got %v, want recursive self type", fType.Returns[0]) + } +} diff --git a/compiler/check/synth/transform/effect_return.go b/compiler/check/synth/transform/effect_return.go index 1eb937f4..5c968798 100644 --- a/compiler/check/synth/transform/effect_return.go +++ b/compiler/check/synth/transform/effect_return.go @@ -61,6 +61,13 @@ func ApplyEffectTransform(fn *typ.Function, args []typ.Type, returnIdx int, base } } return baseReturn + case effect.StringUnpackValue: + if resolved := resolveParamType(args, transform.Format); resolved != nil { + if unpacked := unpackFirstValueType(resolved); unpacked != nil { + return unpacked + } + } + return baseReturn case effect.CallbackReturn: if resolved := resolveParamType(args, transform.CallbackParam); resolved != nil { if cbRet := callbackReturnType(resolved); cbRet != nil { diff --git a/compiler/check/synth/transform/effect_return_test.go b/compiler/check/synth/transform/effect_return_test.go index 3c64fcb6..520b0127 100644 --- a/compiler/check/synth/transform/effect_return_test.go +++ b/compiler/check/synth/transform/effect_return_test.go @@ -174,3 +174,60 @@ func TestApplyEffectTransform_ArrayOfCallbackReturn(t *testing.T) { t.Fatalf("expected array(callback return) transform to produce %v, got %v", want, got) } } + +func TestApplyEffectTransform_StringUnpackValue_IntegerFormat(t *testing.T) { + spec := contract.NewSpec().WithEffects(effect.Return{ + ReturnIndex: 0, + Transform: effect.StringUnpackValue{Format: effect.ParamRef{Index: 0}}, + }) + fn := typ.Func(). + Param("fmt", typ.String). + Param("s", typ.String). + OptParam("pos", typ.Integer). + Returns(typ.Any). + Spec(spec). + Build() + + got := ApplyEffectTransform(fn, []typ.Type{typ.LiteralString(">I4"), typ.String, typ.Integer}, 0, typ.Any) + if !typ.TypeEquals(got, typ.Integer) { + t.Fatalf("expected string.unpack integer format to produce integer, got %v", got) + } +} + +func TestApplyEffectTransform_StringUnpackValue_StringFormat(t *testing.T) { + spec := contract.NewSpec().WithEffects(effect.Return{ + ReturnIndex: 0, + Transform: effect.StringUnpackValue{Format: effect.ParamRef{Index: 0}}, + }) + fn := typ.Func(). + Param("fmt", typ.String). + Param("s", typ.String). + OptParam("pos", typ.Integer). + Returns(typ.Any). + Spec(spec). + Build() + + got := ApplyEffectTransform(fn, []typ.Type{typ.LiteralString("z"), typ.String, typ.Integer}, 0, typ.Any) + if !typ.TypeEquals(got, typ.String) { + t.Fatalf("expected string.unpack string format to produce string, got %v", got) + } +} + +func TestApplyEffectTransform_StringUnpackValue_UnsupportedFormatFallsBack(t *testing.T) { + spec := contract.NewSpec().WithEffects(effect.Return{ + ReturnIndex: 0, + Transform: effect.StringUnpackValue{Format: effect.ParamRef{Index: 0}}, + }) + fn := typ.Func(). + Param("fmt", typ.String). + Param("s", typ.String). + OptParam("pos", typ.Integer). + Returns(typ.Any). + Spec(spec). + Build() + + got := ApplyEffectTransform(fn, []typ.Type{typ.LiteralString("X"), typ.String, typ.Integer}, 0, typ.Any) + if !typ.TypeEquals(got, typ.Any) { + t.Fatalf("expected unsupported string.unpack format to fall back to any, got %v", got) + } +} diff --git a/compiler/check/synth/transform/string_unpack.go b/compiler/check/synth/transform/string_unpack.go new file mode 100644 index 00000000..6ae40549 --- /dev/null +++ b/compiler/check/synth/transform/string_unpack.go @@ -0,0 +1,127 @@ +package transform + +import ( + "github.com/wippyai/go-lua/types/kind" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +func unpackFirstValueType(formatType typ.Type) typ.Type { + switch v := unwrap.Alias(formatType).(type) { + case *typ.Literal: + if v.Base != kind.String { + return nil + } + format, ok := v.Value.(string) + if !ok { + return nil + } + return unpackFirstValueTypeForFormat(format) + case *typ.Optional: + return unpackFirstValueType(v.Inner) + case *typ.Union: + var members []typ.Type + for _, member := range v.Members { + resolved := unpackFirstValueType(member) + if resolved == nil { + return nil + } + members = append(members, resolved) + } + if len(members) == 0 { + return nil + } + return typ.NewUnion(members...) + case *typ.Intersection: + format, ok := literalStringValue(v) + if !ok { + return nil + } + return unpackFirstValueTypeForFormat(format) + default: + return nil + } +} + +func literalStringValue(t typ.Type) (string, bool) { + switch v := unwrap.Alias(t).(type) { + case *typ.Literal: + if v.Base != kind.String { + return "", false + } + s, ok := v.Value.(string) + return s, ok + case *typ.Intersection: + var ( + value string + found bool + ) + for _, member := range v.Members { + s, ok := literalStringValue(member) + if !ok { + continue + } + if found && s != value { + return "", false + } + value = s + found = true + } + return value, found + default: + return "", false + } +} + +func unpackFirstValueTypeForFormat(format string) typ.Type { + for i := 0; i < len(format); { + switch ch := format[i]; ch { + case ' ', '\t', '\n', '\r', '<', '>', '=': + i++ + case '!': + i++ + start := i + for i < len(format) && isASCIIDigit(format[i]) { + i++ + } + if start == i { + return nil + } + case 'x': + i++ + case 'b', 'B', 'h', 'H', 'l', 'L', 'j', 'J', 'T': + return typ.Integer + case 'i', 'I': + i++ + for i < len(format) && isASCIIDigit(format[i]) { + i++ + } + return typ.Integer + case 'f', 'd', 'n': + return typ.Number + case 'c': + i++ + start := i + for i < len(format) && isASCIIDigit(format[i]) { + i++ + } + if start == i { + return nil + } + return typ.String + case 'z', 's': + i++ + for i < len(format) && isASCIIDigit(format[i]) { + i++ + } + return typ.String + default: + return nil + } + } + return nil +} + +func isASCIIDigit(b byte) bool { + return b >= '0' && b <= '9' +} diff --git a/compiler/check/tests/core/assign_test.go b/compiler/check/tests/core/assign_test.go index f731c08e..5fda5687 100644 --- a/compiler/check/tests/core/assign_test.go +++ b/compiler/check/tests/core/assign_test.go @@ -276,7 +276,8 @@ func TestAssign_ErrorMessage(t *testing.T) { } // TestAssign_DynamicTableIndexing tests that tables populated via t[key] = value -// are properly typed and can be accessed. +// keep sound index semantics: exact dominating writes can be definite, while +// arbitrary map lookups remain optional until proven present. func TestAssign_DynamicTableIndexing(t *testing.T) { tests := []testutil.Case{ { @@ -288,7 +289,7 @@ func TestAssign_DynamicTableIndexing(t *testing.T) { for _, m in ipairs(methods) do method_names[m.name] = true end - local exists: boolean = method_names["greet"] + local exists: boolean? = method_names["greet"] `, WantError: false, Stdlib: true, diff --git a/compiler/check/tests/errors/error_correlation_test.go b/compiler/check/tests/errors/error_correlation_test.go index e0bec57e..cfac14ca 100644 --- a/compiler/check/tests/errors/error_correlation_test.go +++ b/compiler/check/tests/errors/error_correlation_test.go @@ -41,7 +41,7 @@ func TestAssertIsNilNarrowsSiblingRequire(t *testing.T) { Field("is_nil", typ.Func(). Param("val", typ.Any). OptParam("msg", typ.String). - WithRefinement(constraint.NewEffect( + WithRefinement(constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, )). diff --git a/compiler/check/tests/errors/never_after_assertion_test.go b/compiler/check/tests/errors/never_after_assertion_test.go index 77989621..6e15126a 100644 --- a/compiler/check/tests/errors/never_after_assertion_test.go +++ b/compiler/check/tests/errors/never_after_assertion_test.go @@ -14,7 +14,7 @@ import ( // TestFunctionRefinementSetup verifies that function refinements are properly set. func TestFunctionRefinementSetup(t *testing.T) { // Create function with refinement - notNilEffect := constraint.NewEffect( + notNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.NotNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) @@ -26,9 +26,9 @@ func TestFunctionRefinementSetup(t *testing.T) { if fn.Refinement == nil { t.Fatal("Refinement is nil") } - eff, ok := fn.Refinement.(*constraint.FunctionEffect) + eff, ok := fn.Refinement.(*constraint.FunctionRefinement) if !ok { - t.Fatalf("Refinement is %T, not *constraint.FunctionEffect", fn.Refinement) + t.Fatalf("Refinement is %T, not *constraint.FunctionRefinement", fn.Refinement) } if len(eff.OnReturn.MustConstraints()) != 1 { t.Fatalf("OnReturn.Len() = %d, want 1", len(eff.OnReturn.MustConstraints())) @@ -52,7 +52,7 @@ func TestFunctionRefinementSetup(t *testing.T) { if fieldFn.Refinement == nil { t.Fatal("field function Refinement is nil after record lookup") } - fieldEff, ok := fieldFn.Refinement.(*constraint.FunctionEffect) + fieldEff, ok := fieldFn.Refinement.(*constraint.FunctionRefinement) if !ok { t.Fatalf("field Refinement is %T", fieldFn.Refinement) } @@ -64,7 +64,7 @@ func TestFunctionRefinementSetup(t *testing.T) { // TestScopeLookupWithManifest verifies that manifest symbols are accessible in scope. func TestScopeLookupWithManifest(t *testing.T) { - notNilEffect := constraint.NewEffect( + notNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.NotNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) @@ -123,7 +123,7 @@ func TestNeverAfterAssertion(t *testing.T) { Build() // assert.not_nil narrows first param to non-nil via OnReturn - notNilEffect := constraint.NewEffect( + notNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.NotNil{Path: constraint.Path{Root: "$0"}}}, // OnReturn nil, // OnTrue nil, // OnFalse @@ -135,7 +135,7 @@ func TestNeverAfterAssertion(t *testing.T) { Build() // assert.is_nil narrows first param to nil via OnReturn - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, // OnReturn nil, // OnTrue nil, // OnFalse @@ -289,7 +289,7 @@ func TestFlowSolutionDebug(t *testing.T) { Build() // assert.not_nil narrows first param to non-nil via OnReturn - notNilEffect := constraint.NewEffect( + notNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.NotNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/tests/errors/never_type_unit_test.go b/compiler/check/tests/errors/never_type_unit_test.go index 869d9399..a005d1dd 100644 --- a/compiler/check/tests/errors/never_type_unit_test.go +++ b/compiler/check/tests/errors/never_type_unit_test.go @@ -25,7 +25,7 @@ func TestNeverType(t *testing.T) { Build() // Create assert manifest with refinement - notNilEffect := constraint.NewEffect( + notNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.NotNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) @@ -35,7 +35,7 @@ func TestNeverType(t *testing.T) { WithRefinement(notNilEffect). Build() - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/tests/flow/fixpoint_unification_test.go b/compiler/check/tests/flow/fixpoint_unification_test.go index 535bec38..5b0cb346 100644 --- a/compiler/check/tests/flow/fixpoint_unification_test.go +++ b/compiler/check/tests/flow/fixpoint_unification_test.go @@ -217,7 +217,7 @@ end // Verify effects exist and A's effect has Terminates == true. foundA := false - for sym, eff := range sess.Store.InterprocPrev.Effects { + for sym, eff := range sess.Store.InterprocPrev.Refinements { if eff == nil { continue } @@ -267,7 +267,7 @@ end } foundA := false - for sym, eff := range sess.Store.InterprocPrev.Effects { + for sym, eff := range sess.Store.InterprocPrev.Refinements { if eff == nil { continue } @@ -307,7 +307,7 @@ end } foundA := false - for sym, eff := range sess.Store.InterprocPrev.Effects { + for sym, eff := range sess.Store.InterprocPrev.Refinements { if eff == nil { continue } @@ -351,7 +351,7 @@ end } foundA := false - for sym, eff := range sess.Store.InterprocPrev.Effects { + for sym, eff := range sess.Store.InterprocPrev.Refinements { if eff == nil { continue } @@ -372,14 +372,14 @@ end } // TestFixpointUnification_EffectRowLabels verifies that effect row labels -// are properly stored on FunctionEffect and survive the fixpoint swap. +// are properly stored on FunctionRefinement and survive the fixpoint swap. func TestFixpointUnification_EffectRowLabels(t *testing.T) { - // Verify that the Row field on FunctionEffect supports union and equality. + // Verify that the Row field on FunctionRefinement supports union and equality. row1 := effect.WithModuleLoad() row2 := effect.WithVariadicTransform() combined := effect.Union(row1, row2) - eff := &constraint.FunctionEffect{ + eff := &constraint.FunctionRefinement{ Row: combined, Terminates: false, } diff --git a/compiler/check/tests/inference/table_inference_test.go b/compiler/check/tests/inference/table_inference_test.go index 8e40703e..38cdd490 100644 --- a/compiler/check/tests/inference/table_inference_test.go +++ b/compiler/check/tests/inference/table_inference_test.go @@ -345,6 +345,16 @@ func TestTableInference_TypedRecords(t *testing.T) { WantError: false, Stdlib: true, }, + { + Name: "typed record optional field from nilable expression", + Code: ` + type Config = {name: string, debug: boolean?} + local maybe_debug: boolean? = nil + local c: Config = {name = "prod", debug = maybe_debug} + `, + WantError: false, + Stdlib: true, + }, { Name: "typed record with computed key rejects", Code: ` @@ -373,10 +383,10 @@ func TestTableInference_TypedRecords(t *testing.T) { func TestTableInference_MapTypes(t *testing.T) { tests := []testutil.Case{ { - Name: "string to number map", + Name: "string to number map lookup is optional", Code: ` local scores: {[string]: number} = {["alice"] = 100, ["bob"] = 90} - local s: number = scores["alice"] + local s: number? = scores["alice"] `, WantError: false, Stdlib: true, diff --git a/compiler/check/tests/modules/imported_self_method_store_test.go b/compiler/check/tests/modules/imported_self_method_store_test.go new file mode 100644 index 00000000..849a9af9 --- /dev/null +++ b/compiler/check/tests/modules/imported_self_method_store_test.go @@ -0,0 +1,73 @@ +package modules + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +func TestExportedSelfMethodStore_ConstructorReturnMatchesExportedType(t *testing.T) { + mod := testutil.CheckAndExport(` +type Store = { + cache: {[string]: string}, + get: (self: Store, key: string) -> string?, + put: (self: Store, key: string, value: string) -> Store, +} + +local Store = {} +Store.__index = Store + +local M = {} +M.Store = Store + +function M.new(): Store + local self: Store = { + cache = {}, + get = Store.get, + put = Store.put, + } + setmetatable(self, Store) + return self +end + +function Store:get(key: string): string? + return self.cache[key] +end + +function Store:put(key: string, value: string): Store + self.cache[key] = value + return self +end + +return M +`, "store") + if mod.HasError() { + t.Fatalf("module errors: %v", testutil.ErrorMessages(mod.Errors)) + } + if mod.Manifest == nil { + t.Fatal("expected manifest") + } + + storeType, ok := mod.Manifest.LookupType("Store") + if !ok || storeType == nil { + t.Fatal("expected exported Store type") + } + + newType, ok := mod.Manifest.LookupValue("new") + if !ok || newType == nil { + t.Fatal("expected exported new value type") + } + + fn := unwrap.Function(newType) + if fn == nil || len(fn.Returns) != 1 { + t.Fatalf("expected constructor function type, got %s", typ.FormatShort(newType)) + } + + got := fn.Returns[0] + if !subtype.IsSubtype(got, storeType) { + t.Fatalf("constructor return %s is not subtype of exported Store %s", typ.FormatShort(got), typ.FormatShort(storeType)) + } +} diff --git a/compiler/check/tests/modules/module_workflow_test.go b/compiler/check/tests/modules/module_workflow_test.go index 7299c9cb..878b0818 100644 --- a/compiler/check/tests/modules/module_workflow_test.go +++ b/compiler/check/tests/modules/module_workflow_test.go @@ -567,8 +567,8 @@ func TestModuleWorkflow_ExportTypeIncludesEffects(t *testing.T) { if fn.Refinement == nil { t.Fatalf("expected not_nil to carry refinement in export") } - if _, ok := fn.Refinement.(*constraint.FunctionEffect); !ok { - t.Fatalf("expected FunctionEffect refinement, got %T", fn.Refinement) + if _, ok := fn.Refinement.(*constraint.FunctionRefinement); !ok { + t.Fatalf("expected FunctionRefinement refinement, got %T", fn.Refinement) } } @@ -835,9 +835,9 @@ func TestE2E_KeysCollectorCrossModule(t *testing.T) { if fn.Refinement == nil { t.Fatal("expected refinement on sorted_keys") } - eff, ok := fn.Refinement.(*constraint.FunctionEffect) + eff, ok := fn.Refinement.(*constraint.FunctionRefinement) if !ok { - t.Fatalf("expected FunctionEffect, got %T", fn.Refinement) + t.Fatalf("expected FunctionRefinement, got %T", fn.Refinement) } if paramIdx := eff.KeysCollectorParamIndex(); paramIdx != 0 { t.Errorf("KeysCollectorParamIndex: got %d, want 0", paramIdx) @@ -869,7 +869,7 @@ func TestE2E_KeysCollectorCrossModule(t *testing.T) { if decodedFn.Refinement == nil { t.Fatal("decoded function missing refinement") } - decodedEff, _ := decodedFn.Refinement.(*constraint.FunctionEffect) + decodedEff, _ := decodedFn.Refinement.(*constraint.FunctionRefinement) if paramIdx := decodedEff.KeysCollectorParamIndex(); paramIdx != 0 { t.Errorf("decoded KeysCollectorParamIndex: got %d, want 0", paramIdx) } @@ -937,9 +937,9 @@ func TestE2E_KeysCollectorAutoExport(t *testing.T) { if fn.Refinement == nil { t.Fatal("expected refinement on sorted_keys") } - eff, ok := fn.Refinement.(*constraint.FunctionEffect) + eff, ok := fn.Refinement.(*constraint.FunctionRefinement) if !ok { - t.Fatalf("expected FunctionEffect, got %T", fn.Refinement) + t.Fatalf("expected FunctionRefinement, got %T", fn.Refinement) } if paramIdx := eff.KeysCollectorParamIndex(); paramIdx != 0 { t.Errorf("KeysCollectorParamIndex: got %d, want 0", paramIdx) @@ -1005,9 +1005,9 @@ func TestE2E_KeysCollectorAutoExport_MultiReturnKeySlot(t *testing.T) { if !ok { t.Fatal("expected Function type") } - eff, ok := fn.Refinement.(*constraint.FunctionEffect) + eff, ok := fn.Refinement.(*constraint.FunctionRefinement) if !ok { - t.Fatalf("expected FunctionEffect refinement, got %T", fn.Refinement) + t.Fatalf("expected FunctionRefinement refinement, got %T", fn.Refinement) } paramIdx, retIdx, ok := eff.KeysCollectorInfo() if !ok { diff --git a/compiler/check/tests/regression/alias_preserved_after_nested_mutation_test.go b/compiler/check/tests/regression/alias_preserved_after_nested_mutation_test.go new file mode 100644 index 00000000..964ef7bd --- /dev/null +++ b/compiler/check/tests/regression/alias_preserved_after_nested_mutation_test.go @@ -0,0 +1,27 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestAliasPreservedAfterNestedMutationReturn(t *testing.T) { + result := testutil.Check(` +type Builder = {_messages: {string}} + +local function new(): Builder + return {_messages = {}} +end + +local function clone(): Builder + local b = new() + local msg: string = "x" + table.insert(b._messages, msg) + return b +end +`, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected nested mutation to preserve alias return type, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/array_union_contextual_table_test.go b/compiler/check/tests/regression/array_union_contextual_table_test.go new file mode 100644 index 00000000..796a8a99 --- /dev/null +++ b/compiler/check/tests/regression/array_union_contextual_table_test.go @@ -0,0 +1,24 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestArrayUnionContextualTyping_UsesExpectedElementType(t *testing.T) { + result := testutil.Check(` +type ContentEvent = {type: "content", data: string} +type ToolCallEvent = {type: "tool_call", id: string, name: string, arguments: {[string]: any}} +type DoneEvent = {type: "done", reason: string?, usage: {input_tokens: number, output_tokens: number}?} +type StreamEvent = ContentEvent | ToolCallEvent | DoneEvent + +local events: {StreamEvent} = { + {type = "content", data = "Hello"}, + {type = "tool_call", id = "t1", name = "search", arguments = {query = "test"}}, +} +`, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors for union array element contextual typing, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/custom_error_return_assert_correlation_test.go b/compiler/check/tests/regression/custom_error_return_assert_correlation_test.go index 23a6dfa6..9f225eb2 100644 --- a/compiler/check/tests/regression/custom_error_return_assert_correlation_test.go +++ b/compiler/check/tests/regression/custom_error_return_assert_correlation_test.go @@ -12,7 +12,7 @@ import ( // Regression: local functions returning `(value?, custom_error_record?)` must // produce ErrorReturn correlation so assert-based nil checks narrow siblings. func TestRegression_CustomErrorReturnCorrelationFromBody(t *testing.T) { - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/tests/regression/error_record_optional_field_after_nil_assert_test.go b/compiler/check/tests/regression/error_record_optional_field_after_nil_assert_test.go index 13fe6cd1..503bf923 100644 --- a/compiler/check/tests/regression/error_record_optional_field_after_nil_assert_test.go +++ b/compiler/check/tests/regression/error_record_optional_field_after_nil_assert_test.go @@ -13,7 +13,7 @@ import ( // additional fields on some paths, merged return slots should expose those // fields as optional (instead of producing "field missing on union member"). func TestRegression_ErrorRecordFieldsBecomeOptionalAcrossReturnPaths(t *testing.T) { - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/tests/regression/error_return_sibling_correlation_test.go b/compiler/check/tests/regression/error_return_sibling_correlation_test.go index bd52479c..6cd2ce18 100644 --- a/compiler/check/tests/regression/error_return_sibling_correlation_test.go +++ b/compiler/check/tests/regression/error_return_sibling_correlation_test.go @@ -19,11 +19,11 @@ import ( // // and ensures the second assertion does not collapse err to never. func TestRegression_ErrorReturnSiblingCorrelationDirection(t *testing.T) { - notNilEffect := constraint.NewEffect( + notNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.NotNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/tests/regression/error_return_union_errorlike_test.go b/compiler/check/tests/regression/error_return_union_errorlike_test.go new file mode 100644 index 00000000..4fce109c --- /dev/null +++ b/compiler/check/tests/regression/error_return_union_errorlike_test.go @@ -0,0 +1,65 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestErrorReturnConvention_UnionOfStringAndLuaErrorStillNarrowsValue(t *testing.T) { + result := testutil.Check(` +type GenError = { + message: string, +} + +local id_source = {} +function id_source.v7(): (string, GenError?) + return "id", nil +end + +type ActiveSession = { + pid: any, +} + +local active_sessions = {} :: {[string]: ActiveSession} + +local function create_session(payload_data) + if not payload_data then + return nil, "missing payload" + end + + local session_id = payload_data.session_id + if not session_id then + local id, err = id_source.v7() + if err then + return nil, err + end + session_id = id + end + + if not payload_data.start_token then + return nil, "missing token" + end + + return session_id, nil +end + +local function use_session(payload_data) + local created_session_id, err = create_session(payload_data) + if err then + return + end + + local recovered_session_info = active_sessions[created_session_id] + if recovered_session_info then + return recovered_session_info.pid + end +end + +return use_session +`, testutil.WithStdlib()) + + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/false_positives_unit_test.go b/compiler/check/tests/regression/false_positives_unit_test.go index da25ec24..c311045c 100644 --- a/compiler/check/tests/regression/false_positives_unit_test.go +++ b/compiler/check/tests/regression/false_positives_unit_test.go @@ -13,20 +13,20 @@ import ( // False Positive Reproductions from wippy lint // These tests document bugs that produce false positive errors -// 1. E0004: Field access on {} type with bracket notation -// Pattern: method_names["greet"] where method_names is a map/table - -func TestFalsePositive_BracketNotationOnMap(t *testing.T) { +// 1. Bracket notation on maps remains soundly optional until presence is proven. +func TestBracketNotationOnMap_GuardedAccess(t *testing.T) { source := ` local method_names: {[string]: string} = { greet = "hello", farewell = "goodbye" } + local maybe_name: string? = method_names["greet"] + assert(method_names["greet"]) local name: string = method_names["greet"] ` result := testutil.Check(source, testutil.WithStdlib()) if result.HasError() { - t.Errorf("expected no errors for bracket notation on map, got: %v", testutil.ErrorMessages(result.Diagnostics)) + t.Errorf("expected no errors for guarded bracket notation on map, got: %v", testutil.ErrorMessages(result.Diagnostics)) } } diff --git a/compiler/check/tests/regression/field_defined_wrapper_return_test.go b/compiler/check/tests/regression/field_defined_wrapper_return_test.go new file mode 100644 index 00000000..f6625e1d --- /dev/null +++ b/compiler/check/tests/regression/field_defined_wrapper_return_test.go @@ -0,0 +1,133 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +// Regression guard: field-defined wrapper functions must not freeze a stale +// return when they call through a mutable captured table path. +func TestFieldDefinedWrapperReturnTracksVisibleCapturedFieldWrite(t *testing.T) { + result := testutil.Check(` +local M = { + dep = { + get = function() + return nil + end, + }, +} + +function M.run() + return M.dep.get() +end + +M.dep = { + get = function() + return { answer = "ok" } + end, +} + +local res = M.run() +local answer: string = res.answer +return answer +`) + if result.HasError() { + t.Fatalf("expected field-defined wrapper to see the visible reassigned field result, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +// Soundness guard: the wrapper must stay nilable when the write is not +// dominating on all paths. +func TestFieldDefinedWrapperReturnRequiresDominatingVisibleWrite(t *testing.T) { + result := testutil.Check(` +local function run(flag: boolean) + local M = { + dep = { + get = function() + return nil + end, + }, + } + + function M.run() + return M.dep.get() + end + + if flag then + M.dep = { + get = function() + return { answer = "ok" } + end, + } + end + + local res = M.run() + local answer: string = res.answer + return answer +end +`) + if !result.HasError() { + t.Fatalf("expected non-dominating wrapper write to remain nilable after join") + } +} + +func TestFieldDefinedWrapperReturnPreservedThroughLocalAlias(t *testing.T) { + result := testutil.Check(` +type Res = { answer: string } + +local M = { + dep = { + get = function() + return nil + end, + }, +} + +function M.run() + return M.dep.get() +end + +M.dep = { + get = function() + return { answer = "ok" } + end, +} + +local f: fun(): Res = M.run +local res = f() +local answer: string = res.answer +return answer +`) + if result.HasError() { + t.Fatalf("expected aliased field-defined wrapper to preserve visible reassigned field result, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFieldDefinedWrapperReturnLocalAliasRespectsCurrentFunctionValue(t *testing.T) { + result := testutil.Check(` +type Res = { answer: string } + +local M = { + dep = { + get = function() + return nil + end, + }, +} + +function M.run() + return M.dep.get() +end + +M.run = function() + return nil +end + +local f: fun(): Res = M.run +return f +`) + if !result.HasError() { + t.Fatalf("expected aliased wrapper value typing to respect current reassigned function value") + } +} diff --git a/compiler/check/tests/regression/fresh_record_expected_literal_union_test.go b/compiler/check/tests/regression/fresh_record_expected_literal_union_test.go new file mode 100644 index 00000000..c5a371ee --- /dev/null +++ b/compiler/check/tests/regression/fresh_record_expected_literal_union_test.go @@ -0,0 +1,50 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestFreshRecordExpectedLiteralUnion_ReturnContext(t *testing.T) { + result := testutil.Check(` +type StatusCode = 200 | 201 | 400 +type Response = { + status: StatusCode, + body: any?, + headers: {[string]: string}, +} + +local function ok(body: any?): Response + return {status = 200, body = body, headers = {}} +end + +local function created(body: any?): Response + return {status = 201, body = body, headers = {}} +end +`, testutil.WithStdlib()) + + if result.HasError() { + t.Fatalf("expected no errors for fresh return literals under expected record type, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFreshRecordExpectedLiteralUnion_AssignmentContext(t *testing.T) { + result := testutil.Check(` +type StatusCode = 200 | 201 | 400 +type Response = { + status: StatusCode, + body: any?, + headers: {[string]: string}, +} + +local ok: Response = {status = 200, body = nil, headers = {}} +local created: Response = {status = 201, body = nil, headers = {}} + +return ok, created +`, testutil.WithStdlib()) + + if result.HasError() { + t.Fatalf("expected no errors for fresh assignment literals under expected record type, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/index_presence_law_test.go b/compiler/check/tests/regression/index_presence_law_test.go new file mode 100644 index 00000000..4d8f3fe5 --- /dev/null +++ b/compiler/check/tests/regression/index_presence_law_test.go @@ -0,0 +1,430 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestIndexPresence_TruthinessGuard_RepeatedLiteralLookup(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = { + ["root"] = { + _topic = "hello", + topic = function(self: Message): string + return self._topic + end, + }, + } + + if messages["root"] then + local topic: string = messages["root"]:topic() + end + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_NilCheck_RepeatedLiteralLookup(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = { + ["root"] = { + _topic = "hello", + topic = function(self: Message): string + return self._topic + end, + }, + } + + if messages["root"] ~= nil then + local topic: string = messages["root"]:topic() + end + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_Assert_RepeatedLiteralLookup(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = { + ["root"] = { + _topic = "hello", + topic = function(self: Message): string + return self._topic + end, + }, + } + + assert(messages["root"]) + local topic: string = messages["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_DominatingLiteralAssignment_MakesLookupDefinite(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + messages["root"] = { + _topic = "installed", + topic = function(self: Message): string + return self._topic + end, + } + + local topic: string = messages["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_DominatingLiteralAssignment_SurvivesOtherKeys(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + messages["root"] = { + _topic = "installed", + topic = function(self: Message): string + return self._topic + end, + } + local other = "other" + messages[other] = { + _topic = "side", + topic = function(self: Message): string + return self._topic + end, + } + + local topic: string = messages["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_HybridFieldAndMap_LiteralFieldStaysExact(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local t = {} + t.root = { + _topic = "hybrid", + topic = function(self: Message): string + return self._topic + end, + } + local key: string = "other" + t[key] = { + _topic = "mapped", + topic = function(self: Message): string + return self._topic + end, + } + + local topic: string = t["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_ConstResolvedStringKey_UsesStaticPathLaw(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = { + ["root"] = { + _topic = "hello", + topic = function(self: Message): string + return self._topic + end, + }, + } + local key = "root" + + if messages[key] then + local topic: string = messages[key]:topic() + end + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_KeyOfPairsLoop_ProvesDynamicKeyPresence(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = { + ["root"] = { + _topic = "hello", + topic = function(self: Message): string + return self._topic + end, + }, + ["other"] = { + _topic = "world", + topic = function(self: Message): string + return self._topic + end, + }, + } + + for key, _ in pairs(messages) do + local topic: string = messages[key]:topic() + end + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_OverwriteWithNil_DirectLookupMustFail(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + messages["root"] = { + _topic = "hello", + topic = function(self: Message): string + return self._topic + end, + } + messages["root"] = nil + + local topic: string = messages["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if !result.HasError() { + t.Fatal("expected overwrite-with-nil direct lookup to require a nil check") + } +} + +func TestIndexPresence_OverwriteWithNil_GuardedLookupSucceeds(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + messages["root"] = { + _topic = "hello", + topic = function(self: Message): string + return self._topic + end, + } + messages["root"] = nil + + if messages["root"] then + local topic: string = messages["root"]:topic() + end + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_JoinedInstallOnBothBranches_IsDefiniteAfterJoin(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + local cond = true + if cond then + messages["root"] = { + _topic = "a", + topic = function(self: Message): string + return self._topic + end, + } + else + messages["root"] = { + _topic = "b", + topic = function(self: Message): string + return self._topic + end, + } + end + + local topic: string = messages["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_NotPresentGuardThenInstall_IsDefiniteAfterJoin(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + if not messages["root"] then + messages["root"] = { + _topic = "installed", + topic = function(self: Message): string + return self._topic + end, + } + end + + local topic: string = messages["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_JoinedInstallOnlyOnOneBranch_DirectLookupMustFail(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + local cond = true + if cond then + messages["root"] = { + _topic = "a", + topic = function(self: Message): string + return self._topic + end, + } + end + + local topic: string = messages["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if !result.HasError() { + t.Fatal("expected branch-local installation to require a nil check after join") + } +} + +func TestIndexPresence_JoinedInstallOnlyOnOneBranch_GuardedLookupSucceeds(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + local cond = true + if cond then + messages["root"] = { + _topic = "a", + topic = function(self: Message): string + return self._topic + end, + } + end + + if messages["root"] then + local topic: string = messages["root"]:topic() + end + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestIndexPresence_JoinedNilOnOneBranch_DirectLookupMustFail(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local messages: {[string]: Message} = {} + messages["root"] = { + _topic = "a", + topic = function(self: Message): string + return self._topic + end, + } + local cond = true + if cond then + messages["root"] = nil + end + + local topic: string = messages["root"]:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if !result.HasError() { + t.Fatal("expected niling a key on one branch to require a nil check after join") + } +} diff --git a/compiler/check/tests/regression/llm_error_record_optional_fields_test.go b/compiler/check/tests/regression/llm_error_record_optional_fields_test.go index 6cd3c9cb..7b10252f 100644 --- a/compiler/check/tests/regression/llm_error_record_optional_fields_test.go +++ b/compiler/check/tests/regression/llm_error_record_optional_fields_test.go @@ -10,7 +10,7 @@ import ( ) func TestRegression_LLMStyleErrorRecordOptionalFields(t *testing.T) { - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/tests/regression/local_function_narrow_return_repair_test.go b/compiler/check/tests/regression/local_function_narrow_return_repair_test.go new file mode 100644 index 00000000..715b2bf2 --- /dev/null +++ b/compiler/check/tests/regression/local_function_narrow_return_repair_test.go @@ -0,0 +1,170 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/tests/testutil" + "github.com/wippyai/go-lua/types/typ" +) + +func TestLocalFunctionNarrowReturnRepairsPreflowNeverSummary(t *testing.T) { + source := ` +local function f(blocks) + local tool_use_block = nil + for _, block in ipairs(blocks) do + if block.type == "tool_use" and block.name == "structured_output" then + tool_use_block = block + break + end + end + if not tool_use_block then + return { success = false, error = "missing" } + end + return { success = true, result = { data = tool_use_block.input } } +end + +return { f = f } +` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected clean check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } + + sess := result.Session + if sess == nil || sess.RootResult == nil || sess.RootResult.Graph == nil { + t.Fatal("missing root result") + } + + var sym cfg.SymbolID + sess.RootResult.Graph.EachAssign(func(_ cfg.Point, info *cfg.AssignInfo) { + if info == nil || !info.IsLocal { + return + } + info.EachTargetSource(func(_ int, target cfg.AssignTarget, source ast.Expr) { + if target.Kind == cfg.TargetIdent && target.Name == "f" { + if _, ok := source.(*ast.FunctionExpr); ok { + sym = target.Symbol + } + } + }) + }) + if sym == 0 { + t.Fatal("missing local function symbol for f") + } + + parentHash := sess.Store.GraphParentHashOf(sess.RootResult.Graph.ID()) + parent := sess.Store.Parents()[parentHash] + snap := sess.Store.GetInterprocFactsSnapshot(sess.RootResult.Graph, parent) + + if got := snap.ReturnSummaries[sym]; len(got) != 1 || containsNever(got[0]) { + t.Fatalf("summary contains never artifact: %v", got) + } + if got := snap.NarrowReturns[sym]; len(got) != 1 || containsNever(got[0]) { + t.Fatalf("narrow contains never artifact: %v", got) + } + if got := snap.FuncTypes[sym]; got == nil || containsNever(got) { + t.Fatalf("function fact contains never artifact: %v", got) + } + + mod := testutil.CheckAndExport(source, "mod", testutil.WithStdlib()) + if mod.HasError() { + t.Fatalf("expected clean export check, got: %v", mod.Errors) + } + wantExport := typ.NewRecord(). + Field("f", typ.Func(). + OptParam("blocks", typ.Any). + Returns( + typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().Field("data", typ.Unknown).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ), + ). + Build()). + Build() + if mod.Manifest == nil || !typ.TypeEquals(mod.Manifest.Export, wantExport) { + t.Fatalf("export = %v, want %v", mod.Manifest.Export, wantExport) + } +} + +func containsNever(t typ.Type) bool { + if t == nil { + return false + } + if typ.IsNever(t) { + return true + } + return typ.Visit(t, typ.Visitor[bool]{ + Optional: func(o *typ.Optional) bool { + return containsNever(o.Inner) + }, + Union: func(u *typ.Union) bool { + for _, m := range u.Members { + if containsNever(m) { + return true + } + } + return false + }, + Intersection: func(in *typ.Intersection) bool { + for _, m := range in.Members { + if containsNever(m) { + return true + } + } + return false + }, + Array: func(a *typ.Array) bool { + return containsNever(a.Element) + }, + Map: func(m *typ.Map) bool { + return containsNever(m.Key) || containsNever(m.Value) + }, + Tuple: func(tup *typ.Tuple) bool { + for _, e := range tup.Elements { + if containsNever(e) { + return true + } + } + return false + }, + Record: func(r *typ.Record) bool { + for _, f := range r.Fields { + if containsNever(f.Type) { + return true + } + } + if r.HasMapComponent() { + return containsNever(r.MapKey) || containsNever(r.MapValue) + } + return false + }, + Function: func(fn *typ.Function) bool { + for _, p := range fn.Params { + if containsNever(p.Type) { + return true + } + } + if fn.Variadic != nil && containsNever(fn.Variadic) { + return true + } + for _, ret := range fn.Returns { + if containsNever(ret) { + return true + } + } + return false + }, + Default: func(typ.Type) bool { + return false + }, + }) +} diff --git a/compiler/check/tests/regression/metatable_shared_self_test.go b/compiler/check/tests/regression/metatable_shared_self_test.go new file mode 100644 index 00000000..45ae248b --- /dev/null +++ b/compiler/check/tests/regression/metatable_shared_self_test.go @@ -0,0 +1,147 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +// Shared-self law: setmetatable-backed constructors produce instance values, +// not class-table values, and methods returning self preserve that instance type. +func TestMetatableSharedSelf_LocalConstructorAndMethods(t *testing.T) { + source := ` + type Counter = { + value: number, + inc: (self: Counter) -> Counter, + get: (self: Counter) -> number, + } + + local Counter = {} + Counter.__index = Counter + + function Counter.new(): Counter + local self: Counter = { + value = 0, + inc = Counter.inc, + get = Counter.get, + } + setmetatable(self, Counter) + return self + end + + function Counter:inc(): Counter + self.value = self.value + 1 + return self + end + + function Counter:get(): number + return self.value + end + + local c: Counter = Counter.new() + local next_counter: Counter = c:inc() + local value: number = next_counter:get() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +// Cross-module law: exported constructors must preserve the instance/self contract +// so methods returning self remain callable downstream. +func TestMetatableSharedSelf_CrossModuleConstructorExport(t *testing.T) { + builderSource := ` + type Builder = { + _name: string, + rename: (self: Builder, name: string) -> Builder, + name: (self: Builder) -> string, + } + + local Builder = {} + Builder.__index = Builder + + function Builder.new(name: string): Builder + local self: Builder = { + _name = name, + rename = Builder.rename, + name = Builder.name, + } + setmetatable(self, Builder) + return self + end + + function Builder:rename(name: string): Builder + self._name = name + return self + end + + function Builder:name(): string + return self._name + end + + local M = {} + M.new = Builder.new + return M + ` + + builderModule := testutil.CheckAndExport(builderSource, "builder", testutil.WithStdlib()) + if builderModule.HasError() { + t.Fatalf("builder module should export cleanly, got: %v", testutil.ErrorMessages(builderModule.Errors)) + } + + consumerSource := ` + local builder = require("builder") + + local b = builder.new("first") + local renamed = b:rename("second") + local name: string = renamed:name() + ` + + result := testutil.Check( + consumerSource, + testutil.WithStdlib(), + testutil.WithModule("builder", builderModule), + ) + if result.HasError() { + t.Fatalf("expected downstream metatable-backed methods to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +// Receiver law: once an instance is known to be Session, metatable-backed +// methods must see the instance fields promised by that type. +func TestMetatableSharedSelf_MethodSeesInstanceFields(t *testing.T) { + source := ` + type Session = { + session_id: string, + user_id: string, + describe: (self: Session) -> string, + } + + local Session = {} + Session.__index = Session + + function Session:describe(): string + return self.session_id .. ":" .. self.user_id + end + + function Session.new(session_id: string, user_id: string): Session + local self: Session = { + session_id = session_id, + user_id = user_id, + describe = Session.describe, + } + setmetatable(self, Session) + return self + end + + local s: Session = Session.new("s1", "u1") + local label: string = s:describe() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/module_error_record_union_coalesce_test.go b/compiler/check/tests/regression/module_error_record_union_coalesce_test.go index ec1bf8dd..508b966e 100644 --- a/compiler/check/tests/regression/module_error_record_union_coalesce_test.go +++ b/compiler/check/tests/regression/module_error_record_union_coalesce_test.go @@ -13,7 +13,7 @@ import ( // error record members must be coalesced before export/import so assert-based // nil checks can correlate sibling returns without spurious field-missing errors. func TestRegression_ModuleErrorReturnUnionCoalescesAcrossBoundary(t *testing.T) { - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/tests/regression/module_function_alias_return_test.go b/compiler/check/tests/regression/module_function_alias_return_test.go new file mode 100644 index 00000000..e6dbdb29 --- /dev/null +++ b/compiler/check/tests/regression/module_function_alias_return_test.go @@ -0,0 +1,68 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" + "github.com/wippyai/go-lua/types/io" +) + +// Regression guard: imported function type aliases must stay callable even +// after local aliasing through a guarded map lookup, and their returned union +// should remain precise enough for discriminant narrowing. +func TestModuleFunctionAliasReturnPreservedAcrossGuardedMapLookup(t *testing.T) { + producerModule := testutil.CheckAndExport(` +type Payload = {id: string} +type Outcome = {ok: true, value: Payload} | {ok: false, error: {message: string}} +type Handler = (string) -> Outcome + +local M = {} +M.Payload = Payload +M.Outcome = Outcome +M.Handler = Handler + +return M +`, "producer", testutil.WithStdlib()) + if producerModule.HasError() { + t.Fatalf("unexpected producer errors: %v", testutil.ErrorMessages(producerModule.Errors)) + } + + encoded, err := io.EncodeManifest(producerModule.Manifest) + if err != nil { + t.Fatalf("EncodeManifest failed: %v", err) + } + decoded, err := io.DecodeManifest(encoded) + if err != nil { + t.Fatalf("DecodeManifest failed: %v", err) + } + + result := testutil.Check(` +local producer = require("producer") + +type Handler = producer.Handler + +local handlers: {[string]: Handler} = { + ["a"] = function(s: string) + return {ok = true, value = {id = s}} + end, +} + +local h = handlers["a"] +if not h then + return nil +end + +local out = h("x") +if out.ok then + local id: string = out.value.id +else + local msg: string = out.error.message +end +`, + testutil.WithStdlib(), + testutil.WithManifest("producer", decoded), + ) + if result.HasError() { + t.Fatalf("expected imported function alias call to preserve return union, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/module_map_alias_return_test.go b/compiler/check/tests/regression/module_map_alias_return_test.go new file mode 100644 index 00000000..1c917094 --- /dev/null +++ b/compiler/check/tests/regression/module_map_alias_return_test.go @@ -0,0 +1,42 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +// Regression guard: map-shaped type aliases exported from a module must survive +// import on function returns and through `or` fallback. +func TestModuleMapAliasReturnPreservedAcrossImport(t *testing.T) { + contextModule := testutil.CheckAndExport(` +type Context = {[string]: any} + +local M = {} +M.Context = Context + +function M.empty(): Context + return {} +end + +return M +`, "context", testutil.WithStdlib()) + if contextModule.HasError() { + t.Fatalf("unexpected context module errors: %v", testutil.ErrorMessages(contextModule.Errors)) + } + + result := testutil.Check(` +local context = require("context") + +function with_default(initial: context.Context?): context.Context + local ctx = initial or context.empty() + return ctx +end +`, + testutil.WithStdlib(), + testutil.WithModule("context", contextModule), + ) + if result.HasError() { + t.Fatalf("expected imported map alias return to survive fallback, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/nested_field_type_guard_arithmetic_test.go b/compiler/check/tests/regression/nested_field_type_guard_arithmetic_test.go new file mode 100644 index 00000000..070c989a --- /dev/null +++ b/compiler/check/tests/regression/nested_field_type_guard_arithmetic_test.go @@ -0,0 +1,119 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" + "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/contract" + "github.com/wippyai/go-lua/types/io" + "github.com/wippyai/go-lua/types/typ" +) + +// Regression guard: nested field type() checks must make the guarded field +// usable in arithmetic, not collapse it to never. +func TestNestedFieldTypeGuardPreservesArithmetic(t *testing.T) { + result := testutil.Check(` +type PayloadCarrier = { + data: fun(self: PayloadCarrier): any, +} + +local function bump(carrier: PayloadCarrier?) + local data = carrier and carrier:data() or nil + if type(data) ~= "table" or type(data.amount) ~= "number" then + return nil + end + + local incremented = data.amount + 1 + local exact: number = data.amount + return incremented, exact +end + +return bump +`, testutil.WithStdlib()) + + if result.HasError() { + t.Fatalf("expected nested field type guard to preserve arithmetic, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestNestedFieldTypeGuardPreservesArithmeticInTemporalLoopShape(t *testing.T) { + messageType := typ.NewInterface("process.Message", []typ.Method{ + {Name: "from", Type: typ.Func().Param("self", typ.Self).Returns(typ.String).Build()}, + {Name: "payload", Type: typ.Func().Param("self", typ.Self).Returns(typ.Any).Build()}, + }) + + chManifest := testutil.ChannelManifest() + channelGen, _ := chManifest.LookupType("Channel") + channelGeneric := channelGen.(*typ.Generic) + + messageChannelType := typ.Instantiate(channelGeneric, messageType) + rawChannelType := typ.Instantiate(channelGeneric, typ.Any) + + listenSpec := contract.NewSpec().WithReturnCase( + constraint.FromConjunction(constraint.NewConjunction(constraint.FieldEquals{ + Target: constraint.ParamPath(1), + Field: "message", + Value: typ.True, + })), + messageChannelType, + ) + + processModule := typ.NewInterface("process", []typ.Method{ + {Name: "listen", Type: typ.Func(). + Param("topic", typ.String). + OptParam("options", typ.Any). + Returns(rawChannelType, typ.NewOptional(typ.LuaError)). + Spec(listenSpec). + Build()}, + {Name: "send", Type: typ.Func(). + Param("pid", typ.String). + Param("topic", typ.String). + Variadic(typ.Any). + Returns(typ.Boolean, typ.NewOptional(typ.String)). + Build()}, + }) + + processManifest := io.NewManifest("process") + processManifest.SetExport(processModule) + processManifest.DefineType("Message", messageType) + + result := testutil.Check(` +local counter = 0 +local done = false + +coroutine.spawn(function() + local ch = process.listen("increment", {message = true}) + while not done do + local msg, ok = ch:receive() + if not ok then + break + end + + local p = msg:payload() + local data = p and p:data() or nil + local reply_to = msg:from() + + if type(data) ~= "table" or type(data.amount) ~= "number" then + process.send(reply_to, "nak", "amount must be a number") + else + process.send(reply_to, "ack") + local amount_sanity = data.amount + 1 + counter = counter + data.amount + counter = amount_sanity - 1 + end + end +end) +`, + testutil.WithStdlib(), + testutil.WithManifest("channel", chManifest), + testutil.WithManifest("process", processManifest), + ) + + if result.HasError() { + for _, d := range result.Diagnostics { + t.Logf("diagnostic at %d:%d: %s", d.Position.Line, d.Position.Column, d.Message) + } + t.Fatalf("expected temporal loop shape to preserve nested-field arithmetic, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/optional_discriminant_variant_binding_test.go b/compiler/check/tests/regression/optional_discriminant_variant_binding_test.go new file mode 100644 index 00000000..9011f2e9 --- /dev/null +++ b/compiler/check/tests/regression/optional_discriminant_variant_binding_test.go @@ -0,0 +1,44 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +// Regression guard: after a local optional union is proven non-nil, a later +// discriminant check must still narrow the whole value to the matching variant. +func TestOptionalDiscriminantNarrowingSupportsVariantBinding(t *testing.T) { + result := testutil.Check(` +type Allow = {kind: "allow", reason: string} +type Deny = {kind: "deny", reason: string} +type Defer = {kind: "defer", queue: string} +type Decision = Allow | Deny | Defer +type Outcome = {ok: true, value: Decision?} | {ok: false, error: string} + +local outcome: Outcome = { + ok = true, + value = { + kind = "defer", + queue = "review", + }, +} + +if not outcome.ok then + return +end + +local decision = outcome.value +if not decision then + return +end + +if decision.kind == "defer" then + local deferred: Defer = decision + local queue: string = decision.queue + end +`, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected optional discriminant binding to narrow variant, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/reassigned_field_call_assignment_test.go b/compiler/check/tests/regression/reassigned_field_call_assignment_test.go new file mode 100644 index 00000000..8a5347ae --- /dev/null +++ b/compiler/check/tests/regression/reassigned_field_call_assignment_test.go @@ -0,0 +1,65 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +// Regression guard: when a reassigned field call synthesizes a precise +// single-value result, local assignment must preserve that result type. +func TestReassignedFieldCallAssignmentPreservesDirectResult(t *testing.T) { + result := testutil.Check(` +local M = { + dep = { + get = function() + return nil + end, + }, +} + +M.dep = { + get = function() + return { answer = "ok" } + end, +} + +local res = M.dep.get() +local answer: string = res.answer +return answer +`) + if result.HasError() { + t.Fatalf("expected direct call assignment to preserve reassigned field result, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +// Soundness guard: a field write from a non-dominating branch must not be +// treated as definitely visible after the join. +func TestReassignedFieldCallAssignmentRequiresDominatingWrite(t *testing.T) { + result := testutil.Check(` +local function run(flag: boolean) + local M = { + dep = { + get = function() + return nil + end, + }, + } + + if flag then + M.dep = { + get = function() + return { answer = "ok" } + end, + } + end + + local res = M.dep.get() + local answer: string = res.answer + return answer +end +`) + if !result.HasError() { + t.Fatalf("expected non-dominating branch write to remain nilable after join") + } +} diff --git a/compiler/check/tests/regression/recursive_alias_method_chain_test.go b/compiler/check/tests/regression/recursive_alias_method_chain_test.go new file mode 100644 index 00000000..3384e72f --- /dev/null +++ b/compiler/check/tests/regression/recursive_alias_method_chain_test.go @@ -0,0 +1,34 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +// Recursive named builder aliases must preserve their receiver type through +// chained method returns. +func TestRecursiveAliasMethodChain(t *testing.T) { + source := ` + type Builder = { + f: (self: Builder) -> Builder, + g: (self: Builder) -> number, + } + + local b: Builder = { + f = function(self: Builder): Builder + return self + end, + g = function(self: Builder): number + return 1 + end, + } + + local n: number = b:f():g() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/recursive_alias_return_literal_test.go b/compiler/check/tests/regression/recursive_alias_return_literal_test.go new file mode 100644 index 00000000..5a896fc0 --- /dev/null +++ b/compiler/check/tests/regression/recursive_alias_return_literal_test.go @@ -0,0 +1,67 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestRecursiveAliasReturnLiteralWithSelfMethod(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + local function make(): Message + return { + _topic = "test", + topic = function(s: Message): string + return s._topic + end, + } + end + + local msg = make() + local topic: string = msg:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRecursiveAliasReturnLiteralNestedInRecordField(t *testing.T) { + source := ` + type Message = { + _topic: string, + topic: (self: Message) -> string, + } + + type MsgCh = {__tag: "msg"} + type Result = {channel: MsgCh, value: Message, ok: boolean} + + local function select_fn(msg_ch: MsgCh): Result + return { + channel = msg_ch, + value = { + _topic = "test", + topic = function(s: Message): string + return s._topic + end, + }, + ok = true, + } + end + + local msg_ch: MsgCh = {__tag = "msg"} + local result = select_fn(msg_ch) + local topic: string = result.value:topic() + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/session_plugin_soundness_test.go b/compiler/check/tests/regression/session_plugin_soundness_test.go new file mode 100644 index 00000000..4fe50071 --- /dev/null +++ b/compiler/check/tests/regression/session_plugin_soundness_test.go @@ -0,0 +1,56 @@ +package regression + +import ( + "strings" + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestSessionPlugin_UntypedSessionIDGuardStillRejectsStringAPI(t *testing.T) { + result := testutil.Check(` +type ActiveSession = { + pid: any, +} + +local active_sessions = {} :: {[string]: ActiveSession} + +local function graceful_terminate_session(session_id: string, session_info: ActiveSession, reason: string) + return +end + +local function handle_session_close(payload_data) + if not payload_data then + return + end + + local session_id = payload_data.session_id + if not session_id then + return + end + + local session_info = active_sessions[session_id] + if session_info then + graceful_terminate_session(session_id, session_info, "user_closed") + end +end + +return handle_session_close +`, testutil.WithStdlib()) + + if !result.HasError() { + t.Fatalf("expected error, got none") + } + + msgs := testutil.ErrorMessages(result.Diagnostics) + found := false + for _, msg := range msgs { + if strings.Contains(msg, "expected string, got any") { + found = true + break + } + } + if !found { + t.Fatalf("expected string/any diagnostic, got %v", msgs) + } +} diff --git a/compiler/check/tests/regression/string_unpack_literal_format_test.go b/compiler/check/tests/regression/string_unpack_literal_format_test.go new file mode 100644 index 00000000..33ae4a83 --- /dev/null +++ b/compiler/check/tests/regression/string_unpack_literal_format_test.go @@ -0,0 +1,27 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestFP_StringUnpackLiteralFormatsProduceIntegers(t *testing.T) { + source := ` + local function parse(buf: string) + local total_length: integer = string.unpack(">I4", buf, 1) + local fmt = ">I4" + local headers_length: integer = string.unpack(fmt, buf, 5) + + local payload_offset: integer = 13 + headers_length + local payload_length: integer = total_length - 12 - headers_length - 4 + local payload = buf:sub(payload_offset, payload_offset + payload_length - 1) + return payload + end + ` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors for string.unpack literal formats, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/table_field_error_return_correlation_test.go b/compiler/check/tests/regression/table_field_error_return_correlation_test.go index fef1f593..06ef99c7 100644 --- a/compiler/check/tests/regression/table_field_error_return_correlation_test.go +++ b/compiler/check/tests/regression/table_field_error_return_correlation_test.go @@ -10,7 +10,7 @@ import ( ) func TestRegression_TableFieldFunctionErrorReturnCorrelation(t *testing.T) { - isNilEffect := constraint.NewEffect( + isNilEffect := constraint.NewRefinement( []constraint.Constraint{constraint.IsNil{Path: constraint.Path{Root: "$0"}}}, nil, nil, ) diff --git a/compiler/check/tests/regression/tool_handler_registry_return_test.go b/compiler/check/tests/regression/tool_handler_registry_return_test.go new file mode 100644 index 00000000..96481db1 --- /dev/null +++ b/compiler/check/tests/regression/tool_handler_registry_return_test.go @@ -0,0 +1,468 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" + "github.com/wippyai/go-lua/types/io" +) + +const toolHandlerProtocolModule = ` +type AppError = { + code: string, + message: string, + retryable: boolean, +} + +type ToolCallMessage = { + tool: string, + arguments: {[string]: any}, +} + +type ToolResult = { + tool: string, + content: string, + cached: boolean, +} + +type SessionState = { + flags: {[string]: boolean}, +} + +type ToolResultResult = {ok: true, value: ToolResult} | {ok: false, error: AppError} +type ToolHandler = (SessionState, ToolCallMessage) -> ToolResultResult + +local M = {} +M.AppError = AppError +M.ToolCallMessage = ToolCallMessage +M.ToolResult = ToolResult +M.SessionState = SessionState +M.ToolResultResult = ToolResultResult +M.ToolHandler = ToolHandler + +return M +` + +const toolHandlerBuilderModule = ` +local protocol = require("protocol") + +type ToolResultResult = protocol.ToolResultResult + +local M = {} + +function M.build(): protocol.ToolHandler + return function(state: protocol.SessionState, msg: protocol.ToolCallMessage): ToolResultResult + local value = msg.arguments["value"] + if type(value) ~= "string" then + return { + ok = false, + error = { + code = "invalid", + message = "value must be string", + retryable = false, + }, + } + end + + if state.flags["flagged"] then + return { + ok = true, + value = { + tool = msg.tool, + content = "flagged:" .. value, + cached = false, + }, + } + end + + return { + ok = true, + value = { + tool = msg.tool, + content = value, + cached = false, + }, + } + end +end + +return M +` + +func exportModule(t *testing.T, name, source string, opts ...testutil.Option) *io.Manifest { + t.Helper() + + baseOpts := []testutil.Option{testutil.WithStdlib()} + baseOpts = append(baseOpts, opts...) + result := testutil.CheckAndExport(source, name, baseOpts...) + if result.HasError() { + t.Fatalf("%s export failed: %v", name, result.Errors) + } + encoded, err := io.EncodeManifest(result.Manifest) + if err != nil { + t.Fatalf("EncodeManifest(%s) failed: %v", name, err) + } + decoded, err := io.DecodeManifest(encoded) + if err != nil { + t.Fatalf("DecodeManifest(%s) failed: %v", name, err) + } + return decoded +} + +func TestToolHandlerReturnPrecisionAcrossRegistryLayers(t *testing.T) { + protocol := exportModule(t, "protocol", toolHandlerProtocolModule) + builder := exportModule(t, "builder", toolHandlerBuilderModule, testutil.WithManifest("protocol", protocol)) + + opts := []testutil.Option{ + testutil.WithStdlib(), + testutil.WithManifest("protocol", protocol), + testutil.WithManifest("builder", builder), + } + + t.Run("direct builder call", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +local handler: protocol.ToolHandler = builder.build() +local out = handler({flags = {}}, {tool = "search", arguments = {value = "x"}}) + +if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content +else + local code: string = out.error.code + local retryable: boolean = out.error.retryable +end +`, opts...) + if result.HasError() { + t.Fatalf("direct builder call lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("map lookup call", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +local handlers: {[string]: protocol.ToolHandler} = { + search = builder.build(), +} + +local handler = handlers["search"] +if not handler then + return nil +end + +local out = handler({flags = {}}, {tool = "search", arguments = {value = "x"}}) +if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content +else + local code: string = out.error.code + local retryable: boolean = out.error.retryable +end +`, opts...) + if result.HasError() { + t.Fatalf("map lookup call lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("receiver field map lookup call", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +type Engine = { + handlers: {[string]: protocol.ToolHandler}, + run: (self: Engine, name: string) -> (), +} + +local Engine = {} +Engine.__index = Engine + +function Engine:run(name: string) + local handler = self.handlers[name] + if not handler then + return nil + end + + local out = handler({flags = {}}, {tool = name, arguments = {value = "x"}}) + if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content + else + local code: string = out.error.code + local retryable: boolean = out.error.retryable + end +end + +local e: Engine = { + handlers = {search = builder.build()}, + run = Engine.run, +} + +setmetatable(e, Engine) +e:run("search") +`, opts...) + if result.HasError() { + t.Fatalf("receiver field map lookup call lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("receiver field map lookup with narrowed message", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +type Message = {kind: "user", content: string} | protocol.ToolCallMessage + +type Store = { + state: protocol.SessionState, +} + +type Engine = { + handlers: {[string]: protocol.ToolHandler}, + run: (self: Engine, store: Store, msg: Message) -> (), +} + +local Engine = {} +Engine.__index = Engine + +function Engine:run(store: Store, msg: Message) + if msg.kind ~= "tool_call" then + return nil + end + + local handler = self.handlers[msg.tool] + if not handler then + return nil + end + + local out = handler(store.state, msg) + if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content + else + local code: string = out.error.code + local retryable: boolean = out.error.retryable + end +end + +local e: Engine = { + handlers = {search = builder.build()}, + run = Engine.run, +} + +setmetatable(e, Engine) +e:run({state = {flags = {}}}, {kind = "tool_call", tool = "search", arguments = {value = "x"}}) +`, opts...) + if result.HasError() { + t.Fatalf("receiver field map lookup with narrowed message lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("split receiver field then map lookup with narrowed message", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +type Message = {kind: "user", content: string} | protocol.ToolCallMessage + +type Store = { + state: protocol.SessionState, +} + +type Engine = { + handlers: {[string]: protocol.ToolHandler}, + run: (self: Engine, store: Store, msg: Message) -> (), +} + +local Engine = {} +Engine.__index = Engine + +function Engine:run(store: Store, msg: Message) + if msg.kind ~= "tool_call" then + return nil + end + + local handlers = self.handlers + local handler = handlers[msg.tool] + if not handler then + return nil + end + + local out = handler(store.state, msg) + if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content + else + local code: string = out.error.code + local retryable: boolean = out.error.retryable + end +end + +local e: Engine = { + handlers = {search = builder.build()}, + run = Engine.run, +} + +setmetatable(e, Engine) +e:run({state = {flags = {}}}, {kind = "tool_call", tool = "search", arguments = {value = "x"}}) +`, opts...) + if result.HasError() { + t.Fatalf("split receiver field then map lookup with narrowed message lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("direct builder call with narrowed message", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +type Message = {kind: "user", content: string} | protocol.ToolCallMessage + +local handler: protocol.ToolHandler = builder.build() +local msg: Message = {kind = "tool_call", tool = "search", arguments = {value = "x"}} + +if msg.kind ~= "tool_call" then + return nil +end + +local out = handler({flags = {}}, msg) +if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content +else + local code: string = out.error.code + local retryable: boolean = out.error.retryable +end +`, opts...) + if result.HasError() { + t.Fatalf("direct builder call with narrowed message lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("map lookup call with narrowed message", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +type Message = {kind: "user", content: string} | protocol.ToolCallMessage + +local handlers: {[string]: protocol.ToolHandler} = { + search = builder.build(), +} + +local handler = handlers["search"] +local msg: Message = {kind = "tool_call", tool = "search", arguments = {value = "x"}} +if not handler or msg.kind ~= "tool_call" then + return nil +end + +local out = handler({flags = {}}, msg) +if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content +else + local code: string = out.error.code + local retryable: boolean = out.error.retryable +end +`, opts...) + if result.HasError() { + t.Fatalf("map lookup call with narrowed message lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("function param map lookup with narrowed message", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +type Message = {kind: "user", content: string} | protocol.ToolCallMessage + +local function run(handlers: {[string]: protocol.ToolHandler}, state: protocol.SessionState, msg: Message) + if msg.kind ~= "tool_call" then + return nil + end + + local key: string = msg.tool + local handler = handlers[msg.tool] + if not handler then + return nil + end + + local out = handler(state, msg) + if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content + else + local code: string = out.error.code + local retryable: boolean = out.error.retryable + end +end + +run({search = builder.build()}, {flags = {}}, {kind = "tool_call", tool = "search", arguments = {value = "x"}}) +`, opts...) + if result.HasError() { + t.Fatalf("function param map lookup with narrowed message lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("function param direct handler with narrowed message", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +type Message = {kind: "user", content: string} | protocol.ToolCallMessage + +local function run(handler: protocol.ToolHandler, state: protocol.SessionState, msg: Message) + if msg.kind ~= "tool_call" then + return nil + end + + local out = handler(state, msg) + if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content + else + local code: string = out.error.code + local retryable: boolean = out.error.retryable + end +end + +run(builder.build(), {flags = {}}, {kind = "tool_call", tool = "search", arguments = {value = "x"}}) +`, opts...) + if result.HasError() { + t.Fatalf("function param direct handler with narrowed message lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) + + t.Run("function param map lookup with string key", func(t *testing.T) { + result := testutil.Check(` +local protocol = require("protocol") +local builder = require("builder") + +local function run(handlers: {[string]: protocol.ToolHandler}, state: protocol.SessionState, key: string) + local handler = handlers[key] + if not handler then + return nil + end + + local out = handler(state, {tool = key, arguments = {value = "x"}}) + if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content + else + local code: string = out.error.code + local retryable: boolean = out.error.retryable + end +end + +run({search = builder.build()}, {flags = {}}, "search") +`, opts...) + if result.HasError() { + t.Fatalf("function param map lookup with string key lost precision: %v", testutil.ErrorMessages(result.Diagnostics)) + } + }) +} diff --git a/compiler/check/tests/regression/untyped_callback_or_fallback_test.go b/compiler/check/tests/regression/untyped_callback_or_fallback_test.go new file mode 100644 index 00000000..750c17ce --- /dev/null +++ b/compiler/check/tests/regression/untyped_callback_or_fallback_test.go @@ -0,0 +1,26 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestFP_UntypedCallbackFallbackDoesNotCollapseToZeroArity(t *testing.T) { + source := ` + local function run(callbacks) + callbacks = callbacks or {} + local on_content = callbacks.on_content or function() end + local on_error = callbacks.on_error or function() end + local on_done = callbacks.on_done or function() end + + on_content("x") + on_error({ message = "boom" }) + on_done({ content = "ok" }) + end + ` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors for untyped callback or-fallback, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/parse/parser.go b/compiler/parse/parser.go index 01a1a184..df39bd03 100644 --- a/compiler/parse/parser.go +++ b/compiler/parse/parser.go @@ -219,7 +219,7 @@ const yyEofCode = 1 const yyErrCode = 2 const yyInitialStackSize = 16 -//line parser.go.y:1287 +//line parser.go.y:1303 // nameWithPos holds a name with its token position type nameWithPos struct { @@ -305,13 +305,13 @@ var yyExca = [...]int16{ const yyPrivate = 57344 -const yyLast = 1808 +const yyLast = 1912 var yyAct = [...]int16{ 425, 424, 317, 205, 55, 26, 288, 277, 262, 208, 67, 109, 260, 185, 252, 202, 103, 58, 255, 257, 66, 254, 274, 289, 248, 27, 249, 43, 44, 253, - 52, 248, 255, 249, 256, 150, 151, 57, 459, 59, + 52, 248, 255, 249, 256, 150, 151, 57, 461, 59, 429, 36, 255, 415, 248, 324, 249, 432, 401, 35, 72, 53, 122, 389, 51, 49, 48, 393, 141, 54, 374, 323, 50, 99, 100, 101, 102, 248, 296, 249, @@ -320,8 +320,8 @@ var yyAct = [...]int16{ 256, 444, 159, 325, 213, 155, 212, 184, 201, 295, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, - 182, 183, 248, 248, 249, 249, 373, 483, 74, 481, - 478, 291, 372, 476, 451, 434, 474, 215, 469, 466, + 182, 183, 248, 248, 249, 249, 373, 489, 74, 486, + 483, 291, 372, 481, 451, 434, 478, 215, 471, 468, 449, 445, 388, 390, 303, 246, 371, 147, 225, 450, 227, 450, 371, 229, 25, 450, 242, 243, 450, 240, 450, 154, 450, 74, 328, 241, 53, 322, 385, 239, @@ -330,35 +330,35 @@ var yyAct = [...]int16{ 328, 244, 224, 328, 258, 280, 217, 218, 219, 220, 221, 222, 223, 153, 154, 148, 322, 189, 46, 47, 74, 214, 146, 187, 74, 105, 106, 367, 248, 188, - 249, 96, 292, 297, 196, 197, 473, 198, 156, 199, - 200, 329, 195, 251, 97, 484, 248, 112, 249, 468, + 249, 96, 292, 297, 196, 197, 476, 198, 156, 199, + 200, 329, 195, 251, 97, 490, 248, 112, 249, 470, 298, 192, 191, 190, 193, 304, 194, 248, 98, 249, - 426, 248, 386, 249, 314, 455, 396, 391, 318, 410, + 426, 248, 386, 249, 314, 456, 396, 391, 318, 410, 320, 89, 309, 310, 311, 248, 248, 249, 249, 248, - 104, 249, 380, 313, 312, 68, 486, 280, 482, 332, - 106, 442, 280, 356, 344, 250, 157, 349, 287, 319, - 318, 354, 352, 351, 350, 336, 346, 359, 161, 330, - 357, 248, 241, 249, 334, 358, 251, 204, 365, 204, - 68, 348, 68, 369, 45, 366, 293, 347, 203, 24, - 96, 286, 285, 281, 376, 377, 378, 379, 364, 280, - 384, 375, 206, 97, 387, 283, 381, 34, 43, 44, - 9, 52, 282, 209, 95, 65, 392, 98, 394, 42, - 355, 284, 21, 68, 400, 397, 86, 87, 88, 316, - 89, 315, 403, 402, 29, 308, 41, 405, 411, 247, - 28, 38, 414, 68, 238, 413, 30, 68, 230, 152, - 71, 421, 70, 422, 69, 124, 404, 428, 64, 406, - 430, 416, 417, 418, 419, 32, 60, 123, 45, 31, - 43, 44, 435, 24, 412, 144, 305, 436, 485, 307, - 463, 23, 440, 461, 443, 40, 446, 438, 37, 362, - 363, 361, 447, 437, 433, 409, 398, 453, 300, 454, - 456, 39, 76, 458, 441, 318, 73, 460, 312, 312, - 312, 312, 142, 370, 462, 467, 75, 56, 1, 465, - 158, 470, 471, 306, 472, 245, 207, 96, 186, 475, - 477, 264, 110, 211, 81, 82, 80, 79, 83, 479, - 97, 108, 33, 480, 22, 61, 8, 63, 62, 93, - 94, 95, 76, 3, 98, 301, 77, 78, 91, 92, - 90, 84, 85, 86, 87, 88, 75, 89, 4, 2, - 0, 0, 0, 0, 299, 0, 0, 96, 232, 233, + 104, 249, 487, 313, 312, 68, 329, 280, 380, 332, + 106, 442, 280, 492, 344, 250, 157, 349, 488, 319, + 318, 354, 352, 351, 350, 336, 346, 359, 356, 330, + 357, 248, 241, 249, 334, 358, 251, 204, 365, 477, + 68, 348, 204, 329, 287, 68, 293, 161, 45, 369, + 96, 203, 286, 24, 376, 377, 378, 379, 364, 280, + 384, 375, 285, 97, 387, 281, 381, 206, 43, 44, + 34, 52, 283, 9, 95, 282, 392, 98, 394, 42, + 355, 366, 21, 68, 400, 397, 86, 87, 88, 209, + 89, 284, 403, 402, 29, 316, 41, 405, 411, 315, + 28, 38, 414, 65, 308, 413, 30, 247, 68, 238, + 230, 421, 152, 422, 71, 70, 404, 428, 124, 406, + 430, 416, 417, 418, 419, 32, 69, 123, 45, 31, + 43, 44, 435, 24, 412, 68, 64, 436, 60, 347, + 144, 23, 440, 491, 443, 40, 446, 305, 37, 465, + 307, 463, 447, 362, 363, 361, 438, 453, 437, 455, + 457, 39, 76, 460, 441, 318, 73, 462, 312, 312, + 312, 312, 433, 409, 464, 469, 75, 398, 300, 467, + 142, 472, 473, 370, 474, 56, 1, 96, 158, 306, + 480, 245, 482, 207, 81, 82, 80, 79, 83, 186, + 97, 484, 264, 110, 211, 485, 108, 33, 22, 93, + 94, 95, 76, 61, 98, 8, 77, 78, 91, 92, + 90, 84, 85, 86, 87, 88, 75, 89, 63, 62, + 3, 301, 4, 2, 299, 0, 0, 96, 232, 233, 234, 235, 236, 237, 81, 82, 80, 79, 83, 0, 97, 0, 0, 0, 0, 0, 231, 0, 0, 93, 94, 95, 0, 0, 98, 0, 77, 78, 91, 92, @@ -383,7 +383,7 @@ var yyAct = [...]int16{ 91, 92, 90, 84, 85, 86, 87, 88, 189, 89, 0, 0, 302, 0, 187, 0, 0, 0, 0, 0, 188, 0, 0, 0, 270, 265, 266, 271, 267, 272, - 268, 269, 273, 195, 76, 0, 464, 0, 0, 0, + 268, 269, 273, 195, 76, 0, 466, 0, 0, 0, 0, 0, 263, 191, 190, 193, 261, 194, 75, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 81, 82, 80, 79, @@ -449,106 +449,118 @@ var yyAct = [...]int16{ 87, 88, 187, 89, 0, 0, 0, 0, 188, 0, 0, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 278, 187, 0, - 279, 191, 190, 193, 188, 276, 275, 0, 0, 196, + 279, 191, 190, 193, 188, 276, 479, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 278, 187, 0, 279, 191, 190, 193, - 188, 276, 448, 0, 0, 196, 197, 0, 198, 0, + 188, 276, 475, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 278, - 187, 0, 279, 191, 190, 193, 188, 276, 423, 0, + 187, 0, 279, 191, 190, 193, 188, 276, 459, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 278, 187, 0, 279, 191, - 190, 193, 188, 276, 399, 0, 0, 196, 197, 0, + 190, 193, 188, 276, 454, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, - 0, 278, 187, 0, 279, 191, 190, 193, 188, 194, - 335, 0, 0, 196, 197, 0, 198, 0, 199, 200, + 0, 278, 187, 0, 279, 191, 190, 193, 188, 276, + 448, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 278, 187, 0, - 279, 191, 190, 193, 188, 276, 331, 0, 0, 196, + 279, 191, 190, 193, 188, 276, 423, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, - 0, 0, 0, 382, 187, 0, 279, 191, 190, 193, + 0, 0, 0, 278, 187, 0, 279, 191, 190, 193, + 188, 276, 399, 0, 0, 196, 197, 0, 198, 0, + 199, 200, 0, 195, 189, 0, 0, 0, 0, 278, + 187, 0, 279, 191, 190, 193, 188, 194, 335, 0, + 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, + 189, 0, 0, 0, 0, 278, 187, 0, 279, 191, + 190, 193, 188, 276, 331, 0, 0, 196, 197, 0, + 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, + 0, 278, 187, 0, 279, 191, 190, 193, 188, 276, + 275, 0, 0, 196, 197, 0, 198, 0, 199, 200, + 0, 195, 189, 0, 0, 0, 0, 382, 187, 0, + 279, 191, 190, 193, 188, 194, 0, 0, 0, 196, + 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, + 0, 0, 0, 0, 187, 0, 192, 191, 190, 193, 188, 194, 0, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 0, - 187, 0, 192, 191, 190, 193, 188, 194, 0, 0, + 187, 0, 192, 191, 190, 193, 188, 458, 0, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 0, 187, 0, 192, 191, - 190, 193, 188, 457, 0, 0, 0, 196, 197, 0, + 190, 193, 188, 452, 0, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, - 0, 0, 187, 0, 192, 191, 190, 193, 188, 452, + 0, 0, 187, 0, 192, 191, 190, 193, 188, 431, 0, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 0, 187, 0, - 192, 191, 190, 193, 188, 431, 0, 0, 0, 196, + 192, 191, 190, 193, 188, 427, 0, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 0, 187, 0, 192, 191, 190, 193, - 188, 427, 0, 0, 0, 196, 197, 0, 198, 0, + 188, 420, 0, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, 189, 0, 0, 0, 0, 0, - 187, 0, 192, 191, 190, 193, 188, 420, 0, 0, + 187, 0, 192, 191, 190, 193, 188, 383, 0, 0, 0, 196, 197, 0, 198, 0, 199, 200, 0, 195, - 189, 0, 0, 0, 0, 0, 187, 0, 192, 191, - 190, 193, 188, 383, 0, 0, 0, 196, 197, 0, - 198, 0, 199, 200, 0, 195, 0, 0, 0, 0, - 0, 0, 0, 0, 192, 191, 190, 193, 0, 353, - 270, 339, 340, 271, 341, 272, 343, 342, 273, 270, - 339, 340, 271, 341, 272, 343, 342, 273, 338, 0, - 0, 0, 337, 0, 0, 0, 0, 338, + 0, 0, 0, 0, 0, 0, 0, 0, 192, 191, + 190, 193, 0, 353, 270, 339, 340, 271, 341, 272, + 343, 342, 273, 270, 339, 340, 271, 341, 272, 343, + 342, 273, 338, 0, 0, 0, 337, 0, 0, 0, + 0, 338, } var yyPact = [...]int16{ -1000, -1000, 1175, 84, -1000, -1000, 1162, -1000, 137, -19, - -1000, 1162, -1000, 1162, 352, 344, 333, -1000, 340, 338, - 336, -1000, -1000, -1000, 1162, -1000, 56, 1049, -1000, -1000, + -1000, 1162, -1000, 1162, 364, 362, 361, -1000, 352, 341, + 340, -1000, -1000, -1000, 1162, -1000, 56, 1049, -1000, -1000, -1000, -1000, -1000, -1000, -19, -1000, -1000, 1162, 1162, 1162, - 1162, 221, -1000, -1000, 536, -1000, 1162, 270, 1162, 657, - -1000, 545, 1121, -1000, -1000, 433, -1000, 1010, 382, 971, - 141, 133, 221, -38, -1000, 335, 132, -1000, 22, -1000, - 157, 19, 918, 248, 1162, 1162, 1162, 1162, 1162, 1162, + 1162, 221, -1000, -1000, 536, -1000, 1162, 274, 1162, 657, + -1000, 545, 1121, -1000, -1000, 441, -1000, 1010, 387, 971, + 141, 133, 221, -38, -1000, 338, 132, -1000, 22, -1000, + 157, 19, 918, 267, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, 1162, - 1162, 1162, 1162, 1162, 1162, 1162, 1548, 1548, -1000, 192, - 192, 192, 192, -1000, 268, 283, 299, -1000, 24, -1000, + 1162, 1162, 1162, 1162, 1162, 1162, 1652, 1652, -1000, 192, + 192, 192, 192, -1000, 271, 288, 315, -1000, 24, -1000, 140, 1162, 1049, -1000, -1000, -1000, -1000, -1000, -1000, -1000, -1000, -1000, 56, -1000, -19, 478, -1000, -1000, -1000, -1000, -1000, -1000, -1000, 292, 292, 292, 292, 292, 292, 292, - -1000, 142, -1000, -1000, 1162, -1000, 1162, 1162, 334, -1000, - 482, 330, 221, 1162, 329, 1548, 1548, 120, 73, 325, + -1000, 142, -1000, -1000, 1162, -1000, 1162, 1162, 336, -1000, + 482, 335, 221, 1162, 334, 1652, 1652, 120, 73, 333, -1000, -1000, 1049, 1088, 1191, 1219, 1219, 1219, 1219, 1219, 1219, 1314, 291, 291, 192, 192, 192, 192, 1294, 1247, 1275, 1314, 1314, 192, -1000, 251, -46, -1000, -1000, -1000, - -1000, -1000, -40, 718, 1366, 274, 295, 288, 307, 273, - 272, -1000, 238, -50, 13, 59, 266, 27, -1000, -5, + -1000, -1000, -40, 718, 1600, 286, 298, 295, 317, 283, + 273, -1000, 264, -50, 13, 59, 266, 27, -1000, -5, -1000, 576, -1000, -1000, 1162, 428, -1000, -1000, -1000, -1000, - -1000, -1000, -1000, -1000, -1000, 419, 1049, -1000, 660, 138, + -1000, -1000, -1000, -1000, -1000, 439, 1049, -1000, 660, 138, -1000, -1000, -1000, -1000, -1000, -1000, -1000, -1000, -1000, -1000, - 56, -1000, 158, 158, 1548, 397, 321, -1000, 1548, 1548, - 1548, -1000, -60, 197, -1000, 317, 315, 1548, 241, 1548, - 134, -1000, -1000, 14, 20, 295, 288, 307, 273, 272, - -1000, -1000, -1000, -1000, 121, 177, 1496, -1000, 1548, 16, - 158, 1470, 1754, 897, 286, 1162, 1548, -50, -1000, 1730, - 1548, 309, 233, -50, -1000, 299, 1548, -1000, 1049, 102, - -1000, 412, 1162, -1000, 158, -1000, -1000, 271, -1000, 178, - 178, 162, -1000, -1000, 7, 264, -1000, 74, 158, -1000, - -16, -1000, 1763, 1548, 1548, 1548, 1548, 218, 1522, 1704, - 118, 198, 158, 1548, 92, -20, 95, -1000, -12, -1000, - -1000, -1000, -1000, -1000, 209, 1548, 9, 1548, 839, 206, - -1000, 417, 10, 1444, 158, -25, -50, -1000, -1000, 158, - 1162, -1000, -1000, 1162, 610, 416, 210, 1548, -1000, 354, - -1000, 1548, -1000, -1000, -30, -1000, -36, -36, -36, -36, - 1678, -1000, 1548, 1418, 158, 196, 1652, 158, -33, 1626, - -1000, -1000, -29, -1000, 158, -1000, -1000, 415, -1000, 177, - 63, 1548, -1000, 414, 1049, 408, 800, -1000, 1162, -1000, - 231, 18, -1000, 91, 158, 1548, -60, -60, -60, -60, - 1392, 158, 158, 177, 90, 62, 1600, 1418, 205, 1574, - 158, 1366, -35, -1000, 1548, 158, 404, -1000, -1000, -1000, - 401, 750, -50, 89, 1548, -1000, 191, 88, 177, -1000, - 1548, 1548, 1392, 176, 86, -1000, 158, 1366, 83, 1548, - 80, -1000, -1000, -1000, -1000, -1000, -50, 158, -1000, -1000, - 158, 158, 79, -1000, 228, 77, -1000, 187, -1000, 399, - -1000, 226, -1000, -1000, -1000, -1000, -1000, + 56, -1000, 158, 158, 1652, 408, 330, -1000, 1652, 1652, + 1652, -1000, -60, 197, -1000, 325, 321, 1652, 241, 1652, + 134, -1000, -1000, 14, 20, 298, 295, 317, 283, 273, + -1000, -1000, -1000, -1000, 121, 177, 1574, -1000, 1652, 16, + 158, 1548, 1858, 897, 378, 1162, 1652, -50, -1000, 1834, + 1652, 309, 248, -50, -1000, 315, 1652, -1000, 1049, 102, + -1000, 416, 1162, -1000, 158, -1000, -1000, 307, -1000, 178, + 178, 162, -1000, -1000, 7, 270, -1000, 74, 158, -1000, + -16, -1000, 1867, 1652, 1652, 1652, 1652, 224, 1626, 1808, + 118, 198, 158, 1652, 92, -20, 95, -1000, -12, -1000, + -1000, -1000, -1000, -1000, 209, 1652, 9, 1652, 839, 206, + -1000, 438, 10, 1522, 158, -25, -50, -1000, -1000, 158, + 1162, -1000, -1000, 1162, 610, 434, 210, 1652, -1000, 354, + -1000, 1652, -1000, -1000, -30, -1000, -36, -36, -36, -36, + 1782, -1000, 1652, 1496, 158, 196, 1756, 158, -33, 1730, + -1000, -1000, -29, -1000, 158, -1000, -1000, 433, -1000, 177, + 63, 1652, -1000, 419, 1049, 417, 800, -1000, 1162, -1000, + 231, 18, -1000, 91, 158, 1652, -60, -60, -60, -60, + 1470, 158, 158, 177, 90, 62, 1704, 1444, 205, 1678, + 158, 1418, -35, -1000, 1652, 158, 412, -1000, -1000, -1000, + 410, 750, -50, 89, 1652, -1000, 191, 88, 177, -1000, + 1652, 1652, 1392, 176, 259, 86, -1000, 158, 1366, 177, + 83, 1652, 80, -1000, -1000, -1000, -1000, -1000, -50, 158, + -1000, -1000, 158, 158, 79, 222, -1000, -1000, 238, 177, + 77, -1000, 187, -1000, 404, -1000, 233, -1000, -1000, -1000, + -1000, -1000, -1000, } var yyPgo = [...]int16{ - 0, 447, 499, 4, 498, 485, 483, 478, 477, 476, - 349, 475, 5, 25, 49, 337, 411, 474, 62, 472, - 16, 15, 41, 471, 11, 463, 462, 461, 0, 13, - 458, 2, 1, 6, 215, 456, 9, 12, 8, 3, - 10, 455, 453, 450, 7, 22, 21, 14, 443, + 0, 455, 503, 4, 502, 501, 500, 499, 498, 485, + 349, 483, 5, 25, 49, 340, 411, 478, 62, 477, + 16, 15, 41, 476, 11, 474, 473, 472, 0, 13, + 469, 2, 1, 6, 215, 463, 9, 12, 8, 3, + 10, 461, 459, 458, 7, 22, 21, 14, 453, } var yyR1 = [...]int8{ @@ -572,11 +584,11 @@ var yyR1 = [...]int8{ 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, - 31, 31, 32, 32, 48, 48, 44, 44, 45, 45, - 45, 45, 37, 37, 37, 38, 38, 38, 38, 38, - 38, 38, 38, 47, 47, 46, 46, 46, 34, 35, - 35, 36, 36, 39, 39, 40, 40, 43, 43, 43, - 41, 41, 42, 42, + 30, 30, 30, 30, 31, 31, 32, 32, 48, 48, + 44, 44, 45, 45, 45, 45, 37, 37, 37, 38, + 38, 38, 38, 38, 38, 38, 38, 47, 47, 46, + 46, 46, 34, 35, 35, 36, 36, 39, 39, 40, + 40, 43, 43, 43, 41, 41, 42, 42, } var yyR2 = [...]int8{ @@ -598,13 +610,13 @@ var yyR2 = [...]int8{ 0, 2, 3, 6, 1, 3, 3, 7, 1, 2, 2, 1, 1, 1, 1, 1, 1, 3, 4, 3, 7, 3, 7, 6, 5, 6, 4, 5, 7, 6, - 9, 8, 6, 8, 5, 7, 4, 3, 4, 3, - 3, 4, 8, 4, 2, 2, 4, 4, 4, 4, - 1, 3, 3, 3, 1, 1, 3, 1, 1, 3, - 2, 4, 1, 3, 2, 3, 4, 3, 4, 3, - 4, 3, 4, 1, 2, 2, 4, 5, 3, 1, - 3, 1, 3, 1, 3, 1, 3, 0, 2, 3, - 0, 2, 5, 6, + 9, 8, 8, 7, 6, 8, 5, 7, 7, 6, + 4, 3, 4, 3, 3, 4, 8, 4, 2, 2, + 4, 4, 4, 4, 1, 3, 3, 3, 1, 1, + 3, 1, 1, 3, 2, 4, 1, 3, 2, 3, + 4, 3, 4, 3, 4, 3, 4, 1, 2, 2, + 4, 5, 3, 1, 3, 1, 3, 1, 3, 1, + 3, 0, 2, 3, 0, 2, 5, 6, } var yyChk = [...]int16{ @@ -653,10 +665,11 @@ var yyChk = [...]int16{ 49, -28, -28, 50, -32, -28, 54, 49, -28, 73, -28, 49, 76, 9, 72, -28, -3, 9, 9, 21, -3, -13, 50, -39, 73, 50, -28, -32, 50, 50, - 72, 72, 49, -28, -32, 50, -28, 49, -32, 73, - -31, 9, -3, 9, 6, -33, 50, -28, 48, 50, - -28, -28, -32, 50, 50, -32, 50, -28, 50, -3, - -33, 50, 50, 50, 48, 9, 50, + 72, 72, 49, -28, 50, -32, 50, -28, 49, 50, + -32, 73, -31, 9, -3, 9, 6, -33, 50, -28, + 48, 50, -28, -28, -32, 50, 50, 50, 50, 50, + -32, 50, -28, 50, -3, -33, 50, 50, 50, 50, + 48, 9, 50, } var yyDef = [...]int16{ @@ -666,8 +679,8 @@ var yyDef = [...]int16{ 59, 60, 61, 62, 63, 64, 65, 0, 0, 0, 0, 0, 95, 94, 0, 44, 0, 0, 0, 0, 100, 0, 0, 110, 111, 0, 7, 0, 0, 0, - 53, 0, 0, 32, 40, 0, 21, 233, 235, 23, - 0, 237, 0, 97, 0, 0, 0, 0, 0, 0, + 53, 0, 0, 32, 40, 0, 21, 237, 239, 23, + 0, 241, 0, 97, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 93, 87, 88, 89, 90, 112, 0, 0, 0, 122, 0, 124, @@ -675,40 +688,41 @@ var yyDef = [...]int16{ 137, 138, 8, -2, 0, 0, 46, 47, 48, 49, 50, 51, 52, 0, 0, 0, 0, 0, 0, 0, 108, 0, 10, 4, 0, 4, 0, 0, 0, 18, - 0, 0, 0, 0, 0, 0, 0, 0, 240, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 244, 0, 98, 99, 56, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 91, 154, 158, 161, 162, 163, 164, 165, 166, 0, 0, 0, 0, 0, 0, 0, - 0, 92, 0, 150, 117, 119, 0, 0, 229, 231, + 0, 92, 0, 150, 117, 119, 0, 0, 233, 235, 123, 126, 148, 149, 0, 0, 45, 101, 102, 103, 104, 105, 106, 107, 109, 0, 12, 27, 0, 0, 54, 33, 34, 35, 36, 37, 38, 39, 41, 19, - 20, 234, 236, 24, 0, 0, 0, 238, 0, 0, - 0, 160, 159, 0, 223, 0, 0, 0, 0, 0, - 0, 194, 212, 166, 0, 140, 141, 143, 147, 145, - 139, 142, 144, 146, 0, 0, 0, 208, 0, 166, - 207, 0, 0, 0, 195, 0, 0, 150, 4, 0, - 0, 0, 0, 150, 228, 0, 0, 125, 127, 0, - 11, 0, 0, 4, 25, 26, 241, 0, 239, 155, - 156, 0, 224, 190, 0, 225, 167, 0, 200, 169, - 0, 171, 214, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 210, 0, 0, 187, 0, 189, 0, 140, + 20, 238, 240, 24, 0, 0, 0, 242, 0, 0, + 0, 160, 159, 0, 227, 0, 0, 0, 0, 0, + 0, 198, 216, 166, 0, 140, 141, 143, 147, 145, + 139, 142, 144, 146, 0, 0, 0, 212, 0, 166, + 211, 0, 0, 0, 199, 0, 0, 150, 4, 0, + 0, 0, 0, 150, 232, 0, 0, 125, 127, 0, + 11, 0, 0, 4, 25, 26, 245, 0, 243, 155, + 156, 0, 228, 194, 0, 229, 167, 0, 204, 169, + 0, 171, 218, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 214, 0, 0, 191, 0, 193, 0, 140, 141, 143, 145, 147, 0, 0, 0, 0, 0, 0, - 4, 0, 151, 0, 118, 120, 150, 4, 230, 232, - 0, 13, 4, 0, 0, 0, 0, 160, 199, 0, - 168, 0, 204, 205, 0, 213, 215, 217, 219, 221, - 0, 209, 0, 0, 176, 0, 0, 206, 186, 0, - 188, 191, 0, 193, 196, 197, 198, 0, 114, 152, - 207, 0, 4, 0, 128, 0, 0, 4, 0, 17, - 0, 0, 226, 0, 201, 0, 216, 218, 220, 222, - 0, 174, 211, 177, 0, 207, 0, 0, 176, 0, - 184, 0, 0, 113, 0, 121, 0, 116, 14, 4, - 0, 0, 150, 0, 0, 227, 0, 0, 173, 175, - 0, 0, 0, 174, 0, 179, 182, 0, 0, 0, - 0, 115, 28, 15, 4, 242, 150, 157, 170, 172, - 203, 202, 0, 178, 175, 0, 185, 0, 153, 0, - 243, 172, 181, 183, 192, 16, 180, + 4, 0, 151, 0, 118, 120, 150, 4, 234, 236, + 0, 13, 4, 0, 0, 0, 0, 160, 203, 0, + 168, 0, 208, 209, 0, 217, 219, 221, 223, 225, + 0, 213, 0, 0, 176, 0, 0, 210, 190, 0, + 192, 195, 0, 197, 200, 201, 202, 0, 114, 152, + 211, 0, 4, 0, 128, 0, 0, 4, 0, 17, + 0, 0, 230, 0, 205, 0, 220, 222, 224, 226, + 0, 174, 215, 177, 0, 211, 0, 0, 176, 0, + 186, 0, 0, 113, 0, 121, 0, 116, 14, 4, + 0, 0, 150, 0, 0, 231, 0, 0, 173, 175, + 0, 0, 0, 174, 177, 0, 179, 184, 0, 189, + 0, 0, 0, 115, 28, 15, 4, 246, 150, 157, + 170, 172, 207, 206, 0, 173, 178, 183, 175, 188, + 0, 187, 0, 153, 0, 247, 172, 181, 182, 185, + 196, 16, 180, } var yyTok1 = [...]int8{ @@ -2422,169 +2436,197 @@ yydefault: yyDollar = yyS[yypt-8 : yypt+1] //line parser.go.y:1033 { - yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: yyDollar[6].typeexprlist} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: yyDollar[3].funcparams, Returns: []ast.TypeExpr{}} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 182: - yyDollar = yyS[yypt-6 : yypt+1] + yyDollar = yyS[yypt-8 : yypt+1] //line parser.go.y:1037 { - yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: yyDollar[3].funcparams, Returns: []ast.TypeExpr{yyDollar[6].typeexpr}} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: yyDollar[6].typeexprlist} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 183: - yyDollar = yyS[yypt-8 : yypt+1] + yyDollar = yyS[yypt-7 : yypt+1] //line parser.go.y:1041 { - yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: yyDollar[3].funcparams, Returns: yyDollar[7].typeexprlist} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: []ast.TypeExpr{}} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 184: - yyDollar = yyS[yypt-5 : yypt+1] + yyDollar = yyS[yypt-6 : yypt+1] //line parser.go.y:1045 { - yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: []ast.TypeExpr{yyDollar[5].typeexpr}} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: yyDollar[3].funcparams, Returns: []ast.TypeExpr{yyDollar[6].typeexpr}} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 185: - yyDollar = yyS[yypt-7 : yypt+1] + yyDollar = yyS[yypt-8 : yypt+1] //line parser.go.y:1049 { - yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: yyDollar[6].typeexprlist} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: yyDollar[3].funcparams, Returns: yyDollar[7].typeexprlist} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 186: - yyDollar = yyS[yypt-4 : yypt+1] + yyDollar = yyS[yypt-5 : yypt+1] //line parser.go.y:1053 { - yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: yyDollar[3].funcparams, Returns: nil} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: []ast.TypeExpr{yyDollar[5].typeexpr}} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 187: - yyDollar = yyS[yypt-3 : yypt+1] + yyDollar = yyS[yypt-7 : yypt+1] //line parser.go.y:1057 { - yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: nil} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: yyDollar[6].typeexprlist} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 188: - yyDollar = yyS[yypt-4 : yypt+1] + yyDollar = yyS[yypt-7 : yypt+1] //line parser.go.y:1061 { - yyVAL.typeexpr = &ast.RecordTypeExpr{Fields: yyDollar[3].recordfields} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: yyDollar[3].funcparams, Returns: []ast.TypeExpr{}} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 189: - yyDollar = yyS[yypt-3 : yypt+1] + yyDollar = yyS[yypt-6 : yypt+1] //line parser.go.y:1065 { - yyVAL.typeexpr = &ast.RecordTypeExpr{Fields: nil} + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: []ast.TypeExpr{}} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } case 190: - yyDollar = yyS[yypt-3 : yypt+1] + yyDollar = yyS[yypt-4 : yypt+1] //line parser.go.y:1069 + { + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: yyDollar[3].funcparams, Returns: nil} + yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) + } + case 191: + yyDollar = yyS[yypt-3 : yypt+1] +//line parser.go.y:1073 + { + yyVAL.typeexpr = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: nil} + yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) + } + case 192: + yyDollar = yyS[yypt-4 : yypt+1] +//line parser.go.y:1077 + { + yyVAL.typeexpr = &ast.RecordTypeExpr{Fields: yyDollar[3].recordfields} + yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) + } + case 193: + yyDollar = yyS[yypt-3 : yypt+1] +//line parser.go.y:1081 + { + yyVAL.typeexpr = &ast.RecordTypeExpr{Fields: nil} + yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) + } + case 194: + yyDollar = yyS[yypt-3 : yypt+1] +//line parser.go.y:1085 { yyVAL.typeexpr = &ast.ArrayTypeExpr{Element: yyDollar[1].typeexpr} yyVAL.typeexpr.CopyPos(yyDollar[1].typeexpr) } - case 191: + case 195: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1073 +//line parser.go.y:1089 { arr := &ast.ArrayTypeExpr{Element: yyDollar[3].typeexpr, Readonly: true} arr.SetPosFromToken(yyDollar[1].token.Pos) yyVAL.typeexpr = arr } - case 192: + case 196: yyDollar = yyS[yypt-8 : yypt+1] -//line parser.go.y:1078 +//line parser.go.y:1094 { m := &ast.MapTypeExpr{Key: yyDollar[4].typeexpr, Value: yyDollar[7].typeexpr, Readonly: true} m.SetPosFromToken(yyDollar[1].token.Pos) yyVAL.typeexpr = m } - case 193: + case 197: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1083 +//line parser.go.y:1099 { yyVAL.typeexpr = &ast.RecordTypeExpr{Fields: yyDollar[3].recordfields, Readonly: true} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } - case 194: + case 198: yyDollar = yyS[yypt-2 : yypt+1] -//line parser.go.y:1087 +//line parser.go.y:1103 { yyVAL.typeexpr = &ast.RecordTypeExpr{Fields: nil} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } - case 195: + case 199: yyDollar = yyS[yypt-2 : yypt+1] -//line parser.go.y:1091 +//line parser.go.y:1107 { yyVAL.typeexpr = &ast.AssertsTypeExpr{ParamName: yyDollar[2].token.Str, NarrowTo: nil} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } - case 196: + case 200: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1095 +//line parser.go.y:1111 { yyVAL.typeexpr = &ast.AssertsTypeExpr{ParamName: yyDollar[2].token.Str, NarrowTo: yyDollar[4].typeexpr} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } - case 197: + case 201: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1099 +//line parser.go.y:1115 { yyVAL.typeexpr = &ast.TypeOfExpr{Expr: yyDollar[3].expr} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } - case 198: + case 202: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1103 +//line parser.go.y:1119 { yyVAL.typeexpr = &ast.KeyOfExpr{Inner: yyDollar[3].typeexpr} yyVAL.typeexpr.SetPosFromToken(yyDollar[1].token.Pos) } - case 199: + case 203: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1107 +//line parser.go.y:1123 { yyVAL.typeexpr = &ast.IndexAccessExpr{Object: yyDollar[1].typeexpr, Index: yyDollar[3].typeexpr} yyVAL.typeexpr.CopyPos(yyDollar[1].typeexpr) } - case 200: + case 204: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1113 +//line parser.go.y:1129 { yyVAL.typeexprlist = []ast.TypeExpr{yyDollar[1].typeexpr} } - case 201: + case 205: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1116 +//line parser.go.y:1132 { yyVAL.typeexprlist = append(yyDollar[1].typeexprlist, yyDollar[3].typeexpr) } - case 202: + case 206: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1121 +//line parser.go.y:1137 { yyVAL.typeexprlist = []ast.TypeExpr{yyDollar[1].typeexpr, yyDollar[3].typeexpr} } - case 203: + case 207: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1124 +//line parser.go.y:1140 { yyVAL.typeexprlist = append(yyDollar[1].typeexprlist, yyDollar[3].typeexpr) } - case 204: + case 208: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1129 +//line parser.go.y:1145 { } - case 205: + case 209: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1131 +//line parser.go.y:1147 { yylex.(*Lexer).PendingGT = &ast.Token{ Type: '>', @@ -2592,234 +2634,234 @@ yydefault: Pos: ast.Position{Source: yyDollar[1].token.Pos.Source, Line: yyDollar[1].token.Pos.Line, Column: yyDollar[1].token.Pos.Column + 1}, } } - case 206: + case 210: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1140 +//line parser.go.y:1156 { yyVAL.funcparam = ast.FunctionParamExpr{Name: yyDollar[1].token.Str, Type: yyDollar[3].typeexpr} } - case 207: + case 211: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1143 +//line parser.go.y:1159 { yyVAL.funcparam = ast.FunctionParamExpr{Name: "", Type: yyDollar[1].typeexpr} } - case 208: + case 212: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1148 +//line parser.go.y:1164 { yyVAL.funcparams = []ast.FunctionParamExpr{yyDollar[1].funcparam} } - case 209: + case 213: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1151 +//line parser.go.y:1167 { yyVAL.funcparams = append(yyDollar[1].funcparams, yyDollar[3].funcparam) } - case 210: + case 214: yyDollar = yyS[yypt-2 : yypt+1] -//line parser.go.y:1154 +//line parser.go.y:1170 { yyVAL.funcparams = []ast.FunctionParamExpr{{Name: "...", Type: yyDollar[2].typeexpr}} } - case 211: + case 215: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1157 +//line parser.go.y:1173 { yyVAL.funcparams = append(yyDollar[1].funcparams, ast.FunctionParamExpr{Name: "...", Type: yyDollar[4].typeexpr}) } - case 212: + case 216: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1162 +//line parser.go.y:1178 { yyVAL.recordfields = []ast.RecordFieldExpr{yyDollar[1].recordfield} } - case 213: + case 217: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1165 +//line parser.go.y:1181 { yyVAL.recordfields = append(yyDollar[1].recordfields, yyDollar[3].recordfield) } - case 214: + case 218: yyDollar = yyS[yypt-2 : yypt+1] -//line parser.go.y:1168 +//line parser.go.y:1184 { yyVAL.recordfields = yyDollar[1].recordfields } - case 215: + case 219: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1173 +//line parser.go.y:1189 { yyVAL.recordfield = ast.RecordFieldExpr{Name: yyDollar[1].token.Str, Type: yyDollar[3].typeexpr, Optional: false} } - case 216: + case 220: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1176 +//line parser.go.y:1192 { yyVAL.recordfield = ast.RecordFieldExpr{Name: yyDollar[1].token.Str, Type: yyDollar[3].typeexpr, Optional: false, Annotations: yyDollar[4].annotations} } - case 217: + case 221: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1179 +//line parser.go.y:1195 { yyVAL.recordfield = ast.RecordFieldExpr{Name: yyDollar[1].token.Str, Type: yyDollar[3].typeexpr, Optional: true} } - case 218: + case 222: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1182 +//line parser.go.y:1198 { yyVAL.recordfield = ast.RecordFieldExpr{Name: yyDollar[1].token.Str, Type: yyDollar[3].typeexpr, Optional: true, Annotations: yyDollar[4].annotations} } - case 219: + case 223: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1185 +//line parser.go.y:1201 { yyVAL.recordfield = ast.RecordFieldExpr{Name: yyDollar[1].fieldname, Type: yyDollar[3].typeexpr, Optional: false} } - case 220: + case 224: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1188 +//line parser.go.y:1204 { yyVAL.recordfield = ast.RecordFieldExpr{Name: yyDollar[1].fieldname, Type: yyDollar[3].typeexpr, Optional: false, Annotations: yyDollar[4].annotations} } - case 221: + case 225: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1191 +//line parser.go.y:1207 { yyVAL.recordfield = ast.RecordFieldExpr{Name: yyDollar[1].fieldname, Type: yyDollar[3].typeexpr, Optional: true} } - case 222: + case 226: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1194 +//line parser.go.y:1210 { yyVAL.recordfield = ast.RecordFieldExpr{Name: yyDollar[1].fieldname, Type: yyDollar[3].typeexpr, Optional: true, Annotations: yyDollar[4].annotations} } - case 223: + case 227: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1199 +//line parser.go.y:1215 { yyVAL.annotations = []ast.AnnotationExpr{yyDollar[1].annotation} } - case 224: + case 228: yyDollar = yyS[yypt-2 : yypt+1] -//line parser.go.y:1202 +//line parser.go.y:1218 { yyVAL.annotations = append(yyDollar[1].annotations, yyDollar[2].annotation) } - case 225: + case 229: yyDollar = yyS[yypt-2 : yypt+1] -//line parser.go.y:1207 +//line parser.go.y:1223 { yyVAL.annotation = ast.AnnotationExpr{Name: yyDollar[2].token.Str, Args: nil} } - case 226: + case 230: yyDollar = yyS[yypt-4 : yypt+1] -//line parser.go.y:1210 +//line parser.go.y:1226 { yyVAL.annotation = ast.AnnotationExpr{Name: yyDollar[2].token.Str, Args: nil} } - case 227: + case 231: yyDollar = yyS[yypt-5 : yypt+1] -//line parser.go.y:1213 +//line parser.go.y:1229 { yyVAL.annotation = ast.AnnotationExpr{Name: yyDollar[2].token.Str, Args: yyDollar[4].exprlist} } - case 228: + case 232: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1218 +//line parser.go.y:1234 { yyVAL.typeparams = yyDollar[2].typeparams } - case 229: + case 233: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1223 +//line parser.go.y:1239 { yyVAL.typeparams = []ast.TypeParamExpr{yyDollar[1].typeparam} } - case 230: + case 234: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1226 +//line parser.go.y:1242 { yyVAL.typeparams = append(yyDollar[1].typeparams, yyDollar[3].typeparam) } - case 231: + case 235: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1231 +//line parser.go.y:1247 { yyVAL.typeparam = ast.TypeParamExpr{Name: yyDollar[1].token.Str, Constraint: nil} } - case 232: + case 236: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1234 +//line parser.go.y:1250 { yyVAL.typeparam = ast.TypeParamExpr{Name: yyDollar[1].token.Str, Constraint: yyDollar[3].typeexpr} } - case 233: + case 237: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1239 +//line parser.go.y:1255 { yyVAL.typednames = []typedNameEntry{yyDollar[1].typedname} } - case 234: + case 238: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1242 +//line parser.go.y:1258 { yyVAL.typednames = append(yyDollar[1].typednames, yyDollar[3].typedname) } - case 235: + case 239: yyDollar = yyS[yypt-1 : yypt+1] -//line parser.go.y:1247 +//line parser.go.y:1263 { yyVAL.typedname = typedNameEntry{Name: yyDollar[1].token.Str, Pos: yyDollar[1].token.Pos, Type: nil} } - case 236: + case 240: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1250 +//line parser.go.y:1266 { yyVAL.typedname = typedNameEntry{Name: yyDollar[1].token.Str, Pos: yyDollar[1].token.Pos, Type: yyDollar[3].typeexpr} } - case 237: + case 241: yyDollar = yyS[yypt-0 : yypt+1] -//line parser.go.y:1255 +//line parser.go.y:1271 { yyVAL.typereflist = nil } - case 238: + case 242: yyDollar = yyS[yypt-2 : yypt+1] -//line parser.go.y:1258 +//line parser.go.y:1274 { yyVAL.typereflist = []*ast.TypeRefExpr{{Path: []string{yyDollar[2].token.Str}}} } - case 239: + case 243: yyDollar = yyS[yypt-3 : yypt+1] -//line parser.go.y:1261 +//line parser.go.y:1277 { yyVAL.typereflist = append(yyDollar[1].typereflist, &ast.TypeRefExpr{Path: []string{yyDollar[3].token.Str}}) } - case 240: + case 244: yyDollar = yyS[yypt-0 : yypt+1] -//line parser.go.y:1266 +//line parser.go.y:1282 { yyVAL.ifacemethods = nil } - case 241: + case 245: yyDollar = yyS[yypt-2 : yypt+1] -//line parser.go.y:1269 +//line parser.go.y:1285 { yyVAL.ifacemethods = append(yyDollar[1].ifacemethods, yyDollar[2].ifacemethod) } - case 242: + case 246: yyDollar = yyS[yypt-5 : yypt+1] -//line parser.go.y:1274 +//line parser.go.y:1290 { yyVAL.ifacemethod = ast.InterfaceMethodExpr{ Name: yyDollar[2].token.Str, Type: &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: yyDollar[5].typeexprlist}, } } - case 243: + case 247: yyDollar = yyS[yypt-6 : yypt+1] -//line parser.go.y:1280 +//line parser.go.y:1296 { yyVAL.ifacemethod = ast.InterfaceMethodExpr{ Name: yyDollar[2].token.Str, diff --git a/compiler/parse/parser.go.y b/compiler/parse/parser.go.y index 47c86bc8..67cfcdd3 100644 --- a/compiler/parse/parser.go.y +++ b/compiler/parse/parser.go.y @@ -1030,10 +1030,18 @@ primarytypeexpr: $$ = &ast.FunctionTypeExpr{Params: $3, Returns: $7} $$.SetPosFromToken($1.Pos) } | + '(' '(' funcparamlist ')' TArrow '(' ')' ')' { + $$ = &ast.FunctionTypeExpr{Params: $3, Returns: []ast.TypeExpr{}} + $$.SetPosFromToken($1.Pos) + } | '(' '(' ')' TArrow '(' typeexprlist2 ')' ')' { $$ = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: $6} $$.SetPosFromToken($1.Pos) } | + '(' '(' ')' TArrow '(' ')' ')' { + $$ = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: []ast.TypeExpr{}} + $$.SetPosFromToken($1.Pos) + } | TFun '(' funcparamlist ')' ':' typeexpr { $$ = &ast.FunctionTypeExpr{Params: $3, Returns: []ast.TypeExpr{$6}} $$.SetPosFromToken($1.Pos) @@ -1050,6 +1058,14 @@ primarytypeexpr: $$ = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: $6} $$.SetPosFromToken($1.Pos) } | + TFun '(' funcparamlist ')' ':' '(' ')' { + $$ = &ast.FunctionTypeExpr{Params: $3, Returns: []ast.TypeExpr{}} + $$.SetPosFromToken($1.Pos) + } | + TFun '(' ')' ':' '(' ')' { + $$ = &ast.FunctionTypeExpr{Params: []ast.FunctionParamExpr{}, Returns: []ast.TypeExpr{}} + $$.SetPosFromToken($1.Pos) + } | TFun '(' funcparamlist ')' { $$ = &ast.FunctionTypeExpr{Params: $3, Returns: nil} $$.SetPosFromToken($1.Pos) diff --git a/compiler/parse/parser_test.go b/compiler/parse/parser_test.go index 9065e988..574acfaa 100644 --- a/compiler/parse/parser_test.go +++ b/compiler/parse/parser_test.go @@ -1980,3 +1980,42 @@ func TestParseVoidReturnInRecordField(t *testing.T) { }) } } + +func TestParseVoidReturnAllForms(t *testing.T) { + tests := []struct { + name string + input string + }{ + // Parenthesized forms for optional/union wrapping + {name: "paren params void", input: `type T = {f: ((x: number) -> ())?}`}, + {name: "paren empty void", input: `type T = {f: (() -> ())?}`}, + {name: "paren params single", input: `type T = {f: ((x: number) -> string)?}`}, + {name: "paren empty single", input: `type T = {f: (() -> string)?}`}, + {name: "paren params multi", input: `type T = {f: ((x: number) -> (string, number))?}`}, + {name: "paren empty multi", input: `type T = {f: (() -> (string, number))?}`}, + {name: "paren void in union", input: `type T = ((x: number) -> ()) | nil`}, + // fun keyword with void return + {name: "fun params void", input: `type T = {f: fun(x: number): ()}`}, + {name: "fun empty void", input: `type T = {f: fun(): ()}`}, + // fun keyword regression + {name: "fun params single", input: `type T = {f: fun(x: number): string}`}, + {name: "fun empty single", input: `type T = {f: fun(): string}`}, + {name: "fun params multi", input: `type T = {f: fun(x: number): (string, number)}`}, + {name: "fun empty multi", input: `type T = {f: fun(): (string, number)}`}, + {name: "fun params no return", input: `type T = {f: fun(x: number)}`}, + {name: "fun empty no return", input: `type T = {f: fun()}`}, + // Bare arrow forms + {name: "bare params void", input: `type T = {f: (x: number) -> ()}`}, + {name: "bare empty void", input: `type T = {f: () -> ()}`}, + {name: "bare params single", input: `type T = {f: (x: number) -> string}`}, + {name: "bare empty single", input: `type T = {f: () -> string}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseString(tt.input, "test") + if err != nil { + t.Errorf("expected no parse error, got: %v", err) + } + }) + } +} diff --git a/compiler/parse/y.output b/compiler/parse/y.output index 326d05a6..5998c43f 100644 --- a/compiler/parse/y.output +++ b/compiler/parse/y.output @@ -1017,17 +1017,17 @@ state 66 state 67 - typednamelist: typedname. (233) + typednamelist: typedname. (237) - . reduce 233 (src line 1238) + . reduce 237 (src line 1254) state 68 - typedname: TIdent. (235) + typedname: TIdent. (239) typedname: TIdent.':' typeexpr ':' shift 155 - . reduce 235 (src line 1246) + . reduce 239 (src line 1262) state 69 @@ -1048,10 +1048,10 @@ state 70 state 71 stat: TInterface TIdent.interfaceextends interfacebody TEnd - interfaceextends: . (237) + interfaceextends: . (241) ':' shift 159 - . reduce 237 (src line 1254) + . reduce 241 (src line 1270) interfaceextends goto 158 @@ -2626,10 +2626,10 @@ state 157 state 158 stat: TInterface TIdent interfaceextends.interfacebody TEnd interfaceextends: interfaceextends.',' TIdent - interfacebody: . (240) + interfacebody: . (244) ',' shift 246 - . reduce 240 (src line 1265) + . reduce 244 (src line 1281) interfacebody goto 245 @@ -3709,7 +3709,9 @@ state 194 primarytypeexpr: '('.'(' funcparamlist ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' TFalse shift 189 TNil shift 187 @@ -3740,6 +3742,8 @@ state 195 primarytypeexpr: TFun.'(' funcparamlist ')' ':' '(' typeexprlist2 ')' primarytypeexpr: TFun.'(' ')' ':' typeexpr primarytypeexpr: TFun.'(' ')' ':' '(' typeexprlist2 ')' + primarytypeexpr: TFun.'(' funcparamlist ')' ':' '(' ')' + primarytypeexpr: TFun.'(' ')' ':' '(' ')' primarytypeexpr: TFun.'(' funcparamlist ')' primarytypeexpr: TFun.'(' ')' @@ -3851,17 +3855,17 @@ state 207 state 208 - typeparamlist: typeparam. (229) + typeparamlist: typeparam. (233) - . reduce 229 (src line 1222) + . reduce 233 (src line 1238) state 209 - typeparam: TIdent. (231) + typeparam: TIdent. (235) typeparam: TIdent.':' typeexpr ':' shift 296 - . reduce 231 (src line 1230) + . reduce 235 (src line 1246) state 210 @@ -4263,19 +4267,19 @@ state 240 state 241 - typednamelist: typednamelist ',' typedname. (234) + typednamelist: typednamelist ',' typedname. (238) - . reduce 234 (src line 1241) + . reduce 238 (src line 1257) state 242 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typedname: TIdent ':' typeexpr. (236) + typedname: TIdent ':' typeexpr. (240) '|' shift 248 '&' shift 249 - . reduce 236 (src line 1249) + . reduce 240 (src line 1265) state 243 @@ -4329,9 +4333,9 @@ state 246 state 247 - interfaceextends: ':' TIdent. (238) + interfaceextends: ':' TIdent. (242) - . reduce 238 (src line 1257) + . reduce 242 (src line 1273) state 248 @@ -4442,9 +4446,9 @@ state 253 primarytypeexpr goto 186 state 254 - annotations: annotation. (223) + annotations: annotation. (227) - . reduce 223 (src line 1198) + . reduce 227 (src line 1214) state 255 @@ -4532,15 +4536,15 @@ state 260 state 261 - primarytypeexpr: '{' '}'. (194) + primarytypeexpr: '{' '}'. (198) - . reduce 194 (src line 1086) + . reduce 198 (src line 1102) state 262 - typefieldlist: typefield. (212) + typefieldlist: typefield. (216) - . reduce 212 (src line 1161) + . reduce 216 (src line 1177) state 263 @@ -4672,8 +4676,12 @@ state 276 primarytypeexpr: '(' '('.')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' primarytypeexpr: '(' '('.funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' + primarytypeexpr: '(' '('.funcparamlist ')' TArrow '(' ')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' primarytypeexpr: '(' '('.')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' + primarytypeexpr: '(' '('.')' TArrow '(' ')' ')' TFalse shift 189 TNil shift 187 @@ -4700,9 +4708,9 @@ state 276 funcparamlist goto 330 state 277 - funcparamlist: funcparam. (208) + funcparamlist: funcparam. (212) - . reduce 208 (src line 1147) + . reduce 212 (src line 1163) state 278 @@ -4743,11 +4751,11 @@ state 279 state 280 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - funcparam: typeexpr. (207) + funcparam: typeexpr. (211) '|' shift 248 '&' shift 249 - . reduce 207 (src line 1142) + . reduce 211 (src line 1158) state 281 @@ -4755,6 +4763,8 @@ state 281 primarytypeexpr: TFun '('.funcparamlist ')' ':' '(' typeexprlist2 ')' primarytypeexpr: TFun '('.')' ':' typeexpr primarytypeexpr: TFun '('.')' ':' '(' typeexprlist2 ')' + primarytypeexpr: TFun '('.funcparamlist ')' ':' '(' ')' + primarytypeexpr: TFun '('.')' ':' '(' ')' primarytypeexpr: TFun '('.funcparamlist ')' primarytypeexpr: TFun '('.')' @@ -4837,11 +4847,11 @@ state 283 typefield goto 262 state 284 - primarytypeexpr: TAsserts TIdent. (195) + primarytypeexpr: TAsserts TIdent. (199) primarytypeexpr: TAsserts TIdent.TIs typeexpr TIs shift 347 - . reduce 195 (src line 1090) + . reduce 199 (src line 1106) state 285 @@ -4991,9 +5001,9 @@ state 293 returntypeannot goto 357 state 294 - typeparams: '<' typeparamlist '>'. (228) + typeparams: '<' typeparamlist '>'. (232) - . reduce 228 (src line 1217) + . reduce 232 (src line 1233) state 295 @@ -5167,9 +5177,9 @@ state 305 state 306 - interfacebody: interfacebody interfacemethod. (241) + interfacebody: interfacebody interfacemethod. (245) - . reduce 241 (src line 1268) + . reduce 245 (src line 1284) state 307 @@ -5181,9 +5191,9 @@ state 307 state 308 - interfaceextends: interfaceextends ',' TIdent. (239) + interfaceextends: interfaceextends ',' TIdent. (243) - . reduce 239 (src line 1260) + . reduce 243 (src line 1276) 309: shift/reduce conflict (shift 251(0), red'n 155(4)) on TQuestion @@ -5213,15 +5223,15 @@ state 311 state 312 - annotations: annotations annotation. (224) + annotations: annotations annotation. (228) - . reduce 224 (src line 1201) + . reduce 228 (src line 1217) state 313 - primarytypeexpr: primarytypeexpr '[' ']'. (190) + primarytypeexpr: primarytypeexpr '[' ']'. (194) - . reduce 190 (src line 1068) + . reduce 194 (src line 1084) state 314 @@ -5235,14 +5245,14 @@ state 314 . error -315: shift/reduce conflict (shift 369(0), red'n 225(0)) on '(' +315: shift/reduce conflict (shift 369(0), red'n 229(0)) on '(' state 315 - annotation: '@' TIdent. (225) + annotation: '@' TIdent. (229) annotation: '@' TIdent.'(' ')' annotation: '@' TIdent.'(' exprlist ')' '(' shift 369 - . reduce 225 (src line 1206) + . reduce 229 (src line 1222) state 316 @@ -5265,11 +5275,11 @@ state 317 state 318 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typeexprlist: typeexpr. (200) + typeexprlist: typeexpr. (204) '|' shift 248 '&' shift 249 - . reduce 200 (src line 1112) + . reduce 204 (src line 1128) state 319 @@ -5297,7 +5307,7 @@ state 321 state 322 typefieldlist: typefieldlist ','.typefield - typefieldlist: typefieldlist ','. (214) + typefieldlist: typefieldlist ','. (218) TType shift 270 TInterface shift 339 @@ -5309,7 +5319,7 @@ state 322 TKeyof shift 342 TExtends shift 273 TIdent shift 338 - . reduce 214 (src line 1167) + . reduce 218 (src line 1183) typefieldname goto 264 typefield goto 375 @@ -5476,6 +5486,7 @@ state 330 primarytypeexpr: '(' funcparamlist.')' TArrow typeexpr primarytypeexpr: '(' '(' funcparamlist.')' TArrow typeexpr ')' primarytypeexpr: '(' '(' funcparamlist.')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '(' '(' funcparamlist.')' TArrow '(' ')' ')' funcparamlist: funcparamlist.',' funcparam funcparamlist: funcparamlist.',' T3Comma typeexpr @@ -5490,6 +5501,7 @@ state 331 primarytypeexpr: '(' ')'.TArrow '(' ')' primarytypeexpr: '(' '(' ')'.TArrow typeexpr ')' primarytypeexpr: '(' '(' ')'.TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '(' '(' ')'.TArrow '(' ')' ')' TArrow shift 386 . error @@ -5498,11 +5510,11 @@ state 331 state 332 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - funcparamlist: T3Comma typeexpr. (210) + funcparamlist: T3Comma typeexpr. (214) '|' shift 248 '&' shift 249 - . reduce 210 (src line 1153) + . reduce 214 (src line 1169) state 333 @@ -5531,6 +5543,7 @@ state 333 state 334 primarytypeexpr: TFun '(' funcparamlist.')' ':' typeexpr primarytypeexpr: TFun '(' funcparamlist.')' ':' '(' typeexprlist2 ')' + primarytypeexpr: TFun '(' funcparamlist.')' ':' '(' ')' primarytypeexpr: TFun '(' funcparamlist.')' funcparamlist: funcparamlist.',' funcparam funcparamlist: funcparamlist.',' T3Comma typeexpr @@ -5540,14 +5553,15 @@ state 334 . error -335: shift/reduce conflict (shift 389(0), red'n 187(0)) on ':' +335: shift/reduce conflict (shift 389(0), red'n 191(0)) on ':' state 335 primarytypeexpr: TFun '(' ')'.':' typeexpr primarytypeexpr: TFun '(' ')'.':' '(' typeexprlist2 ')' - primarytypeexpr: TFun '(' ')'. (187) + primarytypeexpr: TFun '(' ')'.':' '(' ')' + primarytypeexpr: TFun '(' ')'. (191) ':' shift 389 - . reduce 187 (src line 1056) + . reduce 191 (src line 1072) state 336 @@ -5561,9 +5575,9 @@ state 336 state 337 - primarytypeexpr: TInterface '{' '}'. (189) + primarytypeexpr: TInterface '{' '}'. (193) - . reduce 189 (src line 1064) + . reduce 193 (src line 1080) state 338 @@ -5777,7 +5791,9 @@ state 353 primarytypeexpr: '('.'(' funcparamlist ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' TFalse shift 189 TNil shift 187 @@ -5841,19 +5857,19 @@ state 357 block goto 403 state 358 - typeparamlist: typeparamlist ',' typeparam. (230) + typeparamlist: typeparamlist ',' typeparam. (234) - . reduce 230 (src line 1225) + . reduce 234 (src line 1241) state 359 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typeparam: TIdent ':' typeexpr. (232) + typeparam: TIdent ':' typeexpr. (236) '|' shift 248 '&' shift 249 - . reduce 232 (src line 1233) + . reduce 236 (src line 1249) state 360 @@ -6025,9 +6041,9 @@ state 367 primarytypeexpr goto 186 state 368 - primarytypeexpr: primarytypeexpr '[' typeexpr ']'. (199) + primarytypeexpr: primarytypeexpr '[' typeexpr ']'. (203) - . reduce 199 (src line 1106) + . reduce 203 (src line 1122) state 369 @@ -6091,15 +6107,15 @@ state 371 primarytypeexpr goto 186 state 372 - closegt: '>'. (204) + closegt: '>'. (208) - . reduce 204 (src line 1128) + . reduce 208 (src line 1144) state 373 - closegt: TShr. (205) + closegt: TShr. (209) - . reduce 205 (src line 1130) + . reduce 209 (src line 1146) state 374 @@ -6110,21 +6126,21 @@ state 374 state 375 - typefieldlist: typefieldlist ',' typefield. (213) + typefieldlist: typefieldlist ',' typefield. (217) - . reduce 213 (src line 1164) + . reduce 217 (src line 1180) state 376 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typefield: TIdent ':' typeexpr. (215) + typefield: TIdent ':' typeexpr. (219) typefield: TIdent ':' typeexpr.annotations '|' shift 248 '&' shift 249 '@' shift 255 - . reduce 215 (src line 1172) + . reduce 219 (src line 1188) annotation goto 254 annotations goto 416 @@ -6132,13 +6148,13 @@ state 376 state 377 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typefield: TIdent TQuestionColon typeexpr. (217) + typefield: TIdent TQuestionColon typeexpr. (221) typefield: TIdent TQuestionColon typeexpr.annotations '|' shift 248 '&' shift 249 '@' shift 255 - . reduce 217 (src line 1178) + . reduce 221 (src line 1194) annotation goto 254 annotations goto 417 @@ -6146,13 +6162,13 @@ state 377 state 378 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typefield: typefieldname ':' typeexpr. (219) + typefield: typefieldname ':' typeexpr. (223) typefield: typefieldname ':' typeexpr.annotations '|' shift 248 '&' shift 249 '@' shift 255 - . reduce 219 (src line 1184) + . reduce 223 (src line 1200) annotation goto 254 annotations goto 418 @@ -6160,13 +6176,13 @@ state 378 state 379 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typefield: typefieldname TQuestionColon typeexpr. (221) + typefield: typefieldname TQuestionColon typeexpr. (225) typefield: typefieldname TQuestionColon typeexpr.annotations '|' shift 248 '&' shift 249 '@' shift 255 - . reduce 221 (src line 1190) + . reduce 225 (src line 1206) annotation goto 254 annotations goto 419 @@ -6197,9 +6213,9 @@ state 380 primarytypeexpr goto 186 state 381 - funcparamlist: funcparamlist ',' funcparam. (209) + funcparamlist: funcparamlist ',' funcparam. (213) - . reduce 209 (src line 1150) + . reduce 213 (src line 1166) state 382 @@ -6237,7 +6253,9 @@ state 383 primarytypeexpr: '('.'(' funcparamlist ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' TFalse shift 189 TNil shift 187 @@ -6282,6 +6300,7 @@ state 385 primarytypeexpr: '(' funcparamlist ')'.TArrow typeexpr primarytypeexpr: '(' '(' funcparamlist ')'.TArrow typeexpr ')' primarytypeexpr: '(' '(' funcparamlist ')'.TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '(' '(' funcparamlist ')'.TArrow '(' ')' ')' TArrow shift 426 . error @@ -6293,6 +6312,7 @@ state 386 primarytypeexpr: '(' ')' TArrow.'(' ')' primarytypeexpr: '(' '(' ')' TArrow.typeexpr ')' primarytypeexpr: '(' '(' ')' TArrow.'(' typeexprlist2 ')' ')' + primarytypeexpr: '(' '(' ')' TArrow.'(' ')' ')' TFalse shift 189 TNil shift 187 @@ -6317,26 +6337,28 @@ state 386 state 387 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - funcparam: TIdent ':' typeexpr. (206) + funcparam: TIdent ':' typeexpr. (210) '|' shift 248 '&' shift 249 - . reduce 206 (src line 1139) + . reduce 210 (src line 1155) -388: shift/reduce conflict (shift 429(0), red'n 186(0)) on ':' +388: shift/reduce conflict (shift 429(0), red'n 190(0)) on ':' state 388 primarytypeexpr: TFun '(' funcparamlist ')'.':' typeexpr primarytypeexpr: TFun '(' funcparamlist ')'.':' '(' typeexprlist2 ')' - primarytypeexpr: TFun '(' funcparamlist ')'. (186) + primarytypeexpr: TFun '(' funcparamlist ')'.':' '(' ')' + primarytypeexpr: TFun '(' funcparamlist ')'. (190) ':' shift 429 - . reduce 186 (src line 1052) + . reduce 190 (src line 1068) state 389 primarytypeexpr: TFun '(' ')' ':'.typeexpr primarytypeexpr: TFun '(' ')' ':'.'(' typeexprlist2 ')' + primarytypeexpr: TFun '(' ')' ':'.'(' ')' TFalse shift 189 TNil shift 187 @@ -6359,15 +6381,15 @@ state 389 primarytypeexpr goto 186 state 390 - primarytypeexpr: TInterface '{' typefieldlist '}'. (188) + primarytypeexpr: TInterface '{' typefieldlist '}'. (192) - . reduce 188 (src line 1060) + . reduce 192 (src line 1076) state 391 - primarytypeexpr: TReadonly '{' typeexpr '}'. (191) + primarytypeexpr: TReadonly '{' typeexpr '}'. (195) - . reduce 191 (src line 1072) + . reduce 195 (src line 1088) state 392 @@ -6382,33 +6404,33 @@ state 392 state 393 - primarytypeexpr: TReadonly '{' typefieldlist '}'. (193) + primarytypeexpr: TReadonly '{' typefieldlist '}'. (197) - . reduce 193 (src line 1082) + . reduce 197 (src line 1098) -394: shift/reduce conflict (shift 248(4), red'n 196(0)) on '|' -394: shift/reduce conflict (shift 249(6), red'n 196(0)) on '&' +394: shift/reduce conflict (shift 248(4), red'n 200(0)) on '|' +394: shift/reduce conflict (shift 249(6), red'n 200(0)) on '&' state 394 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - primarytypeexpr: TAsserts TIdent TIs typeexpr. (196) + primarytypeexpr: TAsserts TIdent TIs typeexpr. (200) '|' shift 248 '&' shift 249 - . reduce 196 (src line 1094) + . reduce 200 (src line 1110) state 395 - primarytypeexpr: TTypeof '(' expr ')'. (197) + primarytypeexpr: TTypeof '(' expr ')'. (201) - . reduce 197 (src line 1098) + . reduce 201 (src line 1114) state 396 - primarytypeexpr: TKeyof '(' typeexpr ')'. (198) + primarytypeexpr: TKeyof '(' typeexpr ')'. (202) - . reduce 198 (src line 1102) + . reduce 202 (src line 1118) state 397 @@ -6434,17 +6456,17 @@ state 399 . reduce 152 (src line 896) -400: shift/reduce conflict (shift 434(0), red'n 207(0)) on ',' +400: shift/reduce conflict (shift 434(0), red'n 211(0)) on ',' state 400 returntypeannot: ':' '(' typeexpr.',' typeexprlist ')' typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - funcparam: typeexpr. (207) + funcparam: typeexpr. (211) '|' shift 248 '&' shift 249 ',' shift 434 - . reduce 207 (src line 1142) + . reduce 211 (src line 1158) state 401 @@ -6670,9 +6692,9 @@ state 411 state 412 - annotation: '@' TIdent '(' ')'. (226) + annotation: '@' TIdent '(' ')'. (230) - . reduce 226 (src line 1209) + . reduce 230 (src line 1225) state 413 @@ -6687,11 +6709,11 @@ state 413 state 414 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typeexprlist: typeexprlist ',' typeexpr. (201) + typeexprlist: typeexprlist ',' typeexpr. (205) '|' shift 248 '&' shift 249 - . reduce 201 (src line 1115) + . reduce 205 (src line 1131) state 415 @@ -6718,38 +6740,38 @@ state 415 primarytypeexpr goto 186 state 416 - typefield: TIdent ':' typeexpr annotations. (216) + typefield: TIdent ':' typeexpr annotations. (220) annotations: annotations.annotation '@' shift 255 - . reduce 216 (src line 1175) + . reduce 220 (src line 1191) annotation goto 312 state 417 - typefield: TIdent TQuestionColon typeexpr annotations. (218) + typefield: TIdent TQuestionColon typeexpr annotations. (222) annotations: annotations.annotation '@' shift 255 - . reduce 218 (src line 1181) + . reduce 222 (src line 1197) annotation goto 312 state 418 - typefield: typefieldname ':' typeexpr annotations. (220) + typefield: typefieldname ':' typeexpr annotations. (224) annotations: annotations.annotation '@' shift 255 - . reduce 220 (src line 1187) + . reduce 224 (src line 1203) annotation goto 312 state 419 - typefield: typefieldname TQuestionColon typeexpr annotations. (222) + typefield: typefieldname TQuestionColon typeexpr annotations. (226) annotations: annotations.annotation '@' shift 255 - . reduce 222 (src line 1193) + . reduce 226 (src line 1209) annotation goto 312 @@ -6765,7 +6787,9 @@ state 420 primarytypeexpr: '('.'(' funcparamlist ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' TFalse shift 189 TNil shift 187 @@ -6807,11 +6831,11 @@ state 421 state 422 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - funcparamlist: funcparamlist ',' T3Comma typeexpr. (211) + funcparamlist: funcparamlist ',' T3Comma typeexpr. (215) '|' shift 248 '&' shift 249 - . reduce 211 (src line 1156) + . reduce 215 (src line 1172) state 423 @@ -6833,17 +6857,17 @@ state 424 . error -425: shift/reduce conflict (shift 451(0), red'n 207(0)) on ',' +425: shift/reduce conflict (shift 451(0), red'n 211(0)) on ',' state 425 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr typeexprlist2: typeexpr.',' typeexpr - funcparam: typeexpr. (207) + funcparam: typeexpr. (211) '|' shift 248 '&' shift 249 ',' shift 451 - . reduce 207 (src line 1142) + . reduce 211 (src line 1158) state 426 @@ -6852,6 +6876,7 @@ state 426 primarytypeexpr: '(' funcparamlist ')' TArrow.typeexpr primarytypeexpr: '(' '(' funcparamlist ')' TArrow.typeexpr ')' primarytypeexpr: '(' '(' funcparamlist ')' TArrow.'(' typeexprlist2 ')' ')' + primarytypeexpr: '(' '(' funcparamlist ')' TArrow.'(' ')' ')' TFalse shift 189 TNil shift 187 @@ -6885,8 +6910,11 @@ state 427 primarytypeexpr: '('.'(' funcparamlist ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' primarytypeexpr: '(' '(' ')' TArrow '('.typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' + primarytypeexpr: '(' '(' ')' TArrow '('.')' ')' TFalse shift 189 TNil shift 187 @@ -6903,17 +6931,17 @@ state 427 TString shift 190 '{' shift 193 '(' shift 276 - ')' shift 423 + ')' shift 454 . error typeexpr goto 425 simpletypeexpr goto 185 primarytypeexpr goto 186 - typeexprlist2 goto 454 + typeexprlist2 goto 455 funcparam goto 277 funcparamlist goto 274 -428: shift/reduce conflict (shift 455(0), red'n 176(0)) on ')' +428: shift/reduce conflict (shift 456(0), red'n 176(0)) on ')' 428: shift/reduce conflict (shift 248(4), red'n 176(0)) on '|' 428: shift/reduce conflict (shift 249(6), red'n 176(0)) on '&' state 428 @@ -6922,7 +6950,7 @@ state 428 primarytypeexpr: '(' ')' TArrow typeexpr. (176) primarytypeexpr: '(' '(' ')' TArrow typeexpr.')' - ')' shift 455 + ')' shift 456 '|' shift 248 '&' shift 249 . reduce 176 (src line 1012) @@ -6931,6 +6959,7 @@ state 428 state 429 primarytypeexpr: TFun '(' funcparamlist ')' ':'.typeexpr primarytypeexpr: TFun '(' funcparamlist ')' ':'.'(' typeexprlist2 ')' + primarytypeexpr: TFun '(' funcparamlist ')' ':'.'(' ')' TFalse shift 189 TNil shift 187 @@ -6945,23 +6974,23 @@ state 429 TNumber shift 191 TString shift 190 '{' shift 193 - '(' shift 457 + '(' shift 458 . error - typeexpr goto 456 + typeexpr goto 457 simpletypeexpr goto 185 primarytypeexpr goto 186 -430: shift/reduce conflict (shift 248(4), red'n 184(0)) on '|' -430: shift/reduce conflict (shift 249(6), red'n 184(0)) on '&' +430: shift/reduce conflict (shift 248(4), red'n 186(0)) on '|' +430: shift/reduce conflict (shift 249(6), red'n 186(0)) on '&' state 430 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - primarytypeexpr: TFun '(' ')' ':' typeexpr. (184) + primarytypeexpr: TFun '(' ')' ':' typeexpr. (186) '|' shift 248 '&' shift 249 - . reduce 184 (src line 1044) + . reduce 186 (src line 1052) state 431 @@ -6974,8 +7003,11 @@ state 431 primarytypeexpr: '('.'(' funcparamlist ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' primarytypeexpr: TFun '(' ')' ':' '('.typeexprlist2 ')' + primarytypeexpr: TFun '(' ')' ':' '('.')' TFalse shift 189 TNil shift 187 @@ -6992,20 +7024,20 @@ state 431 TString shift 190 '{' shift 193 '(' shift 276 - ')' shift 275 + ')' shift 459 . error typeexpr goto 425 simpletypeexpr goto 185 primarytypeexpr goto 186 - typeexprlist2 goto 458 + typeexprlist2 goto 460 funcparam goto 277 funcparamlist goto 274 state 432 primarytypeexpr: TReadonly '{' '[' typeexpr ']'.':' typeexpr '}' - ':' shift 459 + ':' shift 461 . error @@ -7037,7 +7069,7 @@ state 434 typeexpr goto 318 simpletypeexpr goto 185 primarytypeexpr goto 186 - typeexprlist goto 460 + typeexprlist goto 462 state 435 parlist: typednamelist ',' T3Comma ':' typeexpr. (121) @@ -7052,7 +7084,7 @@ state 435 state 436 funcbody: typeparams '(' parlist ')' returntypeannot block.TEnd - TEnd shift 461 + TEnd shift 463 . error @@ -7076,12 +7108,12 @@ state 439 chunk goto 56 chunk1 goto 2 - block goto 462 + block goto 464 state 440 stat: TFor TIdent '=' expr ',' expr TDo block.TEnd - TEnd shift 463 + TEnd shift 465 . error @@ -7113,7 +7145,7 @@ state 441 expr: expr.TBang TAnd shift 76 - TDo shift 464 + TDo shift 466 TOr shift 75 TAs shift 96 TEqeq shift 81 @@ -7147,13 +7179,13 @@ state 442 ':' shift 289 . reduce 150 (src line 890) - returntypeannot goto 465 + returntypeannot goto 467 state 443 typednamelist: typednamelist.',' typedname interfacemethod: TFunction TIdent '(' typednamelist.')' returntypeannot - ')' shift 466 + ')' shift 468 ',' shift 154 . error @@ -7177,14 +7209,14 @@ state 444 '(' shift 194 . error - typeexpr goto 467 + typeexpr goto 469 simpletypeexpr goto 185 primarytypeexpr goto 186 state 445 - annotation: '@' TIdent '(' exprlist ')'. (227) + annotation: '@' TIdent '(' exprlist ')'. (231) - . reduce 227 (src line 1212) + . reduce 231 (src line 1228) state 446 @@ -7192,7 +7224,7 @@ state 446 typeexpr: typeexpr.'&' simpletypeexpr primarytypeexpr: '{' '[' typeexpr ']' ':' typeexpr.'}' - '}' shift 468 + '}' shift 470 '|' shift 248 '&' shift 249 . error @@ -7202,7 +7234,7 @@ state 447 primarytypeexpr: '(' funcparamlist ')' TArrow '(' typeexprlist2.')' typeexprlist2: typeexprlist2.',' typeexpr - ')' shift 469 + ')' shift 471 ',' shift 450 . error @@ -7242,7 +7274,7 @@ state 450 '(' shift 194 . error - typeexpr goto 470 + typeexpr goto 472 simpletypeexpr goto 185 primarytypeexpr goto 186 @@ -7265,7 +7297,7 @@ state 451 '(' shift 194 . error - typeexpr goto 471 + typeexpr goto 473 simpletypeexpr goto 185 primarytypeexpr goto 186 @@ -7282,7 +7314,10 @@ state 452 primarytypeexpr: '('.'(' ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' primarytypeexpr: '(' '(' funcparamlist ')' TArrow '('.typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' + primarytypeexpr: '(' '(' funcparamlist ')' TArrow '('.')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' TFalse shift 189 TNil shift 187 @@ -7299,17 +7334,17 @@ state 452 TString shift 190 '{' shift 193 '(' shift 276 - ')' shift 448 + ')' shift 475 . error typeexpr goto 425 simpletypeexpr goto 185 primarytypeexpr goto 186 - typeexprlist2 goto 472 + typeexprlist2 goto 474 funcparam goto 277 funcparamlist goto 274 -453: shift/reduce conflict (shift 473(0), red'n 174(0)) on ')' +453: shift/reduce conflict (shift 476(0), red'n 174(0)) on ')' 453: shift/reduce conflict (shift 248(4), red'n 174(0)) on '|' 453: shift/reduce conflict (shift 249(6), red'n 174(0)) on '&' state 453 @@ -7318,41 +7353,54 @@ state 453 primarytypeexpr: '(' funcparamlist ')' TArrow typeexpr. (174) primarytypeexpr: '(' '(' funcparamlist ')' TArrow typeexpr.')' - ')' shift 473 + ')' shift 476 '|' shift 248 '&' shift 249 . reduce 174 (src line 1004) +454: shift/reduce conflict (shift 477(0), red'n 177(0)) on ')' state 454 + primarytypeexpr: '(' ')'.TArrow '(' typeexprlist2 ')' + primarytypeexpr: '(' ')'.TArrow typeexpr + primarytypeexpr: '(' ')'.TArrow '(' ')' + primarytypeexpr: '(' ')' TArrow '(' ')'. (177) + primarytypeexpr: '(' '(' ')' TArrow '(' ')'.')' + + ')' shift 477 + TArrow shift 329 + . reduce 177 (src line 1016) + + +state 455 primarytypeexpr: '(' ')' TArrow '(' typeexprlist2.')' primarytypeexpr: '(' '(' ')' TArrow '(' typeexprlist2.')' ')' typeexprlist2: typeexprlist2.',' typeexpr - ')' shift 474 + ')' shift 478 ',' shift 450 . error -state 455 +state 456 primarytypeexpr: '(' '(' ')' TArrow typeexpr ')'. (179) . reduce 179 (src line 1024) -456: shift/reduce conflict (shift 248(4), red'n 182(0)) on '|' -456: shift/reduce conflict (shift 249(6), red'n 182(0)) on '&' -state 456 +457: shift/reduce conflict (shift 248(4), red'n 184(0)) on '|' +457: shift/reduce conflict (shift 249(6), red'n 184(0)) on '&' +state 457 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - primarytypeexpr: TFun '(' funcparamlist ')' ':' typeexpr. (182) + primarytypeexpr: TFun '(' funcparamlist ')' ':' typeexpr. (184) '|' shift 248 '&' shift 249 - . reduce 182 (src line 1036) + . reduce 184 (src line 1044) -state 457 +state 458 primarytypeexpr: '('.funcparamlist ')' TArrow '(' typeexprlist2 ')' primarytypeexpr: '('.funcparamlist ')' TArrow '(' ')' primarytypeexpr: '('.funcparamlist ')' TArrow typeexpr @@ -7362,8 +7410,11 @@ state 457 primarytypeexpr: '('.'(' funcparamlist ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' ')' TArrow typeexpr ')' primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' funcparamlist ')' TArrow '(' ')' ')' primarytypeexpr: '('.'(' ')' TArrow '(' typeexprlist2 ')' ')' + primarytypeexpr: '('.'(' ')' TArrow '(' ')' ')' primarytypeexpr: TFun '(' funcparamlist ')' ':' '('.typeexprlist2 ')' + primarytypeexpr: TFun '(' funcparamlist ')' ':' '('.')' TFalse shift 189 TNil shift 187 @@ -7380,26 +7431,36 @@ state 457 TString shift 190 '{' shift 193 '(' shift 276 - ')' shift 275 + ')' shift 479 . error typeexpr goto 425 simpletypeexpr goto 185 primarytypeexpr goto 186 - typeexprlist2 goto 475 + typeexprlist2 goto 480 funcparam goto 277 funcparamlist goto 274 -state 458 +state 459 + primarytypeexpr: '(' ')'.TArrow '(' typeexprlist2 ')' + primarytypeexpr: '(' ')'.TArrow typeexpr + primarytypeexpr: '(' ')'.TArrow '(' ')' + primarytypeexpr: TFun '(' ')' ':' '(' ')'. (189) + + TArrow shift 329 + . reduce 189 (src line 1064) + + +state 460 primarytypeexpr: TFun '(' ')' ':' '(' typeexprlist2.')' typeexprlist2: typeexprlist2.',' typeexpr - ')' shift 476 + ')' shift 481 ',' shift 450 . error -state 459 +state 461 primarytypeexpr: TReadonly '{' '[' typeexpr ']' ':'.typeexpr '}' TFalse shift 189 @@ -7418,38 +7479,38 @@ state 459 '(' shift 194 . error - typeexpr goto 477 + typeexpr goto 482 simpletypeexpr goto 185 primarytypeexpr goto 186 -state 460 +state 462 returntypeannot: ':' '(' typeexpr ',' typeexprlist.')' typeexprlist: typeexprlist.',' typeexpr - ')' shift 478 + ')' shift 483 ',' shift 371 . error -state 461 +state 463 funcbody: typeparams '(' parlist ')' returntypeannot block TEnd. (115) . reduce 115 (src line 753) -state 462 +state 464 elseifs: elseifs TElseIf expr TThen block. (28) . reduce 28 (src line 318) -state 463 +state 465 stat: TFor TIdent '=' expr ',' expr TDo block TEnd. (15) . reduce 15 (src line 250) -state 464 +state 466 stat: TFor TIdent '=' expr ',' expr ',' expr TDo.block TEnd chunk1: . (4) @@ -7457,26 +7518,26 @@ state 464 chunk goto 56 chunk1 goto 2 - block goto 479 + block goto 484 -state 465 - interfacemethod: TFunction TIdent '(' ')' returntypeannot. (242) +state 467 + interfacemethod: TFunction TIdent '(' ')' returntypeannot. (246) - . reduce 242 (src line 1273) + . reduce 246 (src line 1289) -state 466 +state 468 interfacemethod: TFunction TIdent '(' typednamelist ')'.returntypeannot returntypeannot: . (150) ':' shift 289 . reduce 150 (src line 890) - returntypeannot goto 480 + returntypeannot goto 485 -467: shift/reduce conflict (shift 248(4), red'n 157(0)) on '|' -467: shift/reduce conflict (shift 249(6), red'n 157(0)) on '&' -state 467 +469: shift/reduce conflict (shift 248(4), red'n 157(0)) on '|' +469: shift/reduce conflict (shift 249(6), red'n 157(0)) on '&' +state 469 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr typeexpr: simpletypeexpr TExtends simpletypeexpr TQuestion typeexpr ':' typeexpr. (157) @@ -7486,156 +7547,191 @@ state 467 . reduce 157 (src line 923) -state 468 +state 470 primarytypeexpr: '{' '[' typeexpr ']' ':' typeexpr '}'. (170) . reduce 170 (src line 988) -state 469 +state 471 primarytypeexpr: '(' funcparamlist ')' TArrow '(' typeexprlist2 ')'. (172) . reduce 172 (src line 996) -state 470 +state 472 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typeexprlist2: typeexprlist2 ',' typeexpr. (203) + typeexprlist2: typeexprlist2 ',' typeexpr. (207) '|' shift 248 '&' shift 249 - . reduce 203 (src line 1123) + . reduce 207 (src line 1139) -state 471 +state 473 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr - typeexprlist2: typeexpr ',' typeexpr. (202) + typeexprlist2: typeexpr ',' typeexpr. (206) '|' shift 248 '&' shift 249 - . reduce 202 (src line 1120) + . reduce 206 (src line 1136) -state 472 +state 474 primarytypeexpr: '(' funcparamlist ')' TArrow '(' typeexprlist2.')' primarytypeexpr: '(' '(' funcparamlist ')' TArrow '(' typeexprlist2.')' ')' typeexprlist2: typeexprlist2.',' typeexpr - ')' shift 481 + ')' shift 486 ',' shift 450 . error -state 473 +475: shift/reduce conflict (shift 487(0), red'n 173(0)) on ')' +state 475 + primarytypeexpr: '(' funcparamlist ')' TArrow '(' ')'. (173) + primarytypeexpr: '(' ')'.TArrow '(' typeexprlist2 ')' + primarytypeexpr: '(' ')'.TArrow typeexpr + primarytypeexpr: '(' ')'.TArrow '(' ')' + primarytypeexpr: '(' '(' funcparamlist ')' TArrow '(' ')'.')' + + ')' shift 487 + TArrow shift 329 + . reduce 173 (src line 1000) + + +state 476 primarytypeexpr: '(' '(' funcparamlist ')' TArrow typeexpr ')'. (178) . reduce 178 (src line 1020) -474: shift/reduce conflict (shift 482(0), red'n 175(0)) on ')' -state 474 +state 477 + primarytypeexpr: '(' '(' ')' TArrow '(' ')' ')'. (183) + + . reduce 183 (src line 1040) + + +478: shift/reduce conflict (shift 488(0), red'n 175(0)) on ')' +state 478 primarytypeexpr: '(' ')' TArrow '(' typeexprlist2 ')'. (175) primarytypeexpr: '(' '(' ')' TArrow '(' typeexprlist2 ')'.')' - ')' shift 482 + ')' shift 488 . reduce 175 (src line 1008) -state 475 +state 479 + primarytypeexpr: '(' ')'.TArrow '(' typeexprlist2 ')' + primarytypeexpr: '(' ')'.TArrow typeexpr + primarytypeexpr: '(' ')'.TArrow '(' ')' + primarytypeexpr: TFun '(' funcparamlist ')' ':' '(' ')'. (188) + + TArrow shift 329 + . reduce 188 (src line 1060) + + +state 480 primarytypeexpr: TFun '(' funcparamlist ')' ':' '(' typeexprlist2.')' typeexprlist2: typeexprlist2.',' typeexpr - ')' shift 483 + ')' shift 489 ',' shift 450 . error -state 476 - primarytypeexpr: TFun '(' ')' ':' '(' typeexprlist2 ')'. (185) +state 481 + primarytypeexpr: TFun '(' ')' ':' '(' typeexprlist2 ')'. (187) - . reduce 185 (src line 1048) + . reduce 187 (src line 1056) -state 477 +state 482 typeexpr: typeexpr.'|' simpletypeexpr typeexpr: typeexpr.'&' simpletypeexpr primarytypeexpr: TReadonly '{' '[' typeexpr ']' ':' typeexpr.'}' - '}' shift 484 + '}' shift 490 '|' shift 248 '&' shift 249 . error -state 478 +state 483 returntypeannot: ':' '(' typeexpr ',' typeexprlist ')'. (153) . reduce 153 (src line 899) -state 479 +state 484 stat: TFor TIdent '=' expr ',' expr ',' expr TDo block.TEnd - TEnd shift 485 + TEnd shift 491 . error -state 480 - interfacemethod: TFunction TIdent '(' typednamelist ')' returntypeannot. (243) +state 485 + interfacemethod: TFunction TIdent '(' typednamelist ')' returntypeannot. (247) - . reduce 243 (src line 1279) + . reduce 247 (src line 1295) -481: shift/reduce conflict (shift 486(0), red'n 172(0)) on ')' -state 481 +486: shift/reduce conflict (shift 492(0), red'n 172(0)) on ')' +state 486 primarytypeexpr: '(' funcparamlist ')' TArrow '(' typeexprlist2 ')'. (172) primarytypeexpr: '(' '(' funcparamlist ')' TArrow '(' typeexprlist2 ')'.')' - ')' shift 486 + ')' shift 492 . reduce 172 (src line 996) -state 482 - primarytypeexpr: '(' '(' ')' TArrow '(' typeexprlist2 ')' ')'. (181) +state 487 + primarytypeexpr: '(' '(' funcparamlist ')' TArrow '(' ')' ')'. (181) . reduce 181 (src line 1032) -state 483 - primarytypeexpr: TFun '(' funcparamlist ')' ':' '(' typeexprlist2 ')'. (183) +state 488 + primarytypeexpr: '(' '(' ')' TArrow '(' typeexprlist2 ')' ')'. (182) - . reduce 183 (src line 1040) + . reduce 182 (src line 1036) -state 484 - primarytypeexpr: TReadonly '{' '[' typeexpr ']' ':' typeexpr '}'. (192) +state 489 + primarytypeexpr: TFun '(' funcparamlist ')' ':' '(' typeexprlist2 ')'. (185) - . reduce 192 (src line 1077) + . reduce 185 (src line 1048) -state 485 +state 490 + primarytypeexpr: TReadonly '{' '[' typeexpr ']' ':' typeexpr '}'. (196) + + . reduce 196 (src line 1093) + + +state 491 stat: TFor TIdent '=' expr ',' expr ',' expr TDo block TEnd. (16) . reduce 16 (src line 255) -state 486 +state 492 primarytypeexpr: '(' '(' funcparamlist ')' TArrow '(' typeexprlist2 ')' ')'. (180) . reduce 180 (src line 1028) 78 terminals, 49 nonterminals -244 grammar rules, 487/16000 states -37 shift/reduce, 0 reduce/reduce conflicts reported +248 grammar rules, 493/16000 states +39 shift/reduce, 0 reduce/reduce conflicts reported 98 working sets used memory: parser 806/240000 -355 extra closures -2472 shift entries, 6 exceptions +361 extra closures +2478 shift entries, 6 exceptions 222 goto entries 493 entries saved by goto default -Optimizer space used: output 1808/240000 -1808 table entries, 550 zero -maximum spread: 78, maximum offset: 466 +Optimizer space used: output 1912/240000 +1912 table entries, 586 zero +maximum spread: 78, maximum offset: 468 diff --git a/compiler/stdlib/globals.go b/compiler/stdlib/globals.go index 5504015d..533b676a 100644 --- a/compiler/stdlib/globals.go +++ b/compiler/stdlib/globals.go @@ -39,7 +39,7 @@ var ( OptParam("message", typ.String). Returns(typ.Any). Effects(effect.Throws()). - WithRefinement(&constraint.FunctionEffect{ + WithRefinement(&constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints( constraint.Truthy{Path: constraint.ParamPath(0)}, ), diff --git a/compiler/stdlib/string.go b/compiler/stdlib/string.go index 7acbbe4e..052ffc30 100644 --- a/compiler/stdlib/string.go +++ b/compiler/stdlib/string.go @@ -6,6 +6,13 @@ import ( "github.com/wippyai/go-lua/types/typ" ) +var stringUnpackSpec = contract.NewSpec().WithEffects( + effect.Return{ + ReturnIndex: 0, + Transform: effect.StringUnpackValue{Format: effect.ParamRef{Index: 0}}, + }, +) + var stringMethods = typ.NewRecord(). Field("byte", typ.Func(). Param("s", typ.String). @@ -79,7 +86,9 @@ var stringMethods = typ.NewRecord(). Param("fmt", typ.String). Param("s", typ.String). OptParam("pos", typ.Integer). - Returns(typ.Any).Build()). + Returns(typ.Any). + Spec(stringUnpackSpec). + Build()). Field("upper", typ.Func(). Param("s", typ.String). Returns(typ.String).Build()). diff --git a/fixture_harness_test.go b/fixture_harness_test.go index 3b92a2a7..f4f45f5a 100644 --- a/fixture_harness_test.go +++ b/fixture_harness_test.go @@ -14,6 +14,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/tests/testutil" "github.com/wippyai/go-lua/types/diag" "github.com/wippyai/go-lua/types/io" + "github.com/wippyai/go-lua/types/typ" ) // Suite describes a fixture suite loaded from manifest.json. @@ -393,11 +394,37 @@ func resolvePackageManifest(name string) *io.Manifest { return testutil.ChannelManifest() case "funcs": return testutil.FuncsManifest() + case "time": + return fixtureTimeManifest() default: return nil } } +func fixtureTimeManifest() *io.Manifest { + m := io.NewManifest("time") + + durationType := typ.NewInterface("time.Duration", []typ.Method{ + {Name: "seconds", Type: typ.Func().Param("self", typ.Self).Returns(typ.Number).Build()}, + }) + + timeType := typ.NewInterface("time.Time", []typ.Method{ + {Name: "sub", Type: typ.Func().Param("self", typ.Self).Param("t", typ.Self).Returns(durationType).Build()}, + {Name: "add", Type: typ.Func().Param("self", typ.Self).Param("d", durationType).Returns(typ.Self).Build()}, + {Name: "unix", Type: typ.Func().Param("self", typ.Self).Returns(typ.Integer).Build()}, + }) + + m.DefineType("Time", timeType) + m.DefineType("Duration", durationType) + + moduleType := typ.NewInterface("time", []typ.Method{ + {Name: "now", Type: typ.Func().Returns(timeType).Build()}, + }) + m.SetExport(moduleType) + + return m +} + // installRequire sets up a require() global that loads modules from the given source map. // Modules are compiled, executed, cached, and returned — matching standard Lua require semantics. func installRequire(L *LState, sources map[string]string) { diff --git a/testdata/fixtures/flow/active-session-typed-map-time-sub/main.lua b/testdata/fixtures/flow/active-session-typed-map-time-sub/main.lua new file mode 100644 index 00000000..64df1276 --- /dev/null +++ b/testdata/fixtures/flow/active-session-typed-map-time-sub/main.lua @@ -0,0 +1,25 @@ +local time = require("time") + +type ActiveSession = { + created_at: time.Time, + last_activity: time.Time?, +} + +local state = { + active_sessions = {} :: {[string]: ActiveSession}, +} + +local now = time.now() + +state.active_sessions["s1"] = { + created_at = now, + last_activity = now, +} + +for _, session_info in pairs(state.active_sessions) do + local last_activity = session_info.last_activity or session_info.created_at + local elapsed = now:sub(last_activity) + return elapsed:seconds() +end + +return 0 diff --git a/testdata/fixtures/flow/active-session-typed-map-time-sub/manifest.json b/testdata/fixtures/flow/active-session-typed-map-time-sub/manifest.json new file mode 100644 index 00000000..636568e3 --- /dev/null +++ b/testdata/fixtures/flow/active-session-typed-map-time-sub/manifest.json @@ -0,0 +1,5 @@ +{ + "description": "Annotating the active_sessions map preserves ActiveSession values through pairs iteration", + "packages": ["time"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/flow/active-session-untyped-map-time-sub-soundness/main.lua b/testdata/fixtures/flow/active-session-untyped-map-time-sub-soundness/main.lua new file mode 100644 index 00000000..2c1c0660 --- /dev/null +++ b/testdata/fixtures/flow/active-session-untyped-map-time-sub-soundness/main.lua @@ -0,0 +1,25 @@ +local time = require("time") + +type ActiveSession = { + created_at: time.Time, + last_activity: time.Time?, +} + +local state = { + active_sessions = {}, +} + +local now = time.now() + +state.active_sessions["s1"] = { + created_at = now, + last_activity = now, +} + +for _, session_info in pairs(state.active_sessions) do + local last_activity = session_info.last_activity or session_info.created_at + local elapsed = now:sub(last_activity) + return elapsed:seconds() +end + +return 0 diff --git a/testdata/fixtures/flow/active-session-untyped-map-time-sub-soundness/manifest.json b/testdata/fixtures/flow/active-session-untyped-map-time-sub-soundness/manifest.json new file mode 100644 index 00000000..7f8847e0 --- /dev/null +++ b/testdata/fixtures/flow/active-session-untyped-map-time-sub-soundness/manifest.json @@ -0,0 +1,5 @@ +{ + "description": "Iterating an untyped active_sessions map yields any-typed values, so time.Time:sub must reject them", + "packages": ["time"], + "check": {"errors": 1} +} diff --git a/testdata/fixtures/modules/active-session-any-time-sub-soundness/main.lua b/testdata/fixtures/modules/active-session-any-time-sub-soundness/main.lua new file mode 100644 index 00000000..2f4f52d3 --- /dev/null +++ b/testdata/fixtures/modules/active-session-any-time-sub-soundness/main.lua @@ -0,0 +1,9 @@ +local time = require("time") +local session_state = require("session_state") + +local now = time.now() +local session_info = session_state.new() +local last_activity = session_info.last_activity or session_info.created_at +local elapsed = now:sub(last_activity) + +return elapsed:seconds() diff --git a/testdata/fixtures/modules/active-session-any-time-sub-soundness/manifest.json b/testdata/fixtures/modules/active-session-any-time-sub-soundness/manifest.json new file mode 100644 index 00000000..dbdb27f3 --- /dev/null +++ b/testdata/fixtures/modules/active-session-any-time-sub-soundness/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Using any-typed session timestamps in time.Time:sub is a sound checker error", + "packages": ["time"], + "check": {"errors": 1}, + "files": [ + "session_state.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/modules/active-session-any-time-sub-soundness/session_state.lua b/testdata/fixtures/modules/active-session-any-time-sub-soundness/session_state.lua new file mode 100644 index 00000000..f251509b --- /dev/null +++ b/testdata/fixtures/modules/active-session-any-time-sub-soundness/session_state.lua @@ -0,0 +1,18 @@ +local time = require("time") + +type ActiveSession = { + created_at: any, + last_activity: any, +} + +local M = {} + +function M.new(): ActiveSession + local now = time.now() + return { + created_at = now, + last_activity = now, + } +end + +return M diff --git a/testdata/fixtures/modules/active-session-typed-time-sub/main.lua b/testdata/fixtures/modules/active-session-typed-time-sub/main.lua new file mode 100644 index 00000000..2f4f52d3 --- /dev/null +++ b/testdata/fixtures/modules/active-session-typed-time-sub/main.lua @@ -0,0 +1,9 @@ +local time = require("time") +local session_state = require("session_state") + +local now = time.now() +local session_info = session_state.new() +local last_activity = session_info.last_activity or session_info.created_at +local elapsed = now:sub(last_activity) + +return elapsed:seconds() diff --git a/testdata/fixtures/modules/active-session-typed-time-sub/manifest.json b/testdata/fixtures/modules/active-session-typed-time-sub/manifest.json new file mode 100644 index 00000000..716426b7 --- /dev/null +++ b/testdata/fixtures/modules/active-session-typed-time-sub/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Annotating session timestamps as time.Time preserves sound time arithmetic", + "packages": ["time"], + "check": {"errors": 0}, + "files": [ + "session_state.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/modules/active-session-typed-time-sub/session_state.lua b/testdata/fixtures/modules/active-session-typed-time-sub/session_state.lua new file mode 100644 index 00000000..5fc34231 --- /dev/null +++ b/testdata/fixtures/modules/active-session-typed-time-sub/session_state.lua @@ -0,0 +1,18 @@ +local time = require("time") + +type ActiveSession = { + created_at: time.Time, + last_activity: time.Time?, +} + +local M = {} + +function M.new(): ActiveSession + local now = time.now() + return { + created_at = now, + last_activity = now, + } +end + +return M diff --git a/testdata/fixtures/modules/google-client-metadata-regression/client.lua b/testdata/fixtures/modules/google-client-metadata-regression/client.lua new file mode 100644 index 00000000..72b7aff7 --- /dev/null +++ b/testdata/fixtures/modules/google-client-metadata-regression/client.lua @@ -0,0 +1,102 @@ +local json = require("json") +local http_client = require("http_client") +local output = require("output") + +type StreamInput = { + stream: any, + metadata: table?, +} + +local client = { + _http_client = http_client +} + +local function extract_response_metadata(response_body: any) + if not response_body then + return {} + end + + local metadata = {} + metadata.model_version = response_body.modelVersion + metadata.response_id = response_body.responseId + metadata.create_time = response_body.createTime + + return metadata +end + +local function parse_error_response(http_response) + local error_info = { + status_code = http_response.status_code, + message = "Google API error" + } + + if http_response.body then + local parsed, decode_err = json.decode(http_response.body) + if not decode_err and parsed then + error_info.metadata = extract_response_metadata(parsed) + end + end + + return error_info +end + +function client.process_stream(stream_response: StreamInput, callbacks) + return nil, "stream not used", { + content = "", + tool_calls = {}, + metadata = stream_response.metadata or {}, + } +end + +function client.request(method, url, http_options) + http_options = http_options or {} + + local response, err + if method == "GET" then + response, err = client._http_client.get(url, http_options) + elseif method == "PUT" then + response, err = client._http_client.put(url, http_options) + elseif method == "PATCH" then + response, err = client._http_client.patch(url, http_options) + else + response, err = client._http_client.post(url, http_options) + end + + if not response then + return nil, { + status_code = 0, + message = "Connection failed: " .. tostring(err) + } + end + + if response.status_code < 200 or response.status_code >= 300 then + local parsed_error = parse_error_response(response) + return nil, parsed_error + end + + if http_options.stream and response.stream then + return { + stream = response.stream, + status_code = response.status_code, + headers = response.headers, + metadata = extract_response_metadata(response) + } + end + + local parsed, parse_err = json.decode(response.body) + if parse_err then + local parse_error = { + status_code = response.status_code, + message = "Failed to parse Google response: " .. parse_err, + metadata = {} + } + return nil, parse_error + end + + parsed.metadata = extract_response_metadata(parsed) + parsed.status_code = response.status_code + + return parsed +end + +return client diff --git a/testdata/fixtures/modules/google-client-metadata-regression/http_client.lua b/testdata/fixtures/modules/google-client-metadata-regression/http_client.lua new file mode 100644 index 00000000..9169b46c --- /dev/null +++ b/testdata/fixtures/modules/google-client-metadata-regression/http_client.lua @@ -0,0 +1,30 @@ +local http_client = {} + +type StreamReader = { + read: (self: any) -> (string?, string?), +} + +type Response = { + status_code: number, + body: string?, + stream: StreamReader?, + headers: {[string]: string}?, +} + +function http_client.get(url: string, options: any?): (Response?, string?) + return nil, "not implemented" +end + +function http_client.post(url: string, options: any?): (Response?, string?) + return nil, "not implemented" +end + +function http_client.put(url: string, options: any?): (Response?, string?) + return nil, "not implemented" +end + +function http_client.patch(url: string, options: any?): (Response?, string?) + return nil, "not implemented" +end + +return http_client diff --git a/testdata/fixtures/modules/google-client-metadata-regression/json.lua b/testdata/fixtures/modules/google-client-metadata-regression/json.lua new file mode 100644 index 00000000..9b3b9a57 --- /dev/null +++ b/testdata/fixtures/modules/google-client-metadata-regression/json.lua @@ -0,0 +1,16 @@ +local json = {} + +function json.encode(value: any): string + return "encoded" +end + +function json.decode(source: string): (any, string?) + return { + data = "test", + modelVersion = "gemini-2.5-pro-001", + responseId = "resp-123", + createTime = "2024-01-15T10:30:00Z", + }, nil +end + +return json diff --git a/testdata/fixtures/modules/google-client-metadata-regression/main.lua b/testdata/fixtures/modules/google-client-metadata-regression/main.lua new file mode 100644 index 00000000..46a455df --- /dev/null +++ b/testdata/fixtures/modules/google-client-metadata-regression/main.lua @@ -0,0 +1,28 @@ +local client = require("client") +local json = require("json") +local tests = require("tests") + +client._http_client = { + get = function(url, options) + return { + status_code = 200, + body = json.encode({ + data = "test", + modelVersion = "gemini-2.5-pro-001", + responseId = "resp-123", + createTime = "2024-01-15T10:30:00Z", + }) + } + end +} + +local response, err = client.request("GET", "https://test.googleapis.com/v1/test", { + headers = {} +}) + +tests.is_nil(err) +tests.eq(response.metadata.model_version, "gemini-2.5-pro-001") +tests.eq(response.metadata.response_id, "resp-123") +tests.eq(response.metadata.create_time, "2024-01-15T10:30:00Z") + +return response.data diff --git a/testdata/fixtures/modules/google-client-metadata-regression/manifest.json b/testdata/fixtures/modules/google-client-metadata-regression/manifest.json new file mode 100644 index 00000000..9b4ebac8 --- /dev/null +++ b/testdata/fixtures/modules/google-client-metadata-regression/manifest.json @@ -0,0 +1,12 @@ +{ + "description": "Reduced google client path: decoding optional http response body without a guard is a sound checker error", + "check": {"errors": 1}, + "files": [ + "json.lua", + "http_client.lua", + "output.lua", + "tests.lua", + "client.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/modules/google-client-metadata-regression/output.lua b/testdata/fixtures/modules/google-client-metadata-regression/output.lua new file mode 100644 index 00000000..8073b262 --- /dev/null +++ b/testdata/fixtures/modules/google-client-metadata-regression/output.lua @@ -0,0 +1,5 @@ +local output = { + TRUNCATION_MSG = "[truncated]" +} + +return output diff --git a/testdata/fixtures/modules/google-client-metadata-regression/tests.lua b/testdata/fixtures/modules/google-client-metadata-regression/tests.lua new file mode 100644 index 00000000..517d2f63 --- /dev/null +++ b/testdata/fixtures/modules/google-client-metadata-regression/tests.lua @@ -0,0 +1,15 @@ +local tests = {} + +function tests.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "expected nil", 2) + end +end + +function tests.eq(actual: any, expected: any, msg: string?) + if actual ~= expected then + error(msg or "not equal", 2) + end +end + +return tests diff --git a/testdata/fixtures/modules/imported-eq-typeof-table-len/main.lua b/testdata/fixtures/modules/imported-eq-typeof-table-len/main.lua new file mode 100644 index 00000000..9bea4adc --- /dev/null +++ b/testdata/fixtures/modules/imported-eq-typeof-table-len/main.lua @@ -0,0 +1,10 @@ +local test = require("test") + +local values: {string}? = { "alpha", "beta" } + +test.eq(type(values), "table", "values should be a table") + +local count: number = #values +local first: string = values[1] + +return count, first diff --git a/testdata/fixtures/modules/imported-eq-typeof-table-len/manifest.json b/testdata/fixtures/modules/imported-eq-typeof-table-len/manifest.json new file mode 100644 index 00000000..2d01b6ff --- /dev/null +++ b/testdata/fixtures/modules/imported-eq-typeof-table-len/manifest.json @@ -0,0 +1,7 @@ +{ + "description": "Imported eq summary narrows type(x) == 'table' for length and indexing", + "files": [ + "test.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/modules/imported-eq-typeof-table-len/test.lua b/testdata/fixtures/modules/imported-eq-typeof-table-len/test.lua new file mode 100644 index 00000000..7f48f666 --- /dev/null +++ b/testdata/fixtures/modules/imported-eq-typeof-table-len/test.lua @@ -0,0 +1,9 @@ +local M = {} + +function M.eq(actual, expected, msg) + if actual ~= expected then + error(msg or "assertion failed", 2) + end +end + +return M diff --git a/testdata/fixtures/modules/imported-field-cast-expected-record/main.lua b/testdata/fixtures/modules/imported-field-cast-expected-record/main.lua new file mode 100644 index 00000000..4e5d25c1 --- /dev/null +++ b/testdata/fixtures/modules/imported-field-cast-expected-record/main.lua @@ -0,0 +1,16 @@ +local page_registry = require("page_registry") + +type PageResponse = { + id: string, + configOverrides: {[string]: any}?, +} + +local all_pages = page_registry.find_all() +local page = all_pages[1] + +local page_info: PageResponse = { + id = page.id, + configOverrides = page.config_overrides :: {[string]: any}?, +} + +local _ = page_info diff --git a/testdata/fixtures/modules/imported-field-cast-expected-record/manifest.json b/testdata/fixtures/modules/imported-field-cast-expected-record/manifest.json new file mode 100644 index 00000000..84b41d8a --- /dev/null +++ b/testdata/fixtures/modules/imported-field-cast-expected-record/manifest.json @@ -0,0 +1,10 @@ +{ + "description": "Explicit cast on an imported field should satisfy the expected record field type", + "files": [ + "page_registry.lua", + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-field-cast-expected-record/page_registry.lua b/testdata/fixtures/modules/imported-field-cast-expected-record/page_registry.lua new file mode 100644 index 00000000..201ab401 --- /dev/null +++ b/testdata/fixtures/modules/imported-field-cast-expected-record/page_registry.lua @@ -0,0 +1,19 @@ +type PageInfo = { + id: string, + config_overrides: any?, +} + +local M = {} + +function M.find_all() + return { + { + id = "p1", + config_overrides = { + theme = "dark", + }, + }, + } +end + +return M diff --git a/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/builder.lua b/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/builder.lua new file mode 100644 index 00000000..a98a8e9c --- /dev/null +++ b/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/builder.lua @@ -0,0 +1,43 @@ +local protocol = require("protocol") + +type ToolResultResult = protocol.ToolResultResult + +local M = {} + +function M.build(): protocol.ToolHandler + return function(state: protocol.SessionState, msg: protocol.ToolCallMessage): ToolResultResult + local value = msg.arguments["value"] + if type(value) ~= "string" then + return { + ok = false, + error = { + code = "invalid", + message = "value must be string", + retryable = false, + }, + } + end + + if state.flags["flagged"] then + return { + ok = true, + value = { + tool = msg.tool, + content = "flagged:" .. value, + cached = false, + }, + } + end + + return { + ok = true, + value = { + tool = msg.tool, + content = value, + cached = false, + }, + } + end +end + +return M diff --git a/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/main.lua b/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/main.lua new file mode 100644 index 00000000..e92f1304 --- /dev/null +++ b/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/main.lua @@ -0,0 +1,26 @@ +local protocol = require("protocol") +local builder = require("builder") + +type Message = {kind: "user", content: string} | protocol.ToolCallMessage + +local function run(handlers: {[string]: protocol.ToolHandler}, state: protocol.SessionState, msg: Message) + if msg.kind ~= "tool_call" then + return nil + end + + local handler = handlers[msg.tool] + if not handler then + return nil + end + + local out = handler(state, msg) + if out.ok then + local tool: string = out.value.tool + local content: string = out.value.content + else + local code: string = out.error.code + local retryable: boolean = out.error.retryable + end +end + +run({search = builder.build()}, {flags = {}}, {kind = "tool_call", tool = "search", arguments = {value = "x"}}) diff --git a/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/manifest.json b/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/manifest.json new file mode 100644 index 00000000..e49ae4c3 --- /dev/null +++ b/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/manifest.json @@ -0,0 +1,12 @@ +{ + "name": "imported-handler-map-lookup-after-discriminant", + "packages": [], + "files": [ + "protocol.lua", + "builder.lua", + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/protocol.lua b/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/protocol.lua new file mode 100644 index 00000000..b5092964 --- /dev/null +++ b/testdata/fixtures/modules/imported-handler-map-lookup-after-discriminant/protocol.lua @@ -0,0 +1,34 @@ +type AppError = { + code: string, + message: string, + retryable: boolean, +} + +type ToolCallMessage = { + kind: "tool_call", + tool: string, + arguments: {[string]: any}, +} + +type ToolResult = { + tool: string, + content: string, + cached: boolean, +} + +type SessionState = { + flags: {[string]: boolean}, +} + +type ToolResultResult = {ok: true, value: ToolResult} | {ok: false, error: AppError} +type ToolHandler = (SessionState, ToolCallMessage) -> ToolResultResult + +local M = {} +M.AppError = AppError +M.ToolCallMessage = ToolCallMessage +M.ToolResult = ToolResult +M.SessionState = SessionState +M.ToolResultResult = ToolResultResult +M.ToolHandler = ToolHandler + +return M diff --git a/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/client.lua b/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/client.lua new file mode 100644 index 00000000..fa9eb4cd --- /dev/null +++ b/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/client.lua @@ -0,0 +1,21 @@ +local M = {} + +type Response = { + metadata: { + response_id: string, + }, +} + +function M.request(ok: boolean): (Response?, string?) + if ok then + return { + metadata = { + response_id = "resp-123", + }, + }, nil + end + + return nil, "failed" +end + +return M diff --git a/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/main.lua b/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/main.lua new file mode 100644 index 00000000..9637d1b1 --- /dev/null +++ b/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/main.lua @@ -0,0 +1,10 @@ +local test = require("test") +local client = require("client") + +local response, err = client.request(true) + +test.is_nil(err, "no error expected") + +local id: string = response.metadata.response_id + +return id diff --git a/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/manifest.json b/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/manifest.json new file mode 100644 index 00000000..deade8f2 --- /dev/null +++ b/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/manifest.json @@ -0,0 +1,11 @@ +{ + "description": "Imported inferred is_nil export must narrow sibling multi-return slots across module boundaries", + "files": [ + "test.lua", + "client.lua", + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/test.lua b/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/test.lua new file mode 100644 index 00000000..84ee34da --- /dev/null +++ b/testdata/fixtures/modules/imported-inferred-is-nil-sibling-correlation/test.lua @@ -0,0 +1,9 @@ +local M = {} + +function M.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "expected nil", 2) + end +end + +return M diff --git a/testdata/fixtures/modules/imported-map-alias-or-local-return/context.lua b/testdata/fixtures/modules/imported-map-alias-or-local-return/context.lua new file mode 100644 index 00000000..ebbc7930 --- /dev/null +++ b/testdata/fixtures/modules/imported-map-alias-or-local-return/context.lua @@ -0,0 +1,10 @@ +type Context = {[string]: any} + +local M = {} +M.Context = Context + +function M.empty(): Context + return {} +end + +return M diff --git a/testdata/fixtures/modules/imported-map-alias-or-local-return/main.lua b/testdata/fixtures/modules/imported-map-alias-or-local-return/main.lua new file mode 100644 index 00000000..bbd2d3de --- /dev/null +++ b/testdata/fixtures/modules/imported-map-alias-or-local-return/main.lua @@ -0,0 +1,6 @@ +local context = require("context") + +function with_default(initial: context.Context?): context.Context + local ctx = initial or context.empty() + return ctx +end diff --git a/testdata/fixtures/modules/imported-map-alias-or-local-return/manifest.json b/testdata/fixtures/modules/imported-map-alias-or-local-return/manifest.json new file mode 100644 index 00000000..65fe8f73 --- /dev/null +++ b/testdata/fixtures/modules/imported-map-alias-or-local-return/manifest.json @@ -0,0 +1,7 @@ +{ + "description": "Imported map alias survives or-fallback when stored in an untyped local before return", + "files": [ + "context.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/modules/imported-map-of-record-store/main.lua b/testdata/fixtures/modules/imported-map-of-record-store/main.lua new file mode 100644 index 00000000..9d8622e3 --- /dev/null +++ b/testdata/fixtures/modules/imported-map-of-record-store/main.lua @@ -0,0 +1,15 @@ +local protocol = require("protocol") +local store_mod = require("store") + +local store: store_mod.Store = store_mod.new() +store:put("s1", { + id = "s1", + last_value = nil, + flags = {}, +}) + +local snapshot = store:get("s1") +if snapshot then + local id: string = snapshot.id + local ready = snapshot.flags["ready"] +end diff --git a/testdata/fixtures/modules/imported-map-of-record-store/manifest.json b/testdata/fixtures/modules/imported-map-of-record-store/manifest.json new file mode 100644 index 00000000..43625f7e --- /dev/null +++ b/testdata/fixtures/modules/imported-map-of-record-store/manifest.json @@ -0,0 +1,11 @@ +{ + "description": "Imported record aliases used as map values inside exported self-method store objects should preserve value types on put/get.", + "files": [ + "protocol.lua", + "store.lua", + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-map-of-record-store/protocol.lua b/testdata/fixtures/modules/imported-map-of-record-store/protocol.lua new file mode 100644 index 00000000..4fe07be7 --- /dev/null +++ b/testdata/fixtures/modules/imported-map-of-record-store/protocol.lua @@ -0,0 +1,10 @@ +type Snapshot = { + id: string, + last_value: string?, + flags: {[string]: boolean}, +} + +local M = {} +M.Snapshot = Snapshot + +return M diff --git a/testdata/fixtures/modules/imported-map-of-record-store/store.lua b/testdata/fixtures/modules/imported-map-of-record-store/store.lua new file mode 100644 index 00000000..7be662ed --- /dev/null +++ b/testdata/fixtures/modules/imported-map-of-record-store/store.lua @@ -0,0 +1,33 @@ +local protocol = require("protocol") + +type Store = { + sessions: {[string]: protocol.Snapshot}, + put: (self: Store, id: string, snapshot: protocol.Snapshot) -> (), + get: (self: Store, id: string) -> protocol.Snapshot?, +} + +local Store = {} +Store.__index = Store + +local M = {} +M.Store = Store + +function M.new(): Store + local self: Store = { + sessions = {}, + put = Store.put, + get = Store.get, + } + setmetatable(self, Store) + return self +end + +function Store:put(id: string, snapshot: protocol.Snapshot) + self.sessions[id] = snapshot +end + +function Store:get(id: string): protocol.Snapshot? + return self.sessions[id] +end + +return M diff --git a/testdata/fixtures/modules/imported-map-of-time-record-store/main.lua b/testdata/fixtures/modules/imported-map-of-time-record-store/main.lua new file mode 100644 index 00000000..5be0bbce --- /dev/null +++ b/testdata/fixtures/modules/imported-map-of-time-record-store/main.lua @@ -0,0 +1,10 @@ +local time = require("time") +local store_mod = require("store") + +local now = time.now() +local store: store_mod.Store = store_mod.new() +local snapshot = store:open("s1", now) +local copy = store:get("s1") +if copy then + local id: string = copy.id +end diff --git a/testdata/fixtures/modules/imported-map-of-time-record-store/manifest.json b/testdata/fixtures/modules/imported-map-of-time-record-store/manifest.json new file mode 100644 index 00000000..bc77c45b --- /dev/null +++ b/testdata/fixtures/modules/imported-map-of-time-record-store/manifest.json @@ -0,0 +1,12 @@ +{ + "description": "Imported record aliases with package-backed time fields should stay stable as map values inside exported store objects.", + "files": [ + "protocol.lua", + "store.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-map-of-time-record-store/protocol.lua b/testdata/fixtures/modules/imported-map-of-time-record-store/protocol.lua new file mode 100644 index 00000000..23c08356 --- /dev/null +++ b/testdata/fixtures/modules/imported-map-of-time-record-store/protocol.lua @@ -0,0 +1,14 @@ +local time = require("time") + +type Snapshot = { + id: string, + opened_at: time.Time, + last_seen: time.Time, + last_value: string?, + flags: {[string]: boolean}, +} + +local M = {} +M.Snapshot = Snapshot + +return M diff --git a/testdata/fixtures/modules/imported-map-of-time-record-store/store.lua b/testdata/fixtures/modules/imported-map-of-time-record-store/store.lua new file mode 100644 index 00000000..2561c698 --- /dev/null +++ b/testdata/fixtures/modules/imported-map-of-time-record-store/store.lua @@ -0,0 +1,48 @@ +local time = require("time") +local protocol = require("protocol") + +type Store = { + sessions: {[string]: protocol.Snapshot}, + open: (self: Store, id: string, now: time.Time) -> protocol.Snapshot, + get: (self: Store, id: string) -> protocol.Snapshot?, +} + +local Store = {} +Store.__index = Store + +local M = {} +M.Store = Store + +function M.new(): Store + local self: Store = { + sessions = {}, + open = Store.open, + get = Store.get, + } + setmetatable(self, Store) + return self +end + +function Store:open(id: string, now: time.Time): protocol.Snapshot + local existing = self.sessions[id] + if existing then + existing.last_seen = now + return existing + end + + local created: protocol.Snapshot = { + id = id, + opened_at = now, + last_seen = now, + last_value = nil, + flags = {}, + } + self.sessions[id] = created + return created +end + +function Store:get(id: string): protocol.Snapshot? + return self.sessions[id] +end + +return M diff --git a/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/main.lua b/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/main.lua new file mode 100644 index 00000000..91c1109e --- /dev/null +++ b/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/main.lua @@ -0,0 +1,25 @@ +local test = require("test") + +type Response = { + result: { + data: { + departments: {string}?, + }, + }, +} + +local response: Response = { + result = { + data = { + departments = { "engineering", "product" }, + }, + }, +} + +test.not_nil(response.result.data.departments, "departments required") +test.eq(type(response.result.data.departments), "table", "departments should be a table") + +local count: number = #response.result.data.departments +local first: string = response.result.data.departments[1] + +return count, first diff --git a/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/manifest.json b/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/manifest.json new file mode 100644 index 00000000..5eeb3c08 --- /dev/null +++ b/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/manifest.json @@ -0,0 +1,10 @@ +{ + "description": "Imported not_nil on nested field path plus imported eq(type(x), 'table') preserves len and indexing", + "files": [ + "test.lua", + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/test.lua b/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/test.lua new file mode 100644 index 00000000..9dccb319 --- /dev/null +++ b/testdata/fixtures/modules/imported-not-nil-field-typeof-table-len/test.lua @@ -0,0 +1,16 @@ +local M = {} + +function M.eq(actual, expected, msg) + if actual ~= expected then + error(msg or "assertion failed", 2) + end +end + +function M.not_nil(val, msg) + if val == nil then + error(msg or "expected non-nil", 2) + end + return val +end + +return M diff --git a/testdata/fixtures/modules/imported-optional-method-zero-arg-read/http_client.lua b/testdata/fixtures/modules/imported-optional-method-zero-arg-read/http_client.lua new file mode 100644 index 00000000..df25c394 --- /dev/null +++ b/testdata/fixtures/modules/imported-optional-method-zero-arg-read/http_client.lua @@ -0,0 +1,26 @@ +type Stream = { + read: (self: Stream, n: number?) -> (string?, string?), +} + +type Response = { + status_code: number, + body: string?, + stream: Stream?, +} + +local http_client = {} + +function http_client.get(url: string): (Response?, string?) + local stream: Stream = { + read = function(self: Stream, n: number?) + return "chunk", nil + end, + } + + return { + status_code = 500, + stream = stream, + }, nil +end + +return http_client diff --git a/testdata/fixtures/modules/imported-optional-method-zero-arg-read/main.lua b/testdata/fixtures/modules/imported-optional-method-zero-arg-read/main.lua new file mode 100644 index 00000000..8d2c52ec --- /dev/null +++ b/testdata/fixtures/modules/imported-optional-method-zero-arg-read/main.lua @@ -0,0 +1,13 @@ +local http_client = require("http_client") + +local response, err = http_client.get("https://example.test") +if err or not response then + return nil, err +end + +if response.status_code >= 300 and response.stream and not response.body then + local body_data = response.stream:read() + response.body = body_data +end + +return response diff --git a/testdata/fixtures/modules/imported-optional-method-zero-arg-read/manifest.json b/testdata/fixtures/modules/imported-optional-method-zero-arg-read/manifest.json new file mode 100644 index 00000000..13558f8c --- /dev/null +++ b/testdata/fixtures/modules/imported-optional-method-zero-arg-read/manifest.json @@ -0,0 +1,7 @@ +{ + "description": "Imported module preserves optional method parameters across export/import for zero-arg method calls", + "files": [ + "http_client.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/modules/imported-record-empty-map-field/main.lua b/testdata/fixtures/modules/imported-record-empty-map-field/main.lua new file mode 100644 index 00000000..0ac4ab5e --- /dev/null +++ b/testdata/fixtures/modules/imported-record-empty-map-field/main.lua @@ -0,0 +1,5 @@ +local protocol = require("protocol") +local store = require("store") + +local snapshot: protocol.Snapshot = store.make("s1") +snapshot.flags["ready"] = true diff --git a/testdata/fixtures/modules/imported-record-empty-map-field/manifest.json b/testdata/fixtures/modules/imported-record-empty-map-field/manifest.json new file mode 100644 index 00000000..2b8b7618 --- /dev/null +++ b/testdata/fixtures/modules/imported-record-empty-map-field/manifest.json @@ -0,0 +1,11 @@ +{ + "description": "Imported record aliases should contextually type empty map fields inside returned table literals.", + "files": [ + "protocol.lua", + "store.lua", + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-record-empty-map-field/protocol.lua b/testdata/fixtures/modules/imported-record-empty-map-field/protocol.lua new file mode 100644 index 00000000..b4ce99f2 --- /dev/null +++ b/testdata/fixtures/modules/imported-record-empty-map-field/protocol.lua @@ -0,0 +1,9 @@ +type Snapshot = { + id: string, + flags: {[string]: boolean}, +} + +local M = {} +M.Snapshot = Snapshot + +return M diff --git a/testdata/fixtures/modules/imported-record-empty-map-field/store.lua b/testdata/fixtures/modules/imported-record-empty-map-field/store.lua new file mode 100644 index 00000000..a0a55729 --- /dev/null +++ b/testdata/fixtures/modules/imported-record-empty-map-field/store.lua @@ -0,0 +1,12 @@ +local protocol = require("protocol") + +local M = {} + +function M.make(id: string): protocol.Snapshot + return { + id = id, + flags = {}, + } +end + +return M diff --git a/testdata/fixtures/modules/imported-record-return-literal/main.lua b/testdata/fixtures/modules/imported-record-return-literal/main.lua new file mode 100644 index 00000000..ee43acbc --- /dev/null +++ b/testdata/fixtures/modules/imported-record-return-literal/main.lua @@ -0,0 +1,15 @@ +local time = require("time") +local protocol = require("protocol") +local store = require("store") + +local now = time.now() +local snapshot: protocol.Snapshot = store.make("s1", now) + +local label: string = snapshot.id +local opened = snapshot.opened_at +local elapsed = now:sub(opened) +local seconds: number = elapsed:seconds() + +if snapshot.last_value then + local value: string = snapshot.last_value +end diff --git a/testdata/fixtures/modules/imported-record-return-literal/manifest.json b/testdata/fixtures/modules/imported-record-return-literal/manifest.json new file mode 100644 index 00000000..57065057 --- /dev/null +++ b/testdata/fixtures/modules/imported-record-return-literal/manifest.json @@ -0,0 +1,12 @@ +{ + "description": "Cross-module imported record alias should work as the expected type for returned table literals and constructor results.", + "files": [ + "protocol.lua", + "store.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-record-return-literal/protocol.lua b/testdata/fixtures/modules/imported-record-return-literal/protocol.lua new file mode 100644 index 00000000..71505321 --- /dev/null +++ b/testdata/fixtures/modules/imported-record-return-literal/protocol.lua @@ -0,0 +1,13 @@ +local time = require("time") + +type Snapshot = { + id: string, + opened_at: time.Time, + last_seen: time.Time, + last_value: string?, +} + +local M = {} +M.Snapshot = Snapshot + +return M diff --git a/testdata/fixtures/modules/imported-record-return-literal/store.lua b/testdata/fixtures/modules/imported-record-return-literal/store.lua new file mode 100644 index 00000000..0260361d --- /dev/null +++ b/testdata/fixtures/modules/imported-record-return-literal/store.lua @@ -0,0 +1,15 @@ +local time = require("time") +local protocol = require("protocol") + +local M = {} + +function M.make(id: string, now: time.Time): protocol.Snapshot + return { + id = id, + opened_at = now, + last_seen = now, + last_value = nil, + } +end + +return M diff --git a/testdata/fixtures/modules/imported-self-method-store/main.lua b/testdata/fixtures/modules/imported-self-method-store/main.lua new file mode 100644 index 00000000..4cffe0ab --- /dev/null +++ b/testdata/fixtures/modules/imported-self-method-store/main.lua @@ -0,0 +1,9 @@ +local store_mod = require("store") + +local store: store_mod.Store = store_mod.new() +store:put("name", "lua") + +local maybe_name = store:get("name") +if maybe_name then + local value: string = maybe_name +end diff --git a/testdata/fixtures/modules/imported-self-method-store/manifest.json b/testdata/fixtures/modules/imported-self-method-store/manifest.json new file mode 100644 index 00000000..a0759bcb --- /dev/null +++ b/testdata/fixtures/modules/imported-self-method-store/manifest.json @@ -0,0 +1,10 @@ +{ + "description": "Cross-module exported object types with self-referential methods should preserve identity for constructor results and method calls.", + "files": [ + "store.lua", + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/modules/imported-self-method-store/store.lua b/testdata/fixtures/modules/imported-self-method-store/store.lua new file mode 100644 index 00000000..977e1584 --- /dev/null +++ b/testdata/fixtures/modules/imported-self-method-store/store.lua @@ -0,0 +1,32 @@ +type Store = { + cache: {[string]: string}, + get: (self: Store, key: string) -> string?, + put: (self: Store, key: string, value: string) -> Store, +} + +local Store = {} +Store.__index = Store + +local M = {} +M.Store = Store + +function M.new(): Store + local self: Store = { + cache = {}, + get = Store.get, + put = Store.put, + } + setmetatable(self, Store) + return self +end + +function Store:get(key: string): string? + return self.cache[key] +end + +function Store:put(key: string, value: string): Store + self.cache[key] = value + return self +end + +return M diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/contract.lua b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/contract.lua new file mode 100644 index 00000000..4dbbeb2e --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/contract.lua @@ -0,0 +1,7 @@ +local contract = {} + +function contract.get(_id) + return nil, "not configured" +end + +return contract diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/main.lua b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/main.lua new file mode 100644 index 00000000..e94c1570 --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/main.lua @@ -0,0 +1,57 @@ +local test = require("test") +local providers = require("providers") + +local captured_options = nil + +local provider_with_retry = { + id = "wippy.llm.provider:openai", + kind = "registry.entry", + meta = { type = "llm.provider", name = "openai", title = "OpenAI" }, + data = { + driver = { + id = "wippy.llm.binding:openai_driver", + options = { + api_key_env = "OPENAI_API_KEY", + retry = { max_attempts = 5, initial_delay = 200 }, + }, + }, + }, +} + +providers._registry = { + get = function(id) + if id == "wippy.llm.provider:openai" then + return provider_with_retry, nil + end + return nil, "not found" + end, +} + +providers._contract = { + get = function(_contract_id) + return { + with_context = function(self, _context) + return self + end, + with_options = function(self, opts) + captured_options = opts + return self + end, + open = function(self, binding_id) + return { _binding_id = binding_id }, nil + end, + }, nil + end, +} + +local instance, err = providers.open("wippy.llm.provider:openai") + +test.is_nil(err, "open should succeed") +assert(instance) +test.not_nil(captured_options, "captured options expected") +test.not_nil(captured_options.retry, "retry expected") + +local attempts: number = captured_options.retry.max_attempts +local delay: number = captured_options.retry.initial_delay + +return attempts, delay diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/manifest.json b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/manifest.json new file mode 100644 index 00000000..2d1c96e1 --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/manifest.json @@ -0,0 +1,10 @@ +{ + "description": "Actual providers.open retry path preserves captured retry options under raw go-lua checking", + "files": [ + "test.lua", + "registry.lua", + "contract.lua", + "providers.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/providers.lua b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/providers.lua new file mode 100644 index 00000000..fbbe0c6c --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/providers.lua @@ -0,0 +1,70 @@ +local registry = require("registry") +local contract = require("contract") + +local providers = { + _registry = registry, + _contract = contract, +} + +local CONTRACT_ID = "wippy.llm:provider" + +function providers.open(provider_id, context_overrides) + if not provider_id then + return nil, "Provider ID is required" + end + + context_overrides = context_overrides or {} + + local provider_entry, err = providers._registry.get(provider_id) + if err then + return nil, "Registry error: " .. tostring(err) + end + + if not provider_entry then + return nil, "Provider not found: " .. provider_id + end + + if not provider_entry.meta or provider_entry.meta.type ~= "llm.provider" then + return nil, "Entry is not a provider: " .. provider_id + end + + if not provider_entry.data or not provider_entry.data.driver or not provider_entry.data.driver.id then + return nil, "Provider missing driver configuration: " .. provider_id + end + + local binding_id = provider_entry.data.driver.id + local base_options = provider_entry.data.driver.options or {} + + local final_context = {} + for k, v in pairs(base_options) do + final_context[k] = v + end + for k, v in pairs(context_overrides) do + final_context[k] = v + end + + local call_options = {} + if final_context.retry then + call_options.retry = final_context.retry + final_context.retry = nil + end + + local provider_contract, err = providers._contract.get(CONTRACT_ID) + if err then + return nil, "Failed to get provider contract: " .. tostring(err) + end + + local chain = provider_contract:with_context(final_context) + if next(call_options) then + chain = chain:with_options(call_options) + end + + local instance, open_err = chain:open(tostring(binding_id)) + if open_err then + return nil, "Failed to open provider binding: " .. tostring(open_err) + end + + return instance +end + +return providers diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/registry.lua b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/registry.lua new file mode 100644 index 00000000..37d2ba53 --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/registry.lua @@ -0,0 +1,11 @@ +local registry = {} + +function registry.get(_id) + return nil, "not configured" +end + +function registry.find(_query) + return {}, nil +end + +return registry diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/test.lua b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/test.lua new file mode 100644 index 00000000..d40e9e98 --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/test.lua @@ -0,0 +1,16 @@ +local M = {} + +function M.is_nil(val, msg) + if val ~= nil then + error(msg or "expected nil", 2) + end +end + +function M.not_nil(val, msg) + if val == nil then + error(msg or "expected non-nil", 2) + end + return val +end + +return M diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options/contract.lua b/testdata/fixtures/modules/providers-open-retry-captured-options/contract.lua new file mode 100644 index 00000000..4dbbeb2e --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options/contract.lua @@ -0,0 +1,7 @@ +local contract = {} + +function contract.get(_id) + return nil, "not configured" +end + +return contract diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options/main.lua b/testdata/fixtures/modules/providers-open-retry-captured-options/main.lua new file mode 100644 index 00000000..82dc27cb --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options/main.lua @@ -0,0 +1,35 @@ +local test = require("test") +local providers = require("providers") + +local captured_options = nil + +providers._contract = { + get = function(_contract_id) + return { + with_context = function(self, _context) + return self + end, + with_options = function(self, opts) + captured_options = opts + return self + end, + open = function(self, binding_id) + return { _binding_id = binding_id }, nil + end, + }, nil + end, +} + +local instance, err = providers.open("wippy.llm.provider:openai", { + retry = { max_attempts = 3, initial_delay = 100 }, +}) + +test.is_nil(err, "open should succeed") +assert(instance) +test.not_nil(captured_options, "captured options expected") +test.not_nil(captured_options.retry, "retry expected") + +local attempts: number = captured_options.retry.max_attempts +local delay: number = captured_options.retry.initial_delay + +return attempts, delay diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options/manifest.json b/testdata/fixtures/modules/providers-open-retry-captured-options/manifest.json new file mode 100644 index 00000000..f307a5d1 --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Imported builder chain preserves captured retry options through not_nil guards", + "files": [ + "test.lua", + "contract.lua", + "providers.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options/providers.lua b/testdata/fixtures/modules/providers-open-retry-captured-options/providers.lua new file mode 100644 index 00000000..040b6f74 --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options/providers.lua @@ -0,0 +1,23 @@ +local contract = require("contract") + +local providers = { + _contract = contract, +} + +function providers.open(provider_id, context_overrides) + context_overrides = context_overrides or {} + + local provider_contract, err = providers._contract.get("provider") + if err then + return nil, err + end + + local chain = provider_contract:with_context({}) + chain = chain:with_options({ + retry = context_overrides.retry, + }) + + return chain:open(provider_id) +end + +return providers diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options/test.lua b/testdata/fixtures/modules/providers-open-retry-captured-options/test.lua new file mode 100644 index 00000000..d40e9e98 --- /dev/null +++ b/testdata/fixtures/modules/providers-open-retry-captured-options/test.lua @@ -0,0 +1,16 @@ +local M = {} + +function M.is_nil(val, msg) + if val ~= nil then + error(msg or "expected nil", 2) + end +end + +function M.not_nil(val, msg) + if val == nil then + error(msg or "expected non-nil", 2) + end + return val +end + +return M diff --git a/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/main.lua b/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/main.lua new file mode 100644 index 00000000..3aeeb9e2 --- /dev/null +++ b/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/main.lua @@ -0,0 +1,22 @@ +local page_registry = require("page_registry") + +local function takes_string(name: string) + return name +end + +local function get_page_data(page) + if not page or not page.data_func or page.data_func == "" then + return {}, nil + end + + local name: string = page.data_func -- expect-error: cannot assign string | true to string + takes_string(page.data_func) -- expect-error: argument 1: expected string, got string | true + return {}, nil +end + +local page = page_registry.build_page({ + id = "demo", + data = { data_func = "load_data" }, +}) + +return get_page_data(page) -- expect-error: expected {data_func?: boolean | string diff --git a/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/manifest.json b/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/manifest.json new file mode 100644 index 00000000..3f278de3 --- /dev/null +++ b/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/manifest.json @@ -0,0 +1,7 @@ +{ + "description": "Dynamic registry data flowing through page_registry into renderer-style optional string guard", + "files": [ + "page_registry.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/page_registry.lua b/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/page_registry.lua new file mode 100644 index 00000000..0bf1eace --- /dev/null +++ b/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/page_registry.lua @@ -0,0 +1,23 @@ +type Entry = { + id: string, + data: {[string]: any}, +} + +local pages = {} + +local function qualify_id(entry_id, relative_id) + return entry_id .. ":" .. relative_id +end + +function pages.build_page(entry: Entry) + local data_func = entry.data.data_func + if data_func and data_func ~= "" then + data_func = qualify_id(entry.id, data_func) + end + + local page = {} + page.data_func = data_func + return page +end + +return pages diff --git a/testdata/fixtures/narrowing/incremental-record-optional-string-assignment/main.lua b/testdata/fixtures/narrowing/incremental-record-optional-string-assignment/main.lua new file mode 100644 index 00000000..dcb3cfad --- /dev/null +++ b/testdata/fixtures/narrowing/incremental-record-optional-string-assignment/main.lua @@ -0,0 +1,33 @@ +type PageInfo = { + id: string, + name: string, + secure: boolean, +} + +type PageDetail = PageInfo & { + data_func: string?, +} + +local function qualify_id(ns: string, relative_id: string): string + return ns .. relative_id +end + +local function build_page(raw: string?): PageDetail + local data_func = raw + if data_func and data_func ~= "" then + data_func = qualify_id("demo:", data_func) + end + + local page = { + id = "p", + name = "n", + secure = false, + } + page.data_func = data_func + + local typed: PageDetail = page + local maybe_name: string? = page.data_func + return typed +end + +return build_page diff --git a/testdata/fixtures/narrowing/nested-field-type-guard-arithmetic/main.lua b/testdata/fixtures/narrowing/nested-field-type-guard-arithmetic/main.lua new file mode 100644 index 00000000..cba22aa7 --- /dev/null +++ b/testdata/fixtures/narrowing/nested-field-type-guard-arithmetic/main.lua @@ -0,0 +1,16 @@ +type PayloadCarrier = { + data: fun(self: PayloadCarrier): any, +} + +local function bump(carrier: PayloadCarrier?) + local data = carrier and carrier:data() or nil + if type(data) ~= "table" or type(data.amount) ~= "number" then + return nil + end + + local next_amount = data.amount + 1 + local exact: number = data.amount + return next_amount, exact +end + +return bump diff --git a/testdata/fixtures/narrowing/optional-field-early-return-guard/main.lua b/testdata/fixtures/narrowing/optional-field-early-return-guard/main.lua new file mode 100644 index 00000000..aa9fb2e0 --- /dev/null +++ b/testdata/fixtures/narrowing/optional-field-early-return-guard/main.lua @@ -0,0 +1,19 @@ +type Page = { + data_func: string?, +} + +local function takes_string(name: string) + return name +end + +local function get_page_data(page: Page?) + if not page or not page.data_func or page.data_func == "" then + return {}, nil + end + + local name: string = page.data_func + takes_string(page.data_func) + return {}, nil +end + +return get_page_data diff --git a/testdata/fixtures/narrowing/optional-field-early-return-inferred/main.lua b/testdata/fixtures/narrowing/optional-field-early-return-inferred/main.lua new file mode 100644 index 00000000..7c03a613 --- /dev/null +++ b/testdata/fixtures/narrowing/optional-field-early-return-inferred/main.lua @@ -0,0 +1,28 @@ +type Page = { + data_func: string?, +} + +local function load_page(): (Page?, string?) + return { data_func = "demo" }, nil +end + +local function takes_string(name: string) + return name +end + +local function get_page_data(page) + if not page or not page.data_func or page.data_func == "" then + return {}, nil + end + + local name: string = page.data_func + takes_string(page.data_func) + return {}, nil +end + +local page, err = load_page() +if err then + return nil, err +end + +return get_page_data(page) diff --git a/testdata/fixtures/narrowing/optional-field-early-return-intersection/main.lua b/testdata/fixtures/narrowing/optional-field-early-return-intersection/main.lua new file mode 100644 index 00000000..7f5b9c28 --- /dev/null +++ b/testdata/fixtures/narrowing/optional-field-early-return-intersection/main.lua @@ -0,0 +1,26 @@ +type PageInfo = { + id: string, + name: string, + secure: boolean, +} + +type PageDetail = PageInfo & { + data_func: string?, + template_set: string?, +} + +local function takes_string(name: string) + return name +end + +local function get_page_data(page: PageDetail?) + if not page or not page.data_func or page.data_func == "" then + return {}, nil + end + + local name: string = page.data_func + takes_string(page.data_func) + return {}, nil +end + +return get_page_data diff --git a/testdata/fixtures/narrowing/optional-string-nonempty-assignment/main.lua b/testdata/fixtures/narrowing/optional-string-nonempty-assignment/main.lua new file mode 100644 index 00000000..6c6413f8 --- /dev/null +++ b/testdata/fixtures/narrowing/optional-string-nonempty-assignment/main.lua @@ -0,0 +1,18 @@ +local function qualify_id(ns: string, relative_id: string): string + return ns .. relative_id +end + +local function build_page(raw: string?) + local data_func = raw + if data_func and data_func ~= "" then + data_func = qualify_id("demo:", data_func) + end + + local maybe_name: string? = data_func + local page: {data_func: string?} = { + data_func = data_func, + } + return page, maybe_name +end + +return build_page diff --git a/testdata/fixtures/narrowing/page-registry-renderer-guard/main.lua b/testdata/fixtures/narrowing/page-registry-renderer-guard/main.lua new file mode 100644 index 00000000..8a8622fa --- /dev/null +++ b/testdata/fixtures/narrowing/page-registry-renderer-guard/main.lua @@ -0,0 +1,22 @@ +local page_registry = require("page_registry") + +local function takes_string(name: string) + return name +end + +local function get_page_data(page) + if not page or not page.data_func or page.data_func == "" then + return {}, nil + end + + local name: string = page.data_func + takes_string(page.data_func) + return {}, nil +end + +local page, err = page_registry.get("demo:home") +if err then + return nil, err +end + +return get_page_data(page) diff --git a/testdata/fixtures/narrowing/page-registry-renderer-guard/manifest.json b/testdata/fixtures/narrowing/page-registry-renderer-guard/manifest.json new file mode 100644 index 00000000..eb9faef1 --- /dev/null +++ b/testdata/fixtures/narrowing/page-registry-renderer-guard/manifest.json @@ -0,0 +1,7 @@ +{ + "description": "Multi-file module boundary regression for page_registry -> renderer style optional string guards", + "files": [ + "page_registry.lua", + "main.lua" + ] +} diff --git a/testdata/fixtures/narrowing/page-registry-renderer-guard/page_registry.lua b/testdata/fixtures/narrowing/page-registry-renderer-guard/page_registry.lua new file mode 100644 index 00000000..0305b89a --- /dev/null +++ b/testdata/fixtures/narrowing/page-registry-renderer-guard/page_registry.lua @@ -0,0 +1,39 @@ +type PageInfo = { + id: string, + name: string, + secure: boolean, +} + +type PageDetail = PageInfo & { + data_func: string?, + template_set: string?, +} + +local pages = {} + +local function qualify_id(ns: string, relative_id: string): string + return ns .. ":" .. relative_id +end + +function pages.get(page_id: string): (PageDetail?, string?) + if not page_id then + return nil, "Page ID is required" + end + + local data_func: string? = "load_data" + if data_func and data_func ~= "" then + data_func = qualify_id("demo", data_func) + end + + local page = { + id = page_id, + name = "Demo Page", + secure = false, + template_set = "demo:set", + } + page.data_func = data_func + + return page, nil +end + +return pages diff --git a/testdata/fixtures/narrowing/temporal-loop-nested-field-type-guard/main.lua b/testdata/fixtures/narrowing/temporal-loop-nested-field-type-guard/main.lua new file mode 100644 index 00000000..132a2e60 --- /dev/null +++ b/testdata/fixtures/narrowing/temporal-loop-nested-field-type-guard/main.lua @@ -0,0 +1,44 @@ +type Message = { + from: fun(self: Message): string, + payload: fun(self: Message): any, +} + +type Channel = { + receive: fun(self: Channel): (Message, boolean), +} + +local process = {} + +function process.listen(topic: string, options: any?): Channel + error("stub") +end + +function process.send(pid: string, topic: string, ...: any): (boolean, string?) + return true, nil +end + +local counter = 0 +local done = false + +coroutine.spawn(function() + local ch = process.listen("increment", {message = true}) + while not done do + local msg, ok = ch:receive() + if not ok then + break + end + + local p = msg:payload() + local data = p and p:data() or nil + local reply_to = msg:from() + + if type(data) ~= "table" or type(data.amount) ~= "number" then + process.send(reply_to, "nak", "amount must be a number") + else + process.send(reply_to, "ack") + local amount_sanity = data.amount + 1 + counter = counter + data.amount + counter = amount_sanity - 1 + end + end +end) diff --git a/testdata/fixtures/narrowing/union-page-variant-guard/main.lua b/testdata/fixtures/narrowing/union-page-variant-guard/main.lua new file mode 100644 index 00000000..b4e7648d --- /dev/null +++ b/testdata/fixtures/narrowing/union-page-variant-guard/main.lua @@ -0,0 +1,30 @@ +type TemplatePage = { + kind: "template", + id: string, + data_func: string?, + template_set: string, +} + +type ComponentPage = { + kind: "component", + id: string, + url: string, +} + +type Page = TemplatePage | ComponentPage + +local function takes_string(name: string) + return name +end + +local function get_page_data(page: Page?) + if not page or not page.data_func or page.data_func == "" then + return {}, nil + end + + local name: string = page.data_func + takes_string(page.data_func) + return {}, nil +end + +return get_page_data diff --git a/testdata/fixtures/realworld/agent-workflow-engine-soundness/engine.lua b/testdata/fixtures/realworld/agent-workflow-engine-soundness/engine.lua new file mode 100644 index 00000000..dc3e7f9f --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine-soundness/engine.lua @@ -0,0 +1,35 @@ +local time = require("time") +local protocol = require("protocol") +local session_store = require("session_store") + +type Engine = { + handlers: {[string]: protocol.ToolHandler}, + register_tool: (self: Engine, name: string, handler: protocol.ToolHandler) -> Engine, + new_session: (self: Engine, id: string, now: time.Time) -> session_store.SessionStore, +} + +local Engine = {} +Engine.__index = Engine + +local M = {} + +function M.new(): Engine + local self: Engine = { + handlers = {}, + register_tool = Engine.register_tool, + new_session = Engine.new_session, + } + setmetatable(self, Engine) + return self +end + +function Engine:register_tool(name: string, handler: protocol.ToolHandler): Engine + self.handlers[name] = handler + return self +end + +function Engine:new_session(id: string, now: time.Time): session_store.SessionStore + return session_store.new(id, now) +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine-soundness/main.lua b/testdata/fixtures/realworld/agent-workflow-engine-soundness/main.lua new file mode 100644 index 00000000..cbd7c5a5 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine-soundness/main.lua @@ -0,0 +1,28 @@ +local time = require("time") +local engine = require("engine") +local tools = require("tools") +local protocol = require("protocol") + +local now = time.now() +local app = engine.new():register_tool("search", tools.search) +local store = app:new_session("unsafe", now) + +local msg: protocol.ToolCallMessage = { + kind = "tool_call", + id = "m1", + tool = "search", + arguments = {topic = "lua"}, + meta = protocol.meta("req-1", "trace-1", nil), +} + +local handler = app.handlers["search"] +local produced = handler(store.state, msg) -- expect-error + +local cached = store:lookup_tool("search") +local bad_content: string = cached.content -- expect-error + +local elapsed = now:sub(store.state.last_activity) -- expect-error +local seconds: number = elapsed:seconds() + +local tags = msg.meta.tags +local bad_source: string = tags["source"] -- expect-error diff --git a/testdata/fixtures/realworld/agent-workflow-engine-soundness/manifest.json b/testdata/fixtures/realworld/agent-workflow-engine-soundness/manifest.json new file mode 100644 index 00000000..2c7ebefb --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine-soundness/manifest.json @@ -0,0 +1,12 @@ +{ + "description": "Soundness companion for the agent workflow engine: realistic unsafe uses of optional map entries and optional timestamps must be rejected.", + "files": [ + "result.lua", + "protocol.lua", + "session_store.lua", + "tools.lua", + "engine.lua", + "main.lua" + ], + "packages": ["time"] +} diff --git a/testdata/fixtures/realworld/agent-workflow-engine-soundness/protocol.lua b/testdata/fixtures/realworld/agent-workflow-engine-soundness/protocol.lua new file mode 100644 index 00000000..58e5ace3 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine-soundness/protocol.lua @@ -0,0 +1,53 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type RequestMeta = { + request_id: string, + trace_id: string, + tags: {[string]: string}?, +} + +type ToolCallMessage = { + kind: "tool_call", + id: string, + tool: string, + arguments: {[string]: any}, + meta: RequestMeta, +} + +type ToolResult = { + tool: string, + content: string, + cached: boolean, +} + +type SessionState = { + id: string, + started_at: time.Time, + last_activity: time.Time?, + tool_cache: {[string]: ToolResult}, +} + +type ToolResultResult = {ok: true, value: ToolResult} | {ok: false, error: AppError} +type ToolHandler = (SessionState, ToolCallMessage) -> ToolResultResult + +local M = {} +M.AppError = AppError +M.RequestMeta = RequestMeta +M.ToolCallMessage = ToolCallMessage +M.ToolResult = ToolResult +M.SessionState = SessionState +M.ToolResultResult = ToolResultResult +M.ToolHandler = ToolHandler + +function M.meta(request_id: string, trace_id: string, tags: {[string]: string}?): RequestMeta + return { + request_id = request_id, + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine-soundness/result.lua b/testdata/fixtures/realworld/agent-workflow-engine-soundness/result.lua new file mode 100644 index 00000000..8e30a297 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine-soundness/result.lua @@ -0,0 +1,29 @@ +type ErrorCode = "not_found" | "invalid" | "busy" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine-soundness/session_store.lua b/testdata/fixtures/realworld/agent-workflow-engine-soundness/session_store.lua new file mode 100644 index 00000000..a30052fe --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine-soundness/session_store.lua @@ -0,0 +1,35 @@ +local time = require("time") +local protocol = require("protocol") + +type SessionStore = { + state: protocol.SessionState, + lookup_tool: (self: SessionStore, name: string) -> protocol.ToolResult?, +} + +type Store = SessionStore + +local Store = {} +Store.__index = Store + +local M = {} +M.SessionStore = SessionStore + +function M.new(id: string, now: time.Time): SessionStore + local self: Store = { + state = { + id = id, + started_at = now, + last_activity = nil, + tool_cache = {}, + }, + lookup_tool = Store.lookup_tool, + } + setmetatable(self, Store) + return self +end + +function Store:lookup_tool(name: string): protocol.ToolResult? + return self.state.tool_cache[name] +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine-soundness/tools.lua b/testdata/fixtures/realworld/agent-workflow-engine-soundness/tools.lua new file mode 100644 index 00000000..a6f68c12 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine-soundness/tools.lua @@ -0,0 +1,16 @@ +local result = require("result") +local protocol = require("protocol") + +type ToolResultResult = protocol.ToolResultResult + +local M = {} + +function M.search(_state: protocol.SessionState, msg: protocol.ToolCallMessage): ToolResultResult + return result.ok({ + tool = msg.tool, + content = "search", + cached = false, + }) +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine/engine.lua b/testdata/fixtures/realworld/agent-workflow-engine/engine.lua new file mode 100644 index 00000000..afe35797 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine/engine.lua @@ -0,0 +1,151 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local session_store = require("session_store") + +type AppError = result.AppError +type StepResult = {ok: true, value: string?} | {ok: false, error: AppError} +type SummaryResult = {ok: true, value: protocol.SessionSummary} | {ok: false, error: AppError} + +type Engine = { + handlers: {[string]: protocol.ToolHandler}, + listeners: {protocol.StepListener}, + register_tool: (self: Engine, name: string, handler: protocol.ToolHandler) -> Engine, + on_step: (self: Engine, listener: protocol.StepListener) -> Engine, + emit: (self: Engine, store: session_store.SessionStore, step: protocol.WorkflowStep, at: time.Time) -> (), + new_session: (self: Engine, id: string, now: time.Time) -> session_store.SessionStore, + process_message: (self: Engine, store: session_store.SessionStore, msg: protocol.Message, at: time.Time) -> StepResult, + process: (self: Engine, store: session_store.SessionStore, messages: {protocol.Message}, now: time.Time) -> SummaryResult, +} + +local Engine = {} +Engine.__index = Engine + +local M = {} +M.Engine = Engine + +function M.new(): Engine + local self: Engine = { + handlers = {}, + listeners = {}, + register_tool = Engine.register_tool, + on_step = Engine.on_step, + emit = Engine.emit, + new_session = Engine.new_session, + process_message = Engine.process_message, + process = Engine.process, + } + setmetatable(self, Engine) + return self +end + +function Engine:register_tool(name: string, handler: protocol.ToolHandler): Engine + self.handlers[name] = handler + return self +end + +function Engine:on_step(listener: protocol.StepListener): Engine + table.insert(self.listeners, listener) + return self +end + +function Engine:emit(store: session_store.SessionStore, step: protocol.WorkflowStep, at: time.Time) + store:emit_step(step, at) + for _, listener in ipairs(self.listeners) do + listener(step, store.state) + end +end + +function Engine:new_session(id: string, now: time.Time): session_store.SessionStore + return session_store.new(id, now) +end + +function Engine:process_message(store: session_store.SessionStore, msg: protocol.Message, at: time.Time): StepResult + store:append_message(msg, at) + + if msg.kind == "user" then + self:emit(store, {kind = "assistant", content = "ack:" .. msg.content}, at) + return {ok = true, value = nil} + end + + if msg.kind == "tool_call" then + local cached = store:lookup_tool(msg.tool) + if cached then + self:emit(store, { + kind = "tool", + tool = msg.tool, + result = { + tool = cached.tool, + content = cached.content, + cached = true, + }, + }, at) + store:mark_flag("cache_hit") + return {ok = true, value = nil} + end + + local handler = self.handlers[msg.tool] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing tool handler: " .. msg.tool, + retryable = false, + }, + } + end + + local tool_result = handler(store.state, msg) + if tool_result.ok then + store:remember_tool(tool_result.value) + self:emit(store, { + kind = "tool", + tool = msg.tool, + result = tool_result.value, + }, at) + + if msg.tool == "profile" then + store:mark_flag("profile_loaded") + end + return {ok = true, value = nil} + end + + return { + ok = false, + error = { + code = tool_result.error.code, + message = tool_result.error.message, + retryable = tool_result.error.retryable, + }, + } + end + + self:emit(store, {kind = "audit", note = "done:" .. msg.reason, at = at}, at) + return {ok = true, value = msg.reason} +end + +function Engine:process( + store: session_store.SessionStore, + messages: {protocol.Message}, + now: time.Time +): SummaryResult + local last_reason: string? = nil + + for _, msg in ipairs(messages) do + local step_result = self:process_message(store, msg, now) + if not step_result.ok then + return step_result + end + if step_result.value ~= nil then + last_reason = step_result.value + end + end + + return { + ok = true, + value = store:summarize(now, last_reason), + } +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine/main.lua b/testdata/fixtures/realworld/agent-workflow-engine/main.lua new file mode 100644 index 00000000..acb0b5c6 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine/main.lua @@ -0,0 +1,178 @@ +local time = require("time") +local result = require("result") +local engine = require("engine") +local tool_builder = require("tool_builder") +local tools = require("tools") +local protocol = require("protocol") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() + +local observed_steps: {protocol.WorkflowStep} = {} +local observed_tool_contents: {[string]: string} = {} +local last_session_id: string? = nil + +local search_handler = tool_builder.new() + :named("search") + :require_arg("topic") + :prefix_with("search") + :remember_flag("warm_cache") + :with_formatter(function(content: string, _state: protocol.SessionState, msg: protocol.ToolCallMessage): string + return content .. ":" .. tools.source_tag(msg) + end) + :build() + +local profile_handler = tool_builder.new() + :named("profile") + :require_arg("user_id") + :prefix_with("profile") + :remember_flag("profile_loaded") + :with_formatter(function(content: string, state: protocol.SessionState, _msg: protocol.ToolCallMessage): string + return content .. ":" .. tools.cache_mode(state) + end) + :build() + +local app = engine.new() + :register_tool("search", search_handler) + :register_tool("profile", profile_handler) + +app:on_step(function(step: protocol.WorkflowStep, state: protocol.SessionState) + last_session_id = state.id + + if step.kind == "assistant" then + local text: string = step.content + elseif step.kind == "tool" then + observed_tool_contents[step.tool] = step.result.content + local cached: boolean = step.result.cached + else + local note: string = step.note + local at: integer = step.at:unix() + end + + table.insert(observed_steps, step) +end) + +local user_msg: protocol.UserMessage = { + kind = "user", + id = "m1", + content = "hello", + meta = protocol.meta("req-1", "trace-1", {source = "ui"}), +} + +local search_msg: protocol.ToolCallMessage = { + kind = "tool_call", + id = "m2", + tool = "search", + arguments = {topic = "lua"}, + meta = protocol.meta("req-2", "trace-1", {source = "planner"}), +} + +local profile_msg: protocol.ToolCallMessage = { + kind = "tool_call", + id = "m3", + tool = "profile", + arguments = {user_id = "u-1"}, + meta = protocol.meta("req-3", "trace-1", nil), +} + +local repeat_search_msg: protocol.ToolCallMessage = { + kind = "tool_call", + id = "m4", + tool = "search", + arguments = {topic = "lua"}, + meta = protocol.meta("req-4", "trace-1", {source = "planner"}), +} + +local done_msg: protocol.DoneMessage = { + kind = "done", + id = "m5", + reason = "complete", + usage = {prompt_tokens = 12, completion_tokens = 5}, + meta = protocol.meta("req-5", "trace-1", nil), +} + +local messages: {protocol.Message} = { + user_msg, + search_msg, + profile_msg, + repeat_search_msg, + done_msg, +} + +local store = app:new_session("sess-1", now) +store:mark_flag("warm_cache") + +local summary_result = app:process(store, messages, now) +if not summary_result.ok then + local message: string = summary_result.error.message + local retryable: boolean = summary_result.error.retryable +else + local summary = summary_result.value + local session_id: string = summary.id + local total_steps: number = summary.total_steps + local cached_tool_count: number = summary.cached_tool_count + local latency: number = summary.last_latency_seconds + local reason: string? = summary.last_reason +end + +local summary_label = result.map(summary_result, function(summary: protocol.SessionSummary): string + return summary.id .. ":" .. tostring(summary.total_steps) +end) + +if summary_label.ok then + local label: string = summary_label.value +end + +local summary_id = result.and_then(summary_result, function(summary: protocol.SessionSummary): StringResult + if summary.total_steps == 0 then + return { + ok = false, + error = { + code = "invalid", + message = "expected steps", + retryable = false, + }, + } + end + return { + ok = true, + value = summary.id, + } +end) + +if summary_id.ok then + local id: string = summary_id.value +end + +local cached = store:lookup_tool("search") +if cached then + local content: string = cached.content + local from_cache: boolean = cached.cached +end + +local missing = store:lookup_tool("missing") +if missing == nil then + local fallback: string = "missing" +end + +for name, content in pairs(observed_tool_contents) do + local tool_name: string = name + local tool_content: string = content +end + +if last_session_id ~= nil then + local stable_id: string = last_session_id +end + +local search_tags = search_msg.meta.tags +if search_tags then + local source = search_tags["source"] + if source then + local source_name: string = source + end +end + +local last_seen = store.state.last_activity or store.state.started_at +local elapsed = now:sub(last_seen) +local seconds: number = elapsed:seconds() diff --git a/testdata/fixtures/realworld/agent-workflow-engine/manifest.json b/testdata/fixtures/realworld/agent-workflow-engine/manifest.json new file mode 100644 index 00000000..578c0966 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine/manifest.json @@ -0,0 +1,16 @@ +{ + "description": "Application-shaped workflow engine mixing generic results, discriminated unions, metatable-backed shared-self objects, callback listeners, dynamic map lookups, and time-based state.", + "files": [ + "result.lua", + "protocol.lua", + "session_store.lua", + "tool_builder.lua", + "tools.lua", + "engine.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/agent-workflow-engine/protocol.lua b/testdata/fixtures/realworld/agent-workflow-engine/protocol.lua new file mode 100644 index 00000000..e7a744ff --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine/protocol.lua @@ -0,0 +1,113 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type Usage = { + prompt_tokens: integer, + completion_tokens: integer, +} + +type RequestMeta = { + request_id: string, + trace_id: string, + tags: {[string]: string}?, +} + +type UserMessage = { + kind: "user", + id: string, + content: string, + meta: RequestMeta, +} + +type ToolCallMessage = { + kind: "tool_call", + id: string, + tool: string, + arguments: {[string]: any}, + meta: RequestMeta, +} + +type DoneMessage = { + kind: "done", + id: string, + reason: "complete" | "tool_error" | "timeout", + usage: Usage?, + meta: RequestMeta, +} + +type Message = UserMessage | ToolCallMessage | DoneMessage + +type ToolResult = { + tool: string, + content: string, + cached: boolean, +} + +type AssistantStep = { + kind: "assistant", + content: string, +} + +type ToolStep = { + kind: "tool", + tool: string, + result: ToolResult, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type WorkflowStep = AssistantStep | ToolStep | AuditStep + +type SessionState = { + id: string, + started_at: time.Time, + last_activity: time.Time?, + messages: {Message}, + steps: {WorkflowStep}, + flags: {[string]: boolean}, + tool_cache: {[string]: ToolResult}, +} + +type SessionSummary = { + id: string, + total_steps: number, + cached_tool_count: number, + last_latency_seconds: number, + last_reason: string?, +} + +type StepListener = (WorkflowStep, SessionState) -> () +type ToolResultResult = {ok: true, value: ToolResult} | {ok: false, error: AppError} +type ToolHandler = (SessionState, ToolCallMessage) -> ToolResultResult + +local M = {} +M.AppError = AppError +M.Usage = Usage +M.RequestMeta = RequestMeta +M.UserMessage = UserMessage +M.ToolCallMessage = ToolCallMessage +M.DoneMessage = DoneMessage +M.Message = Message +M.ToolResult = ToolResult +M.WorkflowStep = WorkflowStep +M.SessionState = SessionState +M.SessionSummary = SessionSummary +M.StepListener = StepListener +M.ToolResultResult = ToolResultResult +M.ToolHandler = ToolHandler + +function M.meta(request_id: string, trace_id: string, tags: {[string]: string}?): RequestMeta + return { + request_id = request_id, + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine/result.lua b/testdata/fixtures/realworld/agent-workflow-engine/result.lua new file mode 100644 index 00000000..314787c1 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine/session_store.lua b/testdata/fixtures/realworld/agent-workflow-engine/session_store.lua new file mode 100644 index 00000000..79dfbc0b --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine/session_store.lua @@ -0,0 +1,91 @@ +local time = require("time") +local protocol = require("protocol") + +type SessionStore = { + state: protocol.SessionState, + touch: (self: SessionStore, at: time.Time) -> SessionStore, + append_message: (self: SessionStore, msg: protocol.Message, at: time.Time) -> SessionStore, + emit_step: (self: SessionStore, step: protocol.WorkflowStep, at: time.Time) -> SessionStore, + remember_tool: (self: SessionStore, tool: protocol.ToolResult) -> (), + lookup_tool: (self: SessionStore, name: string) -> protocol.ToolResult?, + mark_flag: (self: SessionStore, name: string) -> (), + summarize: (self: SessionStore, now: time.Time, last_reason: string?) -> protocol.SessionSummary, +} + +type Store = SessionStore + +local Store = {} +Store.__index = Store + +local M = {} +M.SessionStore = SessionStore + +function M.new(id: string, now: time.Time): SessionStore + local self: Store = { + state = { + id = id, + started_at = now, + last_activity = nil, + messages = {}, + steps = {}, + flags = {}, + tool_cache = {}, + }, + touch = Store.touch, + append_message = Store.append_message, + emit_step = Store.emit_step, + remember_tool = Store.remember_tool, + lookup_tool = Store.lookup_tool, + mark_flag = Store.mark_flag, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:touch(at: time.Time): Store + self.state.last_activity = at + return self +end + +function Store:append_message(msg: protocol.Message, at: time.Time): Store + table.insert(self.state.messages, msg) + return self:touch(at) +end + +function Store:emit_step(step: protocol.WorkflowStep, at: time.Time): Store + table.insert(self.state.steps, step) + return self:touch(at) +end + +function Store:remember_tool(tool: protocol.ToolResult) + self.state.tool_cache[tool.tool] = tool +end + +function Store:lookup_tool(name: string): protocol.ToolResult? + return self.state.tool_cache[name] +end + +function Store:mark_flag(name: string) + self.state.flags[name] = true +end + +function Store:summarize(now: time.Time, last_reason: string?): protocol.SessionSummary + local cached_tool_count = 0 + for _, _ in pairs(self.state.tool_cache) do + cached_tool_count = cached_tool_count + 1 + end + + local since = self.state.last_activity or self.state.started_at + local latency = now:sub(since) + + return { + id = self.state.id, + total_steps = #self.state.steps, + cached_tool_count = cached_tool_count, + last_latency_seconds = latency:seconds(), + last_reason = last_reason, + } +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine/tool_builder.lua b/testdata/fixtures/realworld/agent-workflow-engine/tool_builder.lua new file mode 100644 index 00000000..d79bb42b --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine/tool_builder.lua @@ -0,0 +1,111 @@ +local protocol = require("protocol") + +type ToolResultResult = protocol.ToolResultResult +type Formatter = (string, protocol.SessionState, protocol.ToolCallMessage) -> string + +type ToolBuilder = { + name: string, + required_arg: string, + prefix: string, + mark_flag: string?, + formatter: Formatter?, + named: (self: ToolBuilder, name: string) -> ToolBuilder, + require_arg: (self: ToolBuilder, key: string) -> ToolBuilder, + prefix_with: (self: ToolBuilder, prefix: string) -> ToolBuilder, + remember_flag: (self: ToolBuilder, flag: string) -> ToolBuilder, + with_formatter: (self: ToolBuilder, formatter: Formatter) -> ToolBuilder, + build: (self: ToolBuilder) -> protocol.ToolHandler, +} + +type Builder = ToolBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ToolBuilder = ToolBuilder + +function M.new(): ToolBuilder + local self: Builder = { + name = "tool", + required_arg = "value", + prefix = "tool", + mark_flag = nil, + formatter = nil, + named = Builder.named, + require_arg = Builder.require_arg, + prefix_with = Builder.prefix_with, + remember_flag = Builder.remember_flag, + with_formatter = Builder.with_formatter, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_arg(key: string): Builder + self.required_arg = key + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self +end + +function Builder:remember_flag(flag: string): Builder + self.mark_flag = flag + return self +end + +function Builder:with_formatter(formatter: Formatter): Builder + self.formatter = formatter + return self +end + +function Builder:build(): protocol.ToolHandler + local name = self.name + local required_arg = self.required_arg + local prefix = self.prefix + local mark_flag = self.mark_flag + local formatter = self.formatter + + return function(state: protocol.SessionState, msg: protocol.ToolCallMessage): ToolResultResult + local value = msg.arguments[required_arg] + if type(value) ~= "string" then + return { + ok = false, + error = { + code = "invalid", + message = name .. " " .. required_arg .. " must be string", + retryable = false, + }, + } + end + + local content = prefix .. ":" .. value + if formatter then + content = formatter(content, state, msg) + end + + if mark_flag and state.flags[mark_flag] then + content = content .. ":flagged" + end + + return { + ok = true, + value = { + tool = msg.tool, + content = content, + cached = false, + }, + } + end +end + +return M diff --git a/testdata/fixtures/realworld/agent-workflow-engine/tools.lua b/testdata/fixtures/realworld/agent-workflow-engine/tools.lua new file mode 100644 index 00000000..9af852b0 --- /dev/null +++ b/testdata/fixtures/realworld/agent-workflow-engine/tools.lua @@ -0,0 +1,26 @@ +local protocol = require("protocol") + +local M = {} + +function M.source_tag(msg: protocol.ToolCallMessage): string + local tags = msg.meta.tags + if not tags then + return "unknown" + end + + local source = tags["source"] + if source == nil then + return "unknown" + end + return source +end + +function M.cache_mode(state: protocol.SessionState): string + local status = "first" + if state.flags["profile_loaded"] then + status = "repeat" + end + return status +end + +return M diff --git a/testdata/fixtures/realworld/channel-select-pattern/handler.lua b/testdata/fixtures/realworld/channel-select-pattern/handler.lua index 038ca743..4dd9a475 100644 --- a/testdata/fixtures/realworld/channel-select-pattern/handler.lua +++ b/testdata/fixtures/realworld/channel-select-pattern/handler.lua @@ -4,7 +4,7 @@ type HandlerResult = {ok: boolean, message: string} local M = {} -function M.process_event(event: Event): HandlerResult +function M.process_event(event: types.Event): HandlerResult if event.error then return {ok = false, message = "error: " .. tostring(event.error)} end diff --git a/testdata/fixtures/realworld/context-merge-pipeline/main.lua b/testdata/fixtures/realworld/context-merge-pipeline/main.lua index ce8d60e6..7311b288 100644 --- a/testdata/fixtures/realworld/context-merge-pipeline/main.lua +++ b/testdata/fixtures/realworld/context-merge-pipeline/main.lua @@ -2,14 +2,14 @@ local context = require("context") local pipeline = require("pipeline") local p = pipeline.new() - :add("auth", function(ctx: Context): Context + :add("auth", function(ctx: context.Context): context.Context return context.with(ctx, "user_id", "u123") end) - :add("permissions", function(ctx: Context): Context + :add("permissions", function(ctx: context.Context): context.Context local user_id = context.get(ctx, "user_id") return context.with(ctx, "can_read", user_id ~= nil) end) - :add("defaults", function(ctx: Context): Context + :add("defaults", function(ctx: context.Context): context.Context return context.merge(ctx, { locale = "en", timezone = "UTC", diff --git a/testdata/fixtures/realworld/context-merge-pipeline/manifest.json b/testdata/fixtures/realworld/context-merge-pipeline/manifest.json index 445a10f2..dafb3b6e 100644 --- a/testdata/fixtures/realworld/context-merge-pipeline/manifest.json +++ b/testdata/fixtures/realworld/context-merge-pipeline/manifest.json @@ -1,4 +1,4 @@ { "files": ["context.lua", "pipeline.lua", "main.lua"], - "check": {"errors": 3} + "check": {"errors": 0} } diff --git a/testdata/fixtures/realworld/context-merge-pipeline/pipeline.lua b/testdata/fixtures/realworld/context-merge-pipeline/pipeline.lua index ec7a537e..f47c6d49 100644 --- a/testdata/fixtures/realworld/context-merge-pipeline/pipeline.lua +++ b/testdata/fixtures/realworld/context-merge-pipeline/pipeline.lua @@ -2,13 +2,13 @@ local context = require("context") type Stage = { name: string, - process: (ctx: Context) -> Context, + process: (ctx: context.Context) -> context.Context, } type Pipeline = { _stages: {Stage}, - add: (self: Pipeline, name: string, processor: (ctx: Context) -> Context) -> Pipeline, - run: (self: Pipeline, initial: Context?) -> Context, + add: (self: Pipeline, name: string, processor: (ctx: context.Context) -> context.Context) -> Pipeline, + run: (self: Pipeline, initial: context.Context?) -> context.Context, count: (self: Pipeline) -> number, } @@ -17,11 +17,11 @@ local M = {} function M.new(): Pipeline local p: Pipeline = { _stages = {}, - add = function(self: Pipeline, name: string, processor: (ctx: Context) -> Context): Pipeline + add = function(self: Pipeline, name: string, processor: (ctx: context.Context) -> context.Context): Pipeline table.insert(self._stages, {name = name, process = processor}) return self end, - run = function(self: Pipeline, initial: Context?): Pipeline + run = function(self: Pipeline, initial: context.Context?): context.Context local ctx = initial or context.empty() for _, stage in ipairs(self._stages) do ctx = stage.process(ctx) diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/handler_builder.lua b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/handler_builder.lua new file mode 100644 index 00000000..2b6dfe58 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/handler_builder.lua @@ -0,0 +1,223 @@ +local protocol = require("protocol") +local helpers = require("helpers") + +type Decorator = (string, protocol.OrderAggregate, protocol.Command) -> string + +type HandlerBuilder = { + command_kind: "create" | "reserve" | "complete", + note_prefix: string, + counter_name: string?, + source_key: string?, + decorator: Decorator?, + for_kind: (self: HandlerBuilder, kind: "create" | "reserve" | "complete") -> HandlerBuilder, + prefix_with: (self: HandlerBuilder, prefix: string) -> HandlerBuilder, + count_as: (self: HandlerBuilder, counter_name: string) -> HandlerBuilder, + capture_source: (self: HandlerBuilder, source_key: string) -> HandlerBuilder, + decorate: (self: HandlerBuilder, decorator: Decorator) -> HandlerBuilder, + build: (self: HandlerBuilder) -> protocol.CommandHandler, +} + +type Builder = HandlerBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.HandlerBuilder = HandlerBuilder + +function M.new(): HandlerBuilder + local self: Builder = { + command_kind = "create", + note_prefix = "order", + counter_name = nil, + source_key = nil, + decorator = nil, + for_kind = Builder.for_kind, + prefix_with = Builder.prefix_with, + count_as = Builder.count_as, + capture_source = Builder.capture_source, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:for_kind(kind: "create" | "reserve" | "complete"): Builder + self.command_kind = kind + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.note_prefix = prefix + return self +end + +function Builder:count_as(counter_name: string): Builder + self.counter_name = counter_name + return self +end + +function Builder:capture_source(source_key: string): Builder + self.source_key = source_key + return self +end + +function Builder:decorate(decorator: Decorator): Builder + self.decorator = decorator + return self +end + +function Builder:build(): protocol.CommandHandler + local command_kind = self.command_kind + local note_prefix = self.note_prefix + local counter_name = self.counter_name + local source_key = self.source_key + local decorator = self.decorator + + return function(state: protocol.StoreState, command: protocol.Command, at: time.Time): protocol.HandlerResult + if command.kind == "tick" then + return {ok = true, value = nil} + end + + local aggregate: protocol.OrderAggregate + + if command_kind == "create" then + if command.kind ~= "create" then + return { + ok = false, + error = { + code = "invalid", + message = "expected create", + retryable = false, + }, + } + end + + local existing = state.orders[command.id] + if existing then + return { + ok = false, + error = { + code = "conflict", + message = "order exists: " .. command.id, + retryable = false, + }, + } + end + + aggregate = { + id = command.id, + customer = command.customer, + version = 1, + status = "created", + item_id = nil, + source = nil, + updated_at = at, + } + state.orders[command.id] = aggregate + elseif command_kind == "reserve" then + if command.kind ~= "reserve" then + return { + ok = false, + error = { + code = "invalid", + message = "expected reserve", + retryable = false, + }, + } + end + + local existing = state.orders[command.id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing order: " .. command.id, + retryable = false, + }, + } + end + + aggregate = existing + aggregate.version = aggregate.version + 1 + aggregate.status = "reserved" + aggregate.item_id = command.item_id + aggregate.updated_at = at + else + if command.kind ~= "complete" then + return { + ok = false, + error = { + code = "invalid", + message = "expected complete", + retryable = false, + }, + } + end + + local existing = state.orders[command.id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing order: " .. command.id, + retryable = false, + }, + } + end + + aggregate = existing + aggregate.version = aggregate.version + 1 + aggregate.status = "completed" + aggregate.updated_at = at + end + + if source_key then + local tags = command.meta.tags + if tags then + local source = tags[source_key] + if source then + aggregate.source = source + end + end + end + + local view = state.views[aggregate.id] + if not view then + view = { + id = aggregate.id, + status = aggregate.status, + version = aggregate.version, + item_id = aggregate.item_id, + source = aggregate.source, + completed_at = nil, + } + state.views[aggregate.id] = view + end + + view.status = aggregate.status + view.version = aggregate.version + view.item_id = aggregate.item_id + view.source = aggregate.source + if aggregate.status == "completed" then + view.completed_at = at + end + + if counter_name then + local current = state.counters[counter_name] or 0 + state.counters[counter_name] = current + 1 + end + + local note = note_prefix .. ":" .. aggregate.id .. ":" .. aggregate.status .. ":" .. tostring(aggregate.version) + if decorator then + note = decorator(note, aggregate, command) + end + + return {ok = true, value = note} + end +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/helpers.lua b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/helpers.lua new file mode 100644 index 00000000..bca5f9a9 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/helpers.lua @@ -0,0 +1,40 @@ +local protocol = require("protocol") + +local M = {} + +function M.command_id(command: protocol.Command): string? + if command.kind == "tick" then + return nil + end + return command.id +end + +function M.command_label(command: protocol.Command): string + if command.kind == "create" then + return "create:" .. command.customer + end + if command.kind == "reserve" then + return "reserve:" .. command.item_id + end + if command.kind == "complete" then + return "complete" + end + return "tick" +end + +function M.source_tag(command: protocol.Command): string? + if command.kind == "tick" then + return nil + end + local tags = command.meta.tags + if not tags then + return nil + end + return tags["source"] +end + +function M.status_name(status: string?): string + return status or "unknown" +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/main.lua b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/main.lua new file mode 100644 index 00000000..6d53d271 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/main.lua @@ -0,0 +1,53 @@ +local time = require("time") +local protocol = require("protocol") +local validator_builder = require("validator_builder") +local handler_builder = require("handler_builder") +local runtime = require("runtime") + +local now = time.now() + +local source_validator = validator_builder.new() + :named("source") + :require_tag("source") + :remember_flag("validated_source") + :build() + +local create_handler = handler_builder.new() + :for_kind("create") + :prefix_with("order") + :count_as("created") + :capture_source("source") + :build() + +local app = runtime.new() + :use_validator(source_validator) + :register_handler("create", create_handler) + +local create_one: protocol.CreateOrderCommand = { + kind = "create", + id = "ord-1", + customer = "alice", + meta = protocol.meta("trace-1", {source = "api"}), +} + +local tick: protocol.TickCommand = { + kind = "tick", + at = now, +} + +local store = app:new_store("cqrs-1", now) +local replay_result = app:replay(store, {create_one, tick}, now) + +if replay_result.ok then + local last_status: string = replay_result.value.last_status -- expect-error +end + +local order = store:lookup_order("ord-1") +local item_id: string = order.item_id -- expect-error +local order_source: string = order.source -- expect-error + +local view = store:lookup_view("ord-1") +local completed_at = now:sub(view.completed_at) -- expect-error + +local missing_view: protocol.OrderView = store.state.views["missing"] -- expect-error +local trace_source: string = create_one.meta.tags["source"] -- expect-error diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/manifest.json b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/manifest.json new file mode 100644 index 00000000..00387414 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/manifest.json @@ -0,0 +1,14 @@ +{ + "description": "Soundness checks for the CQRS order runtime domain: unsafe optional state, view, tag, and time uses must be rejected.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "order_store.lua", + "validator_builder.lua", + "handler_builder.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"] +} diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/order_store.lua b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/order_store.lua new file mode 100644 index 00000000..c27c56c4 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/order_store.lua @@ -0,0 +1,137 @@ +local time = require("time") +local protocol = require("protocol") + +type OrderStore = { + state: protocol.StoreState, + touch: (self: OrderStore, at: time.Time) -> OrderStore, + push_step: (self: OrderStore, step: protocol.RunStep, at: time.Time) -> OrderStore, + ensure_order: (self: OrderStore, id: string, customer: string, at: time.Time) -> protocol.OrderAggregate, + lookup_order: (self: OrderStore, id: string) -> protocol.OrderAggregate?, + ensure_view: (self: OrderStore, id: string, status: "created" | "reserved" | "completed") -> protocol.OrderView, + lookup_view: (self: OrderStore, id: string) -> protocol.OrderView?, + increment: (self: OrderStore, name: string) -> integer, + summarize: (self: OrderStore, now: time.Time, last_status: string?) -> protocol.RunSummary, +} + +type Store = OrderStore + +local Store = {} +Store.__index = Store + +local M = {} +M.OrderStore = OrderStore + +function M.new(id: string, now: time.Time): OrderStore + local self: Store = { + state = { + id = id, + started_at = now, + last_command_at = nil, + steps = {}, + orders = {}, + views = {}, + counters = {}, + flags = {}, + }, + touch = Store.touch, + push_step = Store.push_step, + ensure_order = Store.ensure_order, + lookup_order = Store.lookup_order, + ensure_view = Store.ensure_view, + lookup_view = Store.lookup_view, + increment = Store.increment, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:touch(at: time.Time): Store + self.state.last_command_at = at + return self +end + +function Store:push_step(step: protocol.RunStep, at: time.Time): Store + table.insert(self.state.steps, step) + return self:touch(at) +end + +function Store:ensure_order(id: string, customer: string, at: time.Time): protocol.OrderAggregate + local current = self.state.orders[id] + if current then + if current.updated_at == nil then + current.updated_at = at + end + return current + end + + local created: protocol.OrderAggregate = { + id = id, + customer = customer, + version = 0, + status = "created", + item_id = nil, + source = nil, + updated_at = at, + } + self.state.orders[id] = created + return created +end + +function Store:lookup_order(id: string): protocol.OrderAggregate? + return self.state.orders[id] +end + +function Store:ensure_view(id: string, status: "created" | "reserved" | "completed"): protocol.OrderView + local current = self.state.views[id] + if current then + return current + end + + local created: protocol.OrderView = { + id = id, + status = status, + version = 0, + item_id = nil, + source = nil, + completed_at = nil, + } + self.state.views[id] = created + return created +end + +function Store:lookup_view(id: string): protocol.OrderView? + return self.state.views[id] +end + +function Store:increment(name: string): integer + local current = self.state.counters[name] or 0 + local next_value = current + 1 + self.state.counters[name] = next_value + return next_value +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.RunSummary + local order_count = 0 + local completed_count = 0 + for _, view in pairs(self.state.views) do + order_count = order_count + 1 + if view.status == "completed" then + completed_count = completed_count + 1 + end + end + + local seen_at = self.state.last_command_at or self.state.started_at + local elapsed = now:sub(seen_at) + + return { + id = self.state.id, + total_steps = #self.state.steps, + order_count = order_count, + completed_count = completed_count, + last_status = last_status, + elapsed_seconds = elapsed:seconds(), + } +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/protocol.lua b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/protocol.lua new file mode 100644 index 00000000..791c14a5 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/protocol.lua @@ -0,0 +1,125 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type CommandMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type CreateOrderCommand = { + kind: "create", + id: string, + customer: string, + meta: CommandMeta, +} + +type ReserveItemCommand = { + kind: "reserve", + id: string, + item_id: string, + meta: CommandMeta, +} + +type CompleteOrderCommand = { + kind: "complete", + id: string, + meta: CommandMeta, +} + +type TickCommand = { + kind: "tick", + at: time.Time, +} + +type Command = CreateOrderCommand | ReserveItemCommand | CompleteOrderCommand | TickCommand + +type OrderAggregate = { + id: string, + customer: string, + version: integer, + status: "created" | "reserved" | "completed", + item_id: string?, + source: string?, + updated_at: time.Time?, +} + +type OrderView = { + id: string, + status: "created" | "reserved" | "completed", + version: integer, + item_id: string?, + source: string?, + completed_at: time.Time?, +} + +type RunStep = { + kind: "command", + name: string, + note: string, + order_id: string?, +} | { + kind: "audit", + note: string, + at: time.Time, +} + +type StoreState = { + id: string, + started_at: time.Time, + last_command_at: time.Time?, + steps: {RunStep}, + orders: {[string]: OrderAggregate}, + views: {[string]: OrderView}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type RunSummary = { + id: string, + total_steps: number, + order_count: number, + completed_count: number, + last_status: string?, + elapsed_seconds: number, +} + +type ValidationResult = {ok: true, value: Command} | {ok: false, error: AppError} +type HandlerResult = {ok: true, value: string?} | {ok: false, error: AppError} +type ExecuteResult = {ok: true, value: string?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type CommandValidator = (StoreState, Command) -> ValidationResult +type CommandHandler = (StoreState, Command, time.Time) -> HandlerResult +type StepHook = (RunStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.CommandMeta = CommandMeta +M.CreateOrderCommand = CreateOrderCommand +M.ReserveItemCommand = ReserveItemCommand +M.CompleteOrderCommand = CompleteOrderCommand +M.TickCommand = TickCommand +M.Command = Command +M.OrderAggregate = OrderAggregate +M.OrderView = OrderView +M.RunStep = RunStep +M.StoreState = StoreState +M.RunSummary = RunSummary +M.ValidationResult = ValidationResult +M.HandlerResult = HandlerResult +M.ExecuteResult = ExecuteResult +M.RunResult = RunResult +M.CommandValidator = CommandValidator +M.CommandHandler = CommandHandler +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): CommandMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/result.lua b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/runtime.lua b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/runtime.lua new file mode 100644 index 00000000..e9cca901 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/runtime.lua @@ -0,0 +1,143 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local order_store = require("order_store") + +type Runtime = { + validators: {protocol.CommandValidator}, + handlers: {[string]: protocol.CommandHandler}, + hooks: {protocol.StepHook}, + use_validator: (self: Runtime, validator: protocol.CommandValidator) -> Runtime, + register_handler: (self: Runtime, kind: string, handler: protocol.CommandHandler) -> Runtime, + on_step: (self: Runtime, hook: protocol.StepHook) -> Runtime, + new_store: (self: Runtime, id: string, now: time.Time) -> order_store.OrderStore, + emit: (self: Runtime, store: order_store.OrderStore, step: protocol.RunStep, at: time.Time) -> (), + execute: (self: Runtime, store: order_store.OrderStore, command: protocol.Command, at: time.Time) -> protocol.ExecuteResult, + replay: (self: Runtime, store: order_store.OrderStore, commands: {protocol.Command}, now: time.Time) -> protocol.RunResult, +} + +type AppRuntime = Runtime + +local AppRuntime = {} +AppRuntime.__index = AppRuntime + +local M = {} +M.Runtime = Runtime + +function M.new(): Runtime + local self: AppRuntime = { + validators = {}, + handlers = {}, + hooks = {}, + use_validator = AppRuntime.use_validator, + register_handler = AppRuntime.register_handler, + on_step = AppRuntime.on_step, + new_store = AppRuntime.new_store, + emit = AppRuntime.emit, + execute = AppRuntime.execute, + replay = AppRuntime.replay, + } + setmetatable(self, AppRuntime) + return self +end + +function AppRuntime:use_validator(validator: protocol.CommandValidator): AppRuntime + table.insert(self.validators, validator) + return self +end + +function AppRuntime:register_handler(kind: string, handler: protocol.CommandHandler): AppRuntime + self.handlers[kind] = handler + return self +end + +function AppRuntime:on_step(hook: protocol.StepHook): AppRuntime + table.insert(self.hooks, hook) + return self +end + +function AppRuntime:new_store(id: string, now: time.Time): order_store.OrderStore + return order_store.new(id, now) +end + +function AppRuntime:emit(store: order_store.OrderStore, step: protocol.RunStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function AppRuntime:execute( + store: order_store.OrderStore, + command: protocol.Command, + at: time.Time +): protocol.ExecuteResult + if command.kind == "tick" then + local audit_step: protocol.RunStep = { + kind = "audit", + note = "tick", + at = command.at, + } + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + local current: protocol.Command = command + for _, validator in ipairs(self.validators) do + local validation_result: protocol.ValidationResult = validator(store.state, current) + if not validation_result.ok then + return {ok = false, error = validation_result.error} + end + current = validation_result.value + end + + local handler = self.handlers[current.kind] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing handler: " .. current.kind, + retryable = false, + }, + } + end + local handler_value: protocol.CommandHandler = handler + local handler_result: protocol.HandlerResult = handler_value(store.state, current, at) + if not handler_result.ok then + return {ok = false, error = handler_result.error} + end + + local note = handler_result.value or helpers.command_label(current) + local command_step: protocol.RunStep = { + kind = "command", + name = current.kind, + note = note, + order_id = current.id, + } + self:emit(store, command_step, at) + + return {ok = true, value = current.kind} +end + +function AppRuntime:replay( + store: order_store.OrderStore, + commands: {protocol.Command}, + now: time.Time +): protocol.RunResult + local last_status: string? = nil + + for _, command in ipairs(commands) do + local execute_result: protocol.ExecuteResult = self:execute(store, command, now) + if not execute_result.ok then + return {ok = false, error = execute_result.error} + end + if execute_result.value then + last_status = execute_result.value + end + end + + return {ok = true, value = store:summarize(now, last_status)} +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime-soundness/validator_builder.lua b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/validator_builder.lua new file mode 100644 index 00000000..bdaaf43b --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime-soundness/validator_builder.lua @@ -0,0 +1,94 @@ +local protocol = require("protocol") + +type ValidatorBuilder = { + name: string, + required_tag: string?, + flag_name: string?, + named: (self: ValidatorBuilder, name: string) -> ValidatorBuilder, + require_tag: (self: ValidatorBuilder, tag_name: string) -> ValidatorBuilder, + remember_flag: (self: ValidatorBuilder, flag_name: string) -> ValidatorBuilder, + build: (self: ValidatorBuilder) -> protocol.CommandValidator, +} + +type Builder = ValidatorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ValidatorBuilder = ValidatorBuilder + +function M.new(): ValidatorBuilder + local self: Builder = { + name = "validator", + required_tag = nil, + flag_name = nil, + named = Builder.named, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_tag(tag_name: string): Builder + self.required_tag = tag_name + return self +end + +function Builder:remember_flag(flag_name: string): Builder + self.flag_name = flag_name + return self +end + +function Builder:build(): protocol.CommandValidator + local name = self.name + local required_tag = self.required_tag + local flag_name = self.flag_name + + return function(state: protocol.StoreState, command: protocol.Command): protocol.ValidationResult + if command.kind == "tick" then + return {ok = true, value = command} + end + + if required_tag then + local tags = command.meta.tags + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + if flag_name then + state.flags[flag_name] = true + end + + return {ok = true, value = command} + end +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/handler_builder.lua b/testdata/fixtures/realworld/cqrs-order-runtime/handler_builder.lua new file mode 100644 index 00000000..2b6dfe58 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/handler_builder.lua @@ -0,0 +1,223 @@ +local protocol = require("protocol") +local helpers = require("helpers") + +type Decorator = (string, protocol.OrderAggregate, protocol.Command) -> string + +type HandlerBuilder = { + command_kind: "create" | "reserve" | "complete", + note_prefix: string, + counter_name: string?, + source_key: string?, + decorator: Decorator?, + for_kind: (self: HandlerBuilder, kind: "create" | "reserve" | "complete") -> HandlerBuilder, + prefix_with: (self: HandlerBuilder, prefix: string) -> HandlerBuilder, + count_as: (self: HandlerBuilder, counter_name: string) -> HandlerBuilder, + capture_source: (self: HandlerBuilder, source_key: string) -> HandlerBuilder, + decorate: (self: HandlerBuilder, decorator: Decorator) -> HandlerBuilder, + build: (self: HandlerBuilder) -> protocol.CommandHandler, +} + +type Builder = HandlerBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.HandlerBuilder = HandlerBuilder + +function M.new(): HandlerBuilder + local self: Builder = { + command_kind = "create", + note_prefix = "order", + counter_name = nil, + source_key = nil, + decorator = nil, + for_kind = Builder.for_kind, + prefix_with = Builder.prefix_with, + count_as = Builder.count_as, + capture_source = Builder.capture_source, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:for_kind(kind: "create" | "reserve" | "complete"): Builder + self.command_kind = kind + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.note_prefix = prefix + return self +end + +function Builder:count_as(counter_name: string): Builder + self.counter_name = counter_name + return self +end + +function Builder:capture_source(source_key: string): Builder + self.source_key = source_key + return self +end + +function Builder:decorate(decorator: Decorator): Builder + self.decorator = decorator + return self +end + +function Builder:build(): protocol.CommandHandler + local command_kind = self.command_kind + local note_prefix = self.note_prefix + local counter_name = self.counter_name + local source_key = self.source_key + local decorator = self.decorator + + return function(state: protocol.StoreState, command: protocol.Command, at: time.Time): protocol.HandlerResult + if command.kind == "tick" then + return {ok = true, value = nil} + end + + local aggregate: protocol.OrderAggregate + + if command_kind == "create" then + if command.kind ~= "create" then + return { + ok = false, + error = { + code = "invalid", + message = "expected create", + retryable = false, + }, + } + end + + local existing = state.orders[command.id] + if existing then + return { + ok = false, + error = { + code = "conflict", + message = "order exists: " .. command.id, + retryable = false, + }, + } + end + + aggregate = { + id = command.id, + customer = command.customer, + version = 1, + status = "created", + item_id = nil, + source = nil, + updated_at = at, + } + state.orders[command.id] = aggregate + elseif command_kind == "reserve" then + if command.kind ~= "reserve" then + return { + ok = false, + error = { + code = "invalid", + message = "expected reserve", + retryable = false, + }, + } + end + + local existing = state.orders[command.id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing order: " .. command.id, + retryable = false, + }, + } + end + + aggregate = existing + aggregate.version = aggregate.version + 1 + aggregate.status = "reserved" + aggregate.item_id = command.item_id + aggregate.updated_at = at + else + if command.kind ~= "complete" then + return { + ok = false, + error = { + code = "invalid", + message = "expected complete", + retryable = false, + }, + } + end + + local existing = state.orders[command.id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing order: " .. command.id, + retryable = false, + }, + } + end + + aggregate = existing + aggregate.version = aggregate.version + 1 + aggregate.status = "completed" + aggregate.updated_at = at + end + + if source_key then + local tags = command.meta.tags + if tags then + local source = tags[source_key] + if source then + aggregate.source = source + end + end + end + + local view = state.views[aggregate.id] + if not view then + view = { + id = aggregate.id, + status = aggregate.status, + version = aggregate.version, + item_id = aggregate.item_id, + source = aggregate.source, + completed_at = nil, + } + state.views[aggregate.id] = view + end + + view.status = aggregate.status + view.version = aggregate.version + view.item_id = aggregate.item_id + view.source = aggregate.source + if aggregate.status == "completed" then + view.completed_at = at + end + + if counter_name then + local current = state.counters[counter_name] or 0 + state.counters[counter_name] = current + 1 + end + + local note = note_prefix .. ":" .. aggregate.id .. ":" .. aggregate.status .. ":" .. tostring(aggregate.version) + if decorator then + note = decorator(note, aggregate, command) + end + + return {ok = true, value = note} + end +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/helpers.lua b/testdata/fixtures/realworld/cqrs-order-runtime/helpers.lua new file mode 100644 index 00000000..bca5f9a9 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/helpers.lua @@ -0,0 +1,40 @@ +local protocol = require("protocol") + +local M = {} + +function M.command_id(command: protocol.Command): string? + if command.kind == "tick" then + return nil + end + return command.id +end + +function M.command_label(command: protocol.Command): string + if command.kind == "create" then + return "create:" .. command.customer + end + if command.kind == "reserve" then + return "reserve:" .. command.item_id + end + if command.kind == "complete" then + return "complete" + end + return "tick" +end + +function M.source_tag(command: protocol.Command): string? + if command.kind == "tick" then + return nil + end + local tags = command.meta.tags + if not tags then + return nil + end + return tags["source"] +end + +function M.status_name(status: string?): string + return status or "unknown" +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/main.lua b/testdata/fixtures/realworld/cqrs-order-runtime/main.lua new file mode 100644 index 00000000..0bd8684a --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/main.lua @@ -0,0 +1,218 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local helpers = require("helpers") +local validator_builder = require("validator_builder") +local handler_builder = require("handler_builder") +local runtime = require("runtime") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() + +local observed_notes: {[string]: string} = {} +local observed_audits: {string} = {} +local last_runtime_id: string? = nil + +local source_validator = validator_builder.new() + :named("source") + :require_tag("source") + :remember_flag("validated_source") + :build() + +local create_handler = handler_builder.new() + :for_kind("create") + :prefix_with("order") + :count_as("created") + :capture_source("source") + :decorate(function(note: string, aggregate: protocol.OrderAggregate, command: protocol.Command): string + return note .. ":" .. aggregate.customer .. ":" .. helpers.command_label(command) + end) + :build() + +local reserve_handler = handler_builder.new() + :for_kind("reserve") + :prefix_with("reserve") + :count_as("reserved") + :capture_source("source") + :decorate(function(note: string, aggregate: protocol.OrderAggregate, _command: protocol.Command): string + local item = aggregate.item_id or "missing" + return note .. ":" .. item + end) + :build() + +local complete_handler = handler_builder.new() + :for_kind("complete") + :prefix_with("complete") + :count_as("completed") + :capture_source("source") + :decorate(function(note: string, aggregate: protocol.OrderAggregate, _command: protocol.Command): string + local source = aggregate.source or "unknown" + return note .. ":" .. source + end) + :build() + +local app = runtime.new() + :use_validator(source_validator) + :register_handler("create", create_handler) + :register_handler("reserve", reserve_handler) + :register_handler("complete", complete_handler) + +app:on_step(function(step: protocol.RunStep, state: protocol.StoreState) + last_runtime_id = state.id + + if step.kind == "command" then + observed_notes[step.name .. ":" .. tostring(#observed_audits + 1)] = step.note + if step.order_id then + local order_id: string = step.order_id + end + else + table.insert(observed_audits, step.note) + local at_seconds: integer = step.at:unix() + end +end) + +local create_one: protocol.CreateOrderCommand = { + kind = "create", + id = "ord-1", + customer = "alice", + meta = protocol.meta("trace-1", {source = "api", lane = "priority"}), +} + +local reserve_one: protocol.ReserveItemCommand = { + kind = "reserve", + id = "ord-1", + item_id = "item-7", + meta = protocol.meta("trace-2", {source = "worker"}), +} + +local complete_one: protocol.CompleteOrderCommand = { + kind = "complete", + id = "ord-1", + meta = protocol.meta("trace-3", {source = "worker"}), +} + +local create_two: protocol.CreateOrderCommand = { + kind = "create", + id = "ord-2", + customer = "bob", + meta = protocol.meta("trace-4", {source = "api"}), +} + +local reserve_two: protocol.ReserveItemCommand = { + kind = "reserve", + id = "ord-2", + item_id = "item-9", + meta = protocol.meta("trace-5", {source = "worker"}), +} + +local tick: protocol.TickCommand = { + kind = "tick", + at = now, +} + +local commands: {protocol.Command} = { + create_one, + reserve_one, + complete_one, + create_two, + reserve_two, + tick, +} + +local store = app:new_store("cqrs-1", now) +local summary_result = app:replay(store, commands, now) +if not summary_result.ok then + local message: string = summary_result.error.message + local retryable: boolean = summary_result.error.retryable +else + local summary = summary_result.value + local runtime_id: string = summary.id + local total_steps: number = summary.total_steps + local order_count: number = summary.order_count + local completed_count: number = summary.completed_count + local elapsed_seconds: number = summary.elapsed_seconds + local last_status: string? = summary.last_status +end + +local summary_label = result.map(summary_result, function(summary: protocol.RunSummary): string + return summary.id .. ":" .. tostring(summary.order_count) +end) + +if summary_label.ok then + local label: string = summary_label.value +end + +local summary_id = result.and_then(summary_result, function(summary: protocol.RunSummary): StringResult + if summary.completed_count == 0 then + return { + ok = false, + error = { + code = "invalid", + message = "expected completed order", + retryable = false, + }, + } + end + return { + ok = true, + value = summary.id, + } +end) + +if summary_id.ok then + local stable_id: string = summary_id.value +end + +local order_one = store:lookup_order("ord-1") +if order_one then + local status: string = order_one.status + local version: integer = order_one.version + local item = order_one.item_id + if item then + local stable_item: string = item + end + local source = order_one.source + if source then + local stable_source: string = source + end + local updated = order_one.updated_at or now + local seconds: number = now:sub(updated):seconds() +end + +local view_two = store:lookup_view("ord-2") +if view_two then + local view_status: string = view_two.status + local view_version: integer = view_two.version + local view_item = view_two.item_id + if view_item then + local stable_item: string = view_item + end +end + +local missing = store:lookup_view("missing") +if missing == nil then + local fallback: string = "missing" +end + +for key, note in pairs(observed_notes) do + local stable_key: string = key + local stable_note: string = note +end + +for _, note in ipairs(observed_audits) do + local audit_note: string = note +end + +if last_runtime_id ~= nil then + local stable_runtime_id: string = last_runtime_id +end + +local source = helpers.source_tag(create_one) +if source then + local stable_source: string = source +end + +local seen_at = store.state.last_command_at or store.state.started_at +local elapsed = now:sub(seen_at) +local seconds: number = elapsed:seconds() diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/manifest.json b/testdata/fixtures/realworld/cqrs-order-runtime/manifest.json new file mode 100644 index 00000000..f772e899 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/manifest.json @@ -0,0 +1,17 @@ +{ + "description": "Application-shaped CQRS runtime mixing staged validation, fluent command handlers, metatable-backed stores, discriminated commands, dynamic registries, optional projections, counters, callbacks, and time-based summaries.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "order_store.lua", + "validator_builder.lua", + "handler_builder.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/order_store.lua b/testdata/fixtures/realworld/cqrs-order-runtime/order_store.lua new file mode 100644 index 00000000..c27c56c4 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/order_store.lua @@ -0,0 +1,137 @@ +local time = require("time") +local protocol = require("protocol") + +type OrderStore = { + state: protocol.StoreState, + touch: (self: OrderStore, at: time.Time) -> OrderStore, + push_step: (self: OrderStore, step: protocol.RunStep, at: time.Time) -> OrderStore, + ensure_order: (self: OrderStore, id: string, customer: string, at: time.Time) -> protocol.OrderAggregate, + lookup_order: (self: OrderStore, id: string) -> protocol.OrderAggregate?, + ensure_view: (self: OrderStore, id: string, status: "created" | "reserved" | "completed") -> protocol.OrderView, + lookup_view: (self: OrderStore, id: string) -> protocol.OrderView?, + increment: (self: OrderStore, name: string) -> integer, + summarize: (self: OrderStore, now: time.Time, last_status: string?) -> protocol.RunSummary, +} + +type Store = OrderStore + +local Store = {} +Store.__index = Store + +local M = {} +M.OrderStore = OrderStore + +function M.new(id: string, now: time.Time): OrderStore + local self: Store = { + state = { + id = id, + started_at = now, + last_command_at = nil, + steps = {}, + orders = {}, + views = {}, + counters = {}, + flags = {}, + }, + touch = Store.touch, + push_step = Store.push_step, + ensure_order = Store.ensure_order, + lookup_order = Store.lookup_order, + ensure_view = Store.ensure_view, + lookup_view = Store.lookup_view, + increment = Store.increment, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:touch(at: time.Time): Store + self.state.last_command_at = at + return self +end + +function Store:push_step(step: protocol.RunStep, at: time.Time): Store + table.insert(self.state.steps, step) + return self:touch(at) +end + +function Store:ensure_order(id: string, customer: string, at: time.Time): protocol.OrderAggregate + local current = self.state.orders[id] + if current then + if current.updated_at == nil then + current.updated_at = at + end + return current + end + + local created: protocol.OrderAggregate = { + id = id, + customer = customer, + version = 0, + status = "created", + item_id = nil, + source = nil, + updated_at = at, + } + self.state.orders[id] = created + return created +end + +function Store:lookup_order(id: string): protocol.OrderAggregate? + return self.state.orders[id] +end + +function Store:ensure_view(id: string, status: "created" | "reserved" | "completed"): protocol.OrderView + local current = self.state.views[id] + if current then + return current + end + + local created: protocol.OrderView = { + id = id, + status = status, + version = 0, + item_id = nil, + source = nil, + completed_at = nil, + } + self.state.views[id] = created + return created +end + +function Store:lookup_view(id: string): protocol.OrderView? + return self.state.views[id] +end + +function Store:increment(name: string): integer + local current = self.state.counters[name] or 0 + local next_value = current + 1 + self.state.counters[name] = next_value + return next_value +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.RunSummary + local order_count = 0 + local completed_count = 0 + for _, view in pairs(self.state.views) do + order_count = order_count + 1 + if view.status == "completed" then + completed_count = completed_count + 1 + end + end + + local seen_at = self.state.last_command_at or self.state.started_at + local elapsed = now:sub(seen_at) + + return { + id = self.state.id, + total_steps = #self.state.steps, + order_count = order_count, + completed_count = completed_count, + last_status = last_status, + elapsed_seconds = elapsed:seconds(), + } +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/protocol.lua b/testdata/fixtures/realworld/cqrs-order-runtime/protocol.lua new file mode 100644 index 00000000..791c14a5 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/protocol.lua @@ -0,0 +1,125 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type CommandMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type CreateOrderCommand = { + kind: "create", + id: string, + customer: string, + meta: CommandMeta, +} + +type ReserveItemCommand = { + kind: "reserve", + id: string, + item_id: string, + meta: CommandMeta, +} + +type CompleteOrderCommand = { + kind: "complete", + id: string, + meta: CommandMeta, +} + +type TickCommand = { + kind: "tick", + at: time.Time, +} + +type Command = CreateOrderCommand | ReserveItemCommand | CompleteOrderCommand | TickCommand + +type OrderAggregate = { + id: string, + customer: string, + version: integer, + status: "created" | "reserved" | "completed", + item_id: string?, + source: string?, + updated_at: time.Time?, +} + +type OrderView = { + id: string, + status: "created" | "reserved" | "completed", + version: integer, + item_id: string?, + source: string?, + completed_at: time.Time?, +} + +type RunStep = { + kind: "command", + name: string, + note: string, + order_id: string?, +} | { + kind: "audit", + note: string, + at: time.Time, +} + +type StoreState = { + id: string, + started_at: time.Time, + last_command_at: time.Time?, + steps: {RunStep}, + orders: {[string]: OrderAggregate}, + views: {[string]: OrderView}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type RunSummary = { + id: string, + total_steps: number, + order_count: number, + completed_count: number, + last_status: string?, + elapsed_seconds: number, +} + +type ValidationResult = {ok: true, value: Command} | {ok: false, error: AppError} +type HandlerResult = {ok: true, value: string?} | {ok: false, error: AppError} +type ExecuteResult = {ok: true, value: string?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type CommandValidator = (StoreState, Command) -> ValidationResult +type CommandHandler = (StoreState, Command, time.Time) -> HandlerResult +type StepHook = (RunStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.CommandMeta = CommandMeta +M.CreateOrderCommand = CreateOrderCommand +M.ReserveItemCommand = ReserveItemCommand +M.CompleteOrderCommand = CompleteOrderCommand +M.TickCommand = TickCommand +M.Command = Command +M.OrderAggregate = OrderAggregate +M.OrderView = OrderView +M.RunStep = RunStep +M.StoreState = StoreState +M.RunSummary = RunSummary +M.ValidationResult = ValidationResult +M.HandlerResult = HandlerResult +M.ExecuteResult = ExecuteResult +M.RunResult = RunResult +M.CommandValidator = CommandValidator +M.CommandHandler = CommandHandler +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): CommandMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/result.lua b/testdata/fixtures/realworld/cqrs-order-runtime/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/runtime.lua b/testdata/fixtures/realworld/cqrs-order-runtime/runtime.lua new file mode 100644 index 00000000..e9cca901 --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/runtime.lua @@ -0,0 +1,143 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local order_store = require("order_store") + +type Runtime = { + validators: {protocol.CommandValidator}, + handlers: {[string]: protocol.CommandHandler}, + hooks: {protocol.StepHook}, + use_validator: (self: Runtime, validator: protocol.CommandValidator) -> Runtime, + register_handler: (self: Runtime, kind: string, handler: protocol.CommandHandler) -> Runtime, + on_step: (self: Runtime, hook: protocol.StepHook) -> Runtime, + new_store: (self: Runtime, id: string, now: time.Time) -> order_store.OrderStore, + emit: (self: Runtime, store: order_store.OrderStore, step: protocol.RunStep, at: time.Time) -> (), + execute: (self: Runtime, store: order_store.OrderStore, command: protocol.Command, at: time.Time) -> protocol.ExecuteResult, + replay: (self: Runtime, store: order_store.OrderStore, commands: {protocol.Command}, now: time.Time) -> protocol.RunResult, +} + +type AppRuntime = Runtime + +local AppRuntime = {} +AppRuntime.__index = AppRuntime + +local M = {} +M.Runtime = Runtime + +function M.new(): Runtime + local self: AppRuntime = { + validators = {}, + handlers = {}, + hooks = {}, + use_validator = AppRuntime.use_validator, + register_handler = AppRuntime.register_handler, + on_step = AppRuntime.on_step, + new_store = AppRuntime.new_store, + emit = AppRuntime.emit, + execute = AppRuntime.execute, + replay = AppRuntime.replay, + } + setmetatable(self, AppRuntime) + return self +end + +function AppRuntime:use_validator(validator: protocol.CommandValidator): AppRuntime + table.insert(self.validators, validator) + return self +end + +function AppRuntime:register_handler(kind: string, handler: protocol.CommandHandler): AppRuntime + self.handlers[kind] = handler + return self +end + +function AppRuntime:on_step(hook: protocol.StepHook): AppRuntime + table.insert(self.hooks, hook) + return self +end + +function AppRuntime:new_store(id: string, now: time.Time): order_store.OrderStore + return order_store.new(id, now) +end + +function AppRuntime:emit(store: order_store.OrderStore, step: protocol.RunStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function AppRuntime:execute( + store: order_store.OrderStore, + command: protocol.Command, + at: time.Time +): protocol.ExecuteResult + if command.kind == "tick" then + local audit_step: protocol.RunStep = { + kind = "audit", + note = "tick", + at = command.at, + } + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + local current: protocol.Command = command + for _, validator in ipairs(self.validators) do + local validation_result: protocol.ValidationResult = validator(store.state, current) + if not validation_result.ok then + return {ok = false, error = validation_result.error} + end + current = validation_result.value + end + + local handler = self.handlers[current.kind] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing handler: " .. current.kind, + retryable = false, + }, + } + end + local handler_value: protocol.CommandHandler = handler + local handler_result: protocol.HandlerResult = handler_value(store.state, current, at) + if not handler_result.ok then + return {ok = false, error = handler_result.error} + end + + local note = handler_result.value or helpers.command_label(current) + local command_step: protocol.RunStep = { + kind = "command", + name = current.kind, + note = note, + order_id = current.id, + } + self:emit(store, command_step, at) + + return {ok = true, value = current.kind} +end + +function AppRuntime:replay( + store: order_store.OrderStore, + commands: {protocol.Command}, + now: time.Time +): protocol.RunResult + local last_status: string? = nil + + for _, command in ipairs(commands) do + local execute_result: protocol.ExecuteResult = self:execute(store, command, now) + if not execute_result.ok then + return {ok = false, error = execute_result.error} + end + if execute_result.value then + last_status = execute_result.value + end + end + + return {ok = true, value = store:summarize(now, last_status)} +end + +return M diff --git a/testdata/fixtures/realworld/cqrs-order-runtime/validator_builder.lua b/testdata/fixtures/realworld/cqrs-order-runtime/validator_builder.lua new file mode 100644 index 00000000..bdaaf43b --- /dev/null +++ b/testdata/fixtures/realworld/cqrs-order-runtime/validator_builder.lua @@ -0,0 +1,94 @@ +local protocol = require("protocol") + +type ValidatorBuilder = { + name: string, + required_tag: string?, + flag_name: string?, + named: (self: ValidatorBuilder, name: string) -> ValidatorBuilder, + require_tag: (self: ValidatorBuilder, tag_name: string) -> ValidatorBuilder, + remember_flag: (self: ValidatorBuilder, flag_name: string) -> ValidatorBuilder, + build: (self: ValidatorBuilder) -> protocol.CommandValidator, +} + +type Builder = ValidatorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ValidatorBuilder = ValidatorBuilder + +function M.new(): ValidatorBuilder + local self: Builder = { + name = "validator", + required_tag = nil, + flag_name = nil, + named = Builder.named, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_tag(tag_name: string): Builder + self.required_tag = tag_name + return self +end + +function Builder:remember_flag(flag_name: string): Builder + self.flag_name = flag_name + return self +end + +function Builder:build(): protocol.CommandValidator + local name = self.name + local required_tag = self.required_tag + local flag_name = self.flag_name + + return function(state: protocol.StoreState, command: protocol.Command): protocol.ValidationResult + if command.kind == "tick" then + return {ok = true, value = command} + end + + if required_tag then + local tags = command.meta.tags + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + if flag_name then + state.flags[flag_name] = true + end + + return {ok = true, value = command} + end +end + +return M diff --git a/testdata/fixtures/realworld/discriminated-tool-dispatch/executor.lua b/testdata/fixtures/realworld/discriminated-tool-dispatch/executor.lua index 21909bba..b6a91d22 100644 --- a/testdata/fixtures/realworld/discriminated-tool-dispatch/executor.lua +++ b/testdata/fixtures/realworld/discriminated-tool-dispatch/executor.lua @@ -2,7 +2,7 @@ local tools = require("tools") local M = {} -function M.execute(tool: Tool): ToolResult +function M.execute(tool: tools.Tool): tools.ToolResult if tool.type == "search" then local query: string = tool.args.query local limit: number = tool.args.limit or 10 @@ -30,8 +30,8 @@ function M.execute(tool: Tool): ToolResult return {tool_name = "unknown", output = "unsupported tool type", success = false} end -function M.execute_batch(tool_list: {Tool}): {ToolResult} - local results: {ToolResult} = {} +function M.execute_batch(tool_list: {tools.Tool}): {tools.ToolResult} + local results: {tools.ToolResult} = {} for _, tool in ipairs(tool_list) do table.insert(results, M.execute(tool)) end diff --git a/testdata/fixtures/realworld/error-handling-chain/validator.lua b/testdata/fixtures/realworld/error-handling-chain/validator.lua index 63d714a1..bcb082b1 100644 --- a/testdata/fixtures/realworld/error-handling-chain/validator.lua +++ b/testdata/fixtures/realworld/error-handling-chain/validator.lua @@ -1,6 +1,6 @@ local errors = require("errors") -type ValidationResult = {ok: true, value: string} | {ok: false, error: AppError} +type ValidationResult = {ok: true, value: string} | {ok: false, error: errors.AppError} local M = {} M.ValidationResult = ValidationResult diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/bus.lua b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/bus.lua new file mode 100644 index 00000000..27fb0766 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/bus.lua @@ -0,0 +1,162 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local checkpoint_store = require("checkpoint_store") + +type EventBus = { + subscribers: {[string]: {protocol.Subscriber}}, + projectors: {[string]: {protocol.Projector}}, + hooks: {protocol.StepHook}, + ensure_subscribers: (self: EventBus, topic: string) -> {protocol.Subscriber}, + ensure_projectors: (self: EventBus, topic: string) -> {protocol.Projector}, + register_subscriber: (self: EventBus, topic: string, subscriber: protocol.Subscriber) -> EventBus, + register_projector: (self: EventBus, topic: string, projector: protocol.Projector) -> EventBus, + on_step: (self: EventBus, hook: protocol.StepHook) -> EventBus, + new_store: (self: EventBus, id: string, now: time.Time) -> checkpoint_store.CheckpointStore, + emit: (self: EventBus, store: checkpoint_store.CheckpointStore, step: protocol.DispatchStep, at: time.Time) -> (), + publish: (self: EventBus, store: checkpoint_store.CheckpointStore, topic: string, event: protocol.Event, at: time.Time) -> protocol.PublishResult, + replay: (self: EventBus, store: checkpoint_store.CheckpointStore, topic: string, events: {protocol.Event}, now: time.Time) -> protocol.ReplayResult, +} + +type Bus = EventBus + +local Bus = {} +Bus.__index = Bus + +local M = {} +M.EventBus = EventBus + +function M.new(): EventBus + local self: Bus = { + subscribers = {}, + projectors = {}, + hooks = {}, + ensure_subscribers = Bus.ensure_subscribers, + ensure_projectors = Bus.ensure_projectors, + register_subscriber = Bus.register_subscriber, + register_projector = Bus.register_projector, + on_step = Bus.on_step, + new_store = Bus.new_store, + emit = Bus.emit, + publish = Bus.publish, + replay = Bus.replay, + } + setmetatable(self, Bus) + return self +end + +function Bus:ensure_subscribers(topic: string): {protocol.Subscriber} + local current = self.subscribers[topic] + if current then + return current + end + + local created: {protocol.Subscriber} = {} + self.subscribers[topic] = created + return created +end + +function Bus:ensure_projectors(topic: string): {protocol.Projector} + local current = self.projectors[topic] + if current then + return current + end + + local created: {protocol.Projector} = {} + self.projectors[topic] = created + return created +end + +function Bus:register_subscriber(topic: string, subscriber: protocol.Subscriber): Bus + table.insert(self:ensure_subscribers(topic), subscriber) + return self +end + +function Bus:register_projector(topic: string, projector: protocol.Projector): Bus + table.insert(self:ensure_projectors(topic), projector) + return self +end + +function Bus:on_step(hook: protocol.StepHook): Bus + table.insert(self.hooks, hook) + return self +end + +function Bus:new_store(id: string, now: time.Time): checkpoint_store.CheckpointStore + return checkpoint_store.new(id, now) +end + +function Bus:emit(store: checkpoint_store.CheckpointStore, step: protocol.DispatchStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Bus:publish( + store: checkpoint_store.CheckpointStore, + topic: string, + event: protocol.Event, + at: time.Time +): protocol.PublishResult + local projectors = self:ensure_projectors(topic) + for _, projector in ipairs(projectors) do + projector(store.state, event, at) + end + + if event.kind == "tick" then + local audit_step: protocol.DispatchStep = {kind = "audit", note = "tick", at = event.at} + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + local subscribers = self:ensure_subscribers(topic) + for _, subscriber in ipairs(subscribers) do + local note_result: protocol.SubscriberResult = subscriber(store.state, event) + if not note_result.ok then + return {ok = false, error = note_result.error} + end + + local note = note_result.value + if note then + local subscriber_step: protocol.DispatchStep = { + kind = "subscriber", + topic = topic, + note = note, + projection_id = helpers.event_id(event), + } + self:emit(store, subscriber_step, at) + end + end + + if event.kind == "completed" then + return {ok = true, value = "completed"} + end + if event.kind == "failed" then + return {ok = true, value = "failed"} + end + return {ok = true, value = nil} +end + +function Bus:replay( + store: checkpoint_store.CheckpointStore, + topic: string, + events: {protocol.Event}, + now: time.Time +): protocol.ReplayResult + local last_status: string? = nil + + for _, event in ipairs(events) do + local publish_result: protocol.PublishResult = self:publish(store, topic, event, now) + if not publish_result.ok then + return {ok = false, error = publish_result.error} + end + if publish_result.value then + last_status = publish_result.value + end + end + + return {ok = true, value = store:summarize(now, last_status)} +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/checkpoint_store.lua b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/checkpoint_store.lua new file mode 100644 index 00000000..83339e3d --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/checkpoint_store.lua @@ -0,0 +1,112 @@ +local time = require("time") +local protocol = require("protocol") + +type CheckpointStore = { + state: protocol.BusState, + touch: (self: CheckpointStore, at: time.Time) -> CheckpointStore, + push_step: (self: CheckpointStore, step: protocol.DispatchStep, at: time.Time) -> CheckpointStore, + ensure_projection: (self: CheckpointStore, id: string, queue: string, at: time.Time) -> protocol.TaskProjection, + lookup_projection: (self: CheckpointStore, id: string) -> protocol.TaskProjection?, + increment: (self: CheckpointStore, name: string) -> integer, + summarize: (self: CheckpointStore, now: time.Time, last_status: string?) -> protocol.DispatchSummary, +} + +type Store = CheckpointStore + +local Store = {} +Store.__index = Store + +local M = {} +M.CheckpointStore = CheckpointStore + +function M.new(id: string, now: time.Time): CheckpointStore + local self: Store = { + state = { + id = id, + started_at = now, + last_event_at = nil, + steps = {}, + projections = {}, + counters = {}, + flags = {}, + }, + touch = Store.touch, + push_step = Store.push_step, + ensure_projection = Store.ensure_projection, + lookup_projection = Store.lookup_projection, + increment = Store.increment, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:touch(at: time.Time): Store + self.state.last_event_at = at + return self +end + +function Store:push_step(step: protocol.DispatchStep, at: time.Time): Store + table.insert(self.state.steps, step) + return self:touch(at) +end + +function Store:ensure_projection(id: string, queue: string, at: time.Time): protocol.TaskProjection + local current = self.state.projections[id] + if current then + if current.updated_at == nil then + current.updated_at = at + end + if current.queue == "unknown" and queue ~= "unknown" then + current.queue = queue + end + return current + end + + local created: protocol.TaskProjection = { + id = id, + queue = queue, + status = "queued", + worker = nil, + output = nil, + error_code = nil, + retryable = nil, + source = nil, + updated_at = at, + } + self.state.projections[id] = created + return created +end + +function Store:lookup_projection(id: string): protocol.TaskProjection? + return self.state.projections[id] +end + +function Store:increment(name: string): integer + local current = self.state.counters[name] or 0 + local next_value = current + 1 + self.state.counters[name] = next_value + return next_value +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.DispatchSummary + local projection_count = 0 + for _, _ in pairs(self.state.projections) do + projection_count = projection_count + 1 + end + + local failure_count = self.state.counters["failed"] or 0 + local seen_at = self.state.last_event_at or self.state.started_at + local elapsed = now:sub(seen_at) + + return { + id = self.state.id, + total_steps = #self.state.steps, + projection_count = projection_count, + failure_count = failure_count, + last_status = last_status, + elapsed_seconds = elapsed:seconds(), + } +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/helpers.lua b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/helpers.lua new file mode 100644 index 00000000..7d7226e5 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/helpers.lua @@ -0,0 +1,43 @@ +local protocol = require("protocol") + +local M = {} + +function M.event_label(event: protocol.Event): string + if event.kind == "queued" then + return "queued:" .. event.queue + end + if event.kind == "started" then + return "started:" .. event.worker + end + if event.kind == "completed" then + return "completed" + end + if event.kind == "failed" then + return "failed:" .. event.code + end + return "tick" +end + +function M.event_id(event: protocol.Event): string? + if event.kind == "tick" then + return nil + end + return event.id +end + +function M.source_tag(event: protocol.Event): string? + if event.kind == "tick" then + return nil + end + local tags = event.meta.tags + if not tags then + return nil + end + return tags["source"] +end + +function M.status_name(status: string?): string + return status or "unknown" +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/main.lua b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/main.lua new file mode 100644 index 00000000..35979ea5 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/main.lua @@ -0,0 +1,51 @@ +local time = require("time") +local protocol = require("protocol") +local subscriber_builder = require("subscriber_builder") +local projector_builder = require("projector_builder") +local bus = require("bus") + +local now = time.now() + +local subscriber = subscriber_builder.new() + :named("source") + :prefix_with("evt") + :build() + +local projector = projector_builder.new() + :track_queue("priority") + :count_failures_as("failed") + :capture_source("source") + :build() + +local app = bus.new() + :register_projector("tasks", projector) + :register_subscriber("tasks", subscriber) + +local queued_one: protocol.TaskQueuedEvent = { + kind = "queued", + id = "job-1", + queue = "priority", + payload = {task = "search"}, + meta = protocol.meta("trace-1", nil), +} + +local tick: protocol.TickEvent = { + kind = "tick", + at = now, +} + +local store = app:new_store("bus-1", now) +local replay_result = app:replay(store, "tasks", {queued_one, tick}, now) + +if replay_result.ok then + local last_status: string = replay_result.value.last_status -- expect-error +end + +local projection = store:lookup_projection("job-1") +local output: string = projection.output -- expect-error +local source: string = projection.source -- expect-error + +local missing: protocol.TaskProjection = store.state.projections["missing"] -- expect-error + +local elapsed = now:sub(store.state.last_event_at) -- expect-error +local trace_source: string = queued_one.meta.tags["source"] -- expect-error diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/manifest.json b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/manifest.json new file mode 100644 index 00000000..99bff2bb --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/manifest.json @@ -0,0 +1,14 @@ +{ + "description": "Soundness checks for the event bus saga domain: unsafe optional projection, map, tag, and time uses must be rejected.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "checkpoint_store.lua", + "subscriber_builder.lua", + "projector_builder.lua", + "bus.lua", + "main.lua" + ], + "packages": ["time"] +} diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/projector_builder.lua b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/projector_builder.lua new file mode 100644 index 00000000..3a38eb3d --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/projector_builder.lua @@ -0,0 +1,119 @@ +local protocol = require("protocol") + +type ProjectorBuilder = { + tracked_queue: string?, + failure_counter: string?, + source_key: string?, + track_queue: (self: ProjectorBuilder, queue: string) -> ProjectorBuilder, + count_failures_as: (self: ProjectorBuilder, counter_name: string) -> ProjectorBuilder, + capture_source: (self: ProjectorBuilder, tag_key: string) -> ProjectorBuilder, + build: (self: ProjectorBuilder) -> protocol.Projector, +} + +type Builder = ProjectorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ProjectorBuilder = ProjectorBuilder + +function M.new(): ProjectorBuilder + local self: Builder = { + tracked_queue = nil, + failure_counter = nil, + source_key = nil, + track_queue = Builder.track_queue, + count_failures_as = Builder.count_failures_as, + capture_source = Builder.capture_source, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:track_queue(queue: string): Builder + self.tracked_queue = queue + return self +end + +function Builder:count_failures_as(counter_name: string): Builder + self.failure_counter = counter_name + return self +end + +function Builder:capture_source(tag_key: string): Builder + self.source_key = tag_key + return self +end + +function Builder:build(): protocol.Projector + local tracked_queue = self.tracked_queue + local failure_counter = self.failure_counter + local source_key = self.source_key + + return function(state: protocol.BusState, event: protocol.Event, at) + if event.kind == "tick" then + return + end + + local queue_name = "unknown" + if event.kind == "queued" then + queue_name = event.queue + end + + if tracked_queue and event.kind == "queued" and event.queue ~= tracked_queue then + return + end + + local projection = state.projections[event.id] + if not projection then + projection = { + id = event.id, + queue = queue_name, + status = "queued", + worker = nil, + output = nil, + error_code = nil, + retryable = nil, + source = nil, + updated_at = at, + } + state.projections[event.id] = projection + end + + if projection.queue == "unknown" and queue_name ~= "unknown" then + projection.queue = queue_name + end + + local tags = event.meta.tags + if source_key and tags then + local source = tags[source_key] + if source then + projection.source = source + end + end + + if event.kind == "queued" then + projection.status = "queued" + elseif event.kind == "started" then + projection.status = "started" + projection.worker = event.worker + elseif event.kind == "completed" then + projection.status = "completed" + projection.output = event.output + elseif event.kind == "failed" then + projection.status = "failed" + projection.error_code = event.code + projection.retryable = event.retryable + if failure_counter then + local current = state.counters[failure_counter] or 0 + state.counters[failure_counter] = current + 1 + end + end + + projection.updated_at = at + end +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/protocol.lua b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/protocol.lua new file mode 100644 index 00000000..d7b045ec --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/protocol.lua @@ -0,0 +1,129 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type EventMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type TaskQueuedEvent = { + kind: "queued", + id: string, + queue: string, + payload: {[string]: string}, + meta: EventMeta, +} + +type TaskStartedEvent = { + kind: "started", + id: string, + worker: string, + meta: EventMeta, +} + +type TaskCompletedEvent = { + kind: "completed", + id: string, + output: string, + meta: EventMeta, +} + +type TaskFailedEvent = { + kind: "failed", + id: string, + code: string, + retryable: boolean, + meta: EventMeta, +} + +type TickEvent = { + kind: "tick", + at: time.Time, +} + +type Event = TaskQueuedEvent | TaskStartedEvent | TaskCompletedEvent | TaskFailedEvent | TickEvent + +type TaskProjection = { + id: string, + queue: string, + status: "queued" | "started" | "completed" | "failed", + worker: string?, + output: string?, + error_code: string?, + retryable: boolean?, + source: string?, + updated_at: time.Time?, +} + +type SubscriberStep = { + kind: "subscriber", + topic: string, + note: string, + projection_id: string?, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type DispatchStep = SubscriberStep | AuditStep + +type BusState = { + id: string, + started_at: time.Time, + last_event_at: time.Time?, + steps: {DispatchStep}, + projections: {[string]: TaskProjection}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type DispatchSummary = { + id: string, + total_steps: number, + projection_count: number, + failure_count: number, + last_status: string?, + elapsed_seconds: number, +} + +type SubscriberResult = {ok: true, value: string?} | {ok: false, error: AppError} +type PublishResult = {ok: true, value: string?} | {ok: false, error: AppError} +type ReplayResult = {ok: true, value: DispatchSummary} | {ok: false, error: AppError} + +type Subscriber = (BusState, Event) -> SubscriberResult +type Projector = (BusState, Event, time.Time) -> () +type StepHook = (DispatchStep, BusState) -> () + +local M = {} +M.AppError = AppError +M.EventMeta = EventMeta +M.TaskQueuedEvent = TaskQueuedEvent +M.TaskStartedEvent = TaskStartedEvent +M.TaskCompletedEvent = TaskCompletedEvent +M.TaskFailedEvent = TaskFailedEvent +M.TickEvent = TickEvent +M.Event = Event +M.TaskProjection = TaskProjection +M.DispatchStep = DispatchStep +M.BusState = BusState +M.DispatchSummary = DispatchSummary +M.SubscriberResult = SubscriberResult +M.PublishResult = PublishResult +M.ReplayResult = ReplayResult +M.Subscriber = Subscriber +M.Projector = Projector +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): EventMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/result.lua b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/result.lua new file mode 100644 index 00000000..25095ae7 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "rate_limited" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/subscriber_builder.lua b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/subscriber_builder.lua new file mode 100644 index 00000000..39fc1dbf --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime-soundness/subscriber_builder.lua @@ -0,0 +1,123 @@ +local protocol = require("protocol") +local helpers = require("helpers") + +type Decorator = (string, protocol.BusState, protocol.Event) -> string + +type SubscriberBuilder = { + name: string, + prefix: string, + required_tag: string?, + flag_name: string?, + decorator: Decorator?, + named: (self: SubscriberBuilder, name: string) -> SubscriberBuilder, + prefix_with: (self: SubscriberBuilder, prefix: string) -> SubscriberBuilder, + require_tag: (self: SubscriberBuilder, tag_name: string) -> SubscriberBuilder, + remember_flag: (self: SubscriberBuilder, flag_name: string) -> SubscriberBuilder, + decorate: (self: SubscriberBuilder, decorator: Decorator) -> SubscriberBuilder, + build: (self: SubscriberBuilder) -> protocol.Subscriber, +} + +type Builder = SubscriberBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.SubscriberBuilder = SubscriberBuilder + +function M.new(): SubscriberBuilder + local self: Builder = { + name = "subscriber", + prefix = "sub", + required_tag = nil, + flag_name = nil, + decorator = nil, + named = Builder.named, + prefix_with = Builder.prefix_with, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self +end + +function Builder:require_tag(tag_name: string): Builder + self.required_tag = tag_name + return self +end + +function Builder:remember_flag(flag_name: string): Builder + self.flag_name = flag_name + return self +end + +function Builder:decorate(decorator: Decorator): Builder + self.decorator = decorator + return self +end + +function Builder:build(): protocol.Subscriber + local name = self.name + local prefix = self.prefix + local required_tag = self.required_tag + local flag_name = self.flag_name + local decorator = self.decorator + + return function(state: protocol.BusState, event: protocol.Event): protocol.SubscriberResult + if event.kind == "tick" then + return {ok = true, value = nil} + end + + local note = prefix .. ":" .. name .. ":" .. helpers.event_label(event) + + if required_tag then + local tags = event.meta.tags + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + note = note .. ":" .. value + end + + if flag_name then + state.flags[flag_name] = true + end + if decorator then + note = decorator(note, state, event) + end + + return {ok = true, value = note} + end +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/bus.lua b/testdata/fixtures/realworld/event-bus-saga-runtime/bus.lua new file mode 100644 index 00000000..27fb0766 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/bus.lua @@ -0,0 +1,162 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local checkpoint_store = require("checkpoint_store") + +type EventBus = { + subscribers: {[string]: {protocol.Subscriber}}, + projectors: {[string]: {protocol.Projector}}, + hooks: {protocol.StepHook}, + ensure_subscribers: (self: EventBus, topic: string) -> {protocol.Subscriber}, + ensure_projectors: (self: EventBus, topic: string) -> {protocol.Projector}, + register_subscriber: (self: EventBus, topic: string, subscriber: protocol.Subscriber) -> EventBus, + register_projector: (self: EventBus, topic: string, projector: protocol.Projector) -> EventBus, + on_step: (self: EventBus, hook: protocol.StepHook) -> EventBus, + new_store: (self: EventBus, id: string, now: time.Time) -> checkpoint_store.CheckpointStore, + emit: (self: EventBus, store: checkpoint_store.CheckpointStore, step: protocol.DispatchStep, at: time.Time) -> (), + publish: (self: EventBus, store: checkpoint_store.CheckpointStore, topic: string, event: protocol.Event, at: time.Time) -> protocol.PublishResult, + replay: (self: EventBus, store: checkpoint_store.CheckpointStore, topic: string, events: {protocol.Event}, now: time.Time) -> protocol.ReplayResult, +} + +type Bus = EventBus + +local Bus = {} +Bus.__index = Bus + +local M = {} +M.EventBus = EventBus + +function M.new(): EventBus + local self: Bus = { + subscribers = {}, + projectors = {}, + hooks = {}, + ensure_subscribers = Bus.ensure_subscribers, + ensure_projectors = Bus.ensure_projectors, + register_subscriber = Bus.register_subscriber, + register_projector = Bus.register_projector, + on_step = Bus.on_step, + new_store = Bus.new_store, + emit = Bus.emit, + publish = Bus.publish, + replay = Bus.replay, + } + setmetatable(self, Bus) + return self +end + +function Bus:ensure_subscribers(topic: string): {protocol.Subscriber} + local current = self.subscribers[topic] + if current then + return current + end + + local created: {protocol.Subscriber} = {} + self.subscribers[topic] = created + return created +end + +function Bus:ensure_projectors(topic: string): {protocol.Projector} + local current = self.projectors[topic] + if current then + return current + end + + local created: {protocol.Projector} = {} + self.projectors[topic] = created + return created +end + +function Bus:register_subscriber(topic: string, subscriber: protocol.Subscriber): Bus + table.insert(self:ensure_subscribers(topic), subscriber) + return self +end + +function Bus:register_projector(topic: string, projector: protocol.Projector): Bus + table.insert(self:ensure_projectors(topic), projector) + return self +end + +function Bus:on_step(hook: protocol.StepHook): Bus + table.insert(self.hooks, hook) + return self +end + +function Bus:new_store(id: string, now: time.Time): checkpoint_store.CheckpointStore + return checkpoint_store.new(id, now) +end + +function Bus:emit(store: checkpoint_store.CheckpointStore, step: protocol.DispatchStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Bus:publish( + store: checkpoint_store.CheckpointStore, + topic: string, + event: protocol.Event, + at: time.Time +): protocol.PublishResult + local projectors = self:ensure_projectors(topic) + for _, projector in ipairs(projectors) do + projector(store.state, event, at) + end + + if event.kind == "tick" then + local audit_step: protocol.DispatchStep = {kind = "audit", note = "tick", at = event.at} + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + local subscribers = self:ensure_subscribers(topic) + for _, subscriber in ipairs(subscribers) do + local note_result: protocol.SubscriberResult = subscriber(store.state, event) + if not note_result.ok then + return {ok = false, error = note_result.error} + end + + local note = note_result.value + if note then + local subscriber_step: protocol.DispatchStep = { + kind = "subscriber", + topic = topic, + note = note, + projection_id = helpers.event_id(event), + } + self:emit(store, subscriber_step, at) + end + end + + if event.kind == "completed" then + return {ok = true, value = "completed"} + end + if event.kind == "failed" then + return {ok = true, value = "failed"} + end + return {ok = true, value = nil} +end + +function Bus:replay( + store: checkpoint_store.CheckpointStore, + topic: string, + events: {protocol.Event}, + now: time.Time +): protocol.ReplayResult + local last_status: string? = nil + + for _, event in ipairs(events) do + local publish_result: protocol.PublishResult = self:publish(store, topic, event, now) + if not publish_result.ok then + return {ok = false, error = publish_result.error} + end + if publish_result.value then + last_status = publish_result.value + end + end + + return {ok = true, value = store:summarize(now, last_status)} +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/checkpoint_store.lua b/testdata/fixtures/realworld/event-bus-saga-runtime/checkpoint_store.lua new file mode 100644 index 00000000..83339e3d --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/checkpoint_store.lua @@ -0,0 +1,112 @@ +local time = require("time") +local protocol = require("protocol") + +type CheckpointStore = { + state: protocol.BusState, + touch: (self: CheckpointStore, at: time.Time) -> CheckpointStore, + push_step: (self: CheckpointStore, step: protocol.DispatchStep, at: time.Time) -> CheckpointStore, + ensure_projection: (self: CheckpointStore, id: string, queue: string, at: time.Time) -> protocol.TaskProjection, + lookup_projection: (self: CheckpointStore, id: string) -> protocol.TaskProjection?, + increment: (self: CheckpointStore, name: string) -> integer, + summarize: (self: CheckpointStore, now: time.Time, last_status: string?) -> protocol.DispatchSummary, +} + +type Store = CheckpointStore + +local Store = {} +Store.__index = Store + +local M = {} +M.CheckpointStore = CheckpointStore + +function M.new(id: string, now: time.Time): CheckpointStore + local self: Store = { + state = { + id = id, + started_at = now, + last_event_at = nil, + steps = {}, + projections = {}, + counters = {}, + flags = {}, + }, + touch = Store.touch, + push_step = Store.push_step, + ensure_projection = Store.ensure_projection, + lookup_projection = Store.lookup_projection, + increment = Store.increment, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:touch(at: time.Time): Store + self.state.last_event_at = at + return self +end + +function Store:push_step(step: protocol.DispatchStep, at: time.Time): Store + table.insert(self.state.steps, step) + return self:touch(at) +end + +function Store:ensure_projection(id: string, queue: string, at: time.Time): protocol.TaskProjection + local current = self.state.projections[id] + if current then + if current.updated_at == nil then + current.updated_at = at + end + if current.queue == "unknown" and queue ~= "unknown" then + current.queue = queue + end + return current + end + + local created: protocol.TaskProjection = { + id = id, + queue = queue, + status = "queued", + worker = nil, + output = nil, + error_code = nil, + retryable = nil, + source = nil, + updated_at = at, + } + self.state.projections[id] = created + return created +end + +function Store:lookup_projection(id: string): protocol.TaskProjection? + return self.state.projections[id] +end + +function Store:increment(name: string): integer + local current = self.state.counters[name] or 0 + local next_value = current + 1 + self.state.counters[name] = next_value + return next_value +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.DispatchSummary + local projection_count = 0 + for _, _ in pairs(self.state.projections) do + projection_count = projection_count + 1 + end + + local failure_count = self.state.counters["failed"] or 0 + local seen_at = self.state.last_event_at or self.state.started_at + local elapsed = now:sub(seen_at) + + return { + id = self.state.id, + total_steps = #self.state.steps, + projection_count = projection_count, + failure_count = failure_count, + last_status = last_status, + elapsed_seconds = elapsed:seconds(), + } +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/helpers.lua b/testdata/fixtures/realworld/event-bus-saga-runtime/helpers.lua new file mode 100644 index 00000000..7d7226e5 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/helpers.lua @@ -0,0 +1,43 @@ +local protocol = require("protocol") + +local M = {} + +function M.event_label(event: protocol.Event): string + if event.kind == "queued" then + return "queued:" .. event.queue + end + if event.kind == "started" then + return "started:" .. event.worker + end + if event.kind == "completed" then + return "completed" + end + if event.kind == "failed" then + return "failed:" .. event.code + end + return "tick" +end + +function M.event_id(event: protocol.Event): string? + if event.kind == "tick" then + return nil + end + return event.id +end + +function M.source_tag(event: protocol.Event): string? + if event.kind == "tick" then + return nil + end + local tags = event.meta.tags + if not tags then + return nil + end + return tags["source"] +end + +function M.status_name(status: string?): string + return status or "unknown" +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/main.lua b/testdata/fixtures/realworld/event-bus-saga-runtime/main.lua new file mode 100644 index 00000000..4d8609a8 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/main.lua @@ -0,0 +1,224 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local helpers = require("helpers") +local subscriber_builder = require("subscriber_builder") +local projector_builder = require("projector_builder") +local bus = require("bus") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() + +local observed_notes: {[string]: string} = {} +local observed_audits: {string} = {} +local last_bus_id: string? = nil + +local source_subscriber = subscriber_builder.new() + :named("source") + :prefix_with("evt") + :require_tag("source") + :remember_flag("seen_source") + :decorate(function(note: string, state: protocol.BusState, event: protocol.Event): string + local suffix = helpers.event_label(event) + if state.flags["replayed"] then + suffix = suffix .. ":replayed" + end + return note .. ":" .. suffix + end) + :build() + +local final_subscriber = subscriber_builder.new() + :named("final") + :prefix_with("final") + :decorate(function(note: string, _state: protocol.BusState, event: protocol.Event): string + if event.kind == "completed" then + return note .. ":" .. event.output + end + if event.kind == "failed" then + return note .. ":" .. event.code + end + return note + end) + :build() + +local projector = projector_builder.new() + :track_queue("priority") + :count_failures_as("failed") + :capture_source("source") + :build() + +local app = bus.new() + :register_projector("tasks", projector) + :register_subscriber("tasks", source_subscriber) + :register_subscriber("tasks", final_subscriber) + +app:on_step(function(step: protocol.DispatchStep, state: protocol.BusState) + last_bus_id = state.id + + if step.kind == "subscriber" then + observed_notes[step.topic .. ":" .. tostring(#observed_audits + 1)] = step.note + if step.projection_id then + local projection_id: string = step.projection_id + end + else + table.insert(observed_audits, step.note) + local at_seconds: integer = step.at:unix() + end +end) + +local queued_one: protocol.TaskQueuedEvent = { + kind = "queued", + id = "job-1", + queue = "priority", + payload = {task = "search"}, + meta = protocol.meta("trace-1", {source = "api", lane = "priority"}), +} + +local started_one: protocol.TaskStartedEvent = { + kind = "started", + id = "job-1", + worker = "worker-a", + meta = protocol.meta("trace-2", {source = "worker"}), +} + +local completed_one: protocol.TaskCompletedEvent = { + kind = "completed", + id = "job-1", + output = "done", + meta = protocol.meta("trace-3", {source = "worker"}), +} + +local queued_two: protocol.TaskQueuedEvent = { + kind = "queued", + id = "job-2", + queue = "priority", + payload = {task = "profile"}, + meta = protocol.meta("trace-4", {source = "api"}), +} + +local failed_two: protocol.TaskFailedEvent = { + kind = "failed", + id = "job-2", + code = "rate_limited", + retryable = true, + meta = protocol.meta("trace-5", {source = "worker"}), +} + +local tick: protocol.TickEvent = { + kind = "tick", + at = now, +} + +local events: {protocol.Event} = { + queued_one, + started_one, + completed_one, + queued_two, + failed_two, + tick, +} + +local store = app:new_store("bus-1", now) +store.state.flags["replayed"] = true + +local summary_result = app:replay(store, "tasks", events, now) +if not summary_result.ok then + local message: string = summary_result.error.message + local retryable: boolean = summary_result.error.retryable +else + local summary = summary_result.value + local bus_id: string = summary.id + local total_steps: number = summary.total_steps + local projection_count: number = summary.projection_count + local failure_count: number = summary.failure_count + local elapsed_seconds: number = summary.elapsed_seconds + local last_status: string? = summary.last_status +end + +local summary_label = result.map(summary_result, function(summary: protocol.DispatchSummary): string + return summary.id .. ":" .. tostring(summary.projection_count) +end) + +if summary_label.ok then + local label: string = summary_label.value +end + +local summary_id = result.and_then(summary_result, function(summary: protocol.DispatchSummary): StringResult + if summary.failure_count == 0 then + return { + ok = false, + error = { + code = "invalid", + message = "expected a failure", + retryable = false, + }, + } + end + + return { + ok = true, + value = summary.id, + } +end) + +if summary_id.ok then + local stable_id: string = summary_id.value +end + +local projection_one = store:lookup_projection("job-1") +if projection_one then + local status: string = projection_one.status + local worker = projection_one.worker + if worker then + local stable_worker: string = worker + end + local output = projection_one.output + if output then + local stable_output: string = output + end + local source = projection_one.source + if source then + local stable_source: string = source + end +end + +local projection_two = store:lookup_projection("job-2") +if projection_two then + local failed_status: string = projection_two.status + local error_code = projection_two.error_code + if error_code then + local stable_code: string = error_code + end + local retryable = projection_two.retryable + if retryable ~= nil then + local stable_retryable: boolean = retryable + end +end + +local missing = store:lookup_projection("missing") +if missing == nil then + local fallback: string = "missing" +end + +for key, note in pairs(observed_notes) do + local stable_key: string = key + local stable_note: string = note +end + +for _, note in ipairs(observed_audits) do + local audit_note: string = note +end + +if last_bus_id ~= nil then + local stable_bus_id: string = last_bus_id +end + +local event_source = helpers.source_tag(queued_one) +if event_source then + local stable_event_source: string = event_source +end + +local seen_at = store.state.last_event_at or store.state.started_at +local elapsed = now:sub(seen_at) +local seconds: number = elapsed:seconds() diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/manifest.json b/testdata/fixtures/realworld/event-bus-saga-runtime/manifest.json new file mode 100644 index 00000000..297dbd89 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/manifest.json @@ -0,0 +1,17 @@ +{ + "description": "Application-shaped event bus mixing closure-built subscribers and projectors, metatable-backed runtime/store objects, discriminated events, dynamic registries, optional projections, callbacks, counters, and time-based replay summaries.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "checkpoint_store.lua", + "subscriber_builder.lua", + "projector_builder.lua", + "bus.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/projector_builder.lua b/testdata/fixtures/realworld/event-bus-saga-runtime/projector_builder.lua new file mode 100644 index 00000000..3a38eb3d --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/projector_builder.lua @@ -0,0 +1,119 @@ +local protocol = require("protocol") + +type ProjectorBuilder = { + tracked_queue: string?, + failure_counter: string?, + source_key: string?, + track_queue: (self: ProjectorBuilder, queue: string) -> ProjectorBuilder, + count_failures_as: (self: ProjectorBuilder, counter_name: string) -> ProjectorBuilder, + capture_source: (self: ProjectorBuilder, tag_key: string) -> ProjectorBuilder, + build: (self: ProjectorBuilder) -> protocol.Projector, +} + +type Builder = ProjectorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ProjectorBuilder = ProjectorBuilder + +function M.new(): ProjectorBuilder + local self: Builder = { + tracked_queue = nil, + failure_counter = nil, + source_key = nil, + track_queue = Builder.track_queue, + count_failures_as = Builder.count_failures_as, + capture_source = Builder.capture_source, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:track_queue(queue: string): Builder + self.tracked_queue = queue + return self +end + +function Builder:count_failures_as(counter_name: string): Builder + self.failure_counter = counter_name + return self +end + +function Builder:capture_source(tag_key: string): Builder + self.source_key = tag_key + return self +end + +function Builder:build(): protocol.Projector + local tracked_queue = self.tracked_queue + local failure_counter = self.failure_counter + local source_key = self.source_key + + return function(state: protocol.BusState, event: protocol.Event, at) + if event.kind == "tick" then + return + end + + local queue_name = "unknown" + if event.kind == "queued" then + queue_name = event.queue + end + + if tracked_queue and event.kind == "queued" and event.queue ~= tracked_queue then + return + end + + local projection = state.projections[event.id] + if not projection then + projection = { + id = event.id, + queue = queue_name, + status = "queued", + worker = nil, + output = nil, + error_code = nil, + retryable = nil, + source = nil, + updated_at = at, + } + state.projections[event.id] = projection + end + + if projection.queue == "unknown" and queue_name ~= "unknown" then + projection.queue = queue_name + end + + local tags = event.meta.tags + if source_key and tags then + local source = tags[source_key] + if source then + projection.source = source + end + end + + if event.kind == "queued" then + projection.status = "queued" + elseif event.kind == "started" then + projection.status = "started" + projection.worker = event.worker + elseif event.kind == "completed" then + projection.status = "completed" + projection.output = event.output + elseif event.kind == "failed" then + projection.status = "failed" + projection.error_code = event.code + projection.retryable = event.retryable + if failure_counter then + local current = state.counters[failure_counter] or 0 + state.counters[failure_counter] = current + 1 + end + end + + projection.updated_at = at + end +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/protocol.lua b/testdata/fixtures/realworld/event-bus-saga-runtime/protocol.lua new file mode 100644 index 00000000..d7b045ec --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/protocol.lua @@ -0,0 +1,129 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type EventMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type TaskQueuedEvent = { + kind: "queued", + id: string, + queue: string, + payload: {[string]: string}, + meta: EventMeta, +} + +type TaskStartedEvent = { + kind: "started", + id: string, + worker: string, + meta: EventMeta, +} + +type TaskCompletedEvent = { + kind: "completed", + id: string, + output: string, + meta: EventMeta, +} + +type TaskFailedEvent = { + kind: "failed", + id: string, + code: string, + retryable: boolean, + meta: EventMeta, +} + +type TickEvent = { + kind: "tick", + at: time.Time, +} + +type Event = TaskQueuedEvent | TaskStartedEvent | TaskCompletedEvent | TaskFailedEvent | TickEvent + +type TaskProjection = { + id: string, + queue: string, + status: "queued" | "started" | "completed" | "failed", + worker: string?, + output: string?, + error_code: string?, + retryable: boolean?, + source: string?, + updated_at: time.Time?, +} + +type SubscriberStep = { + kind: "subscriber", + topic: string, + note: string, + projection_id: string?, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type DispatchStep = SubscriberStep | AuditStep + +type BusState = { + id: string, + started_at: time.Time, + last_event_at: time.Time?, + steps: {DispatchStep}, + projections: {[string]: TaskProjection}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type DispatchSummary = { + id: string, + total_steps: number, + projection_count: number, + failure_count: number, + last_status: string?, + elapsed_seconds: number, +} + +type SubscriberResult = {ok: true, value: string?} | {ok: false, error: AppError} +type PublishResult = {ok: true, value: string?} | {ok: false, error: AppError} +type ReplayResult = {ok: true, value: DispatchSummary} | {ok: false, error: AppError} + +type Subscriber = (BusState, Event) -> SubscriberResult +type Projector = (BusState, Event, time.Time) -> () +type StepHook = (DispatchStep, BusState) -> () + +local M = {} +M.AppError = AppError +M.EventMeta = EventMeta +M.TaskQueuedEvent = TaskQueuedEvent +M.TaskStartedEvent = TaskStartedEvent +M.TaskCompletedEvent = TaskCompletedEvent +M.TaskFailedEvent = TaskFailedEvent +M.TickEvent = TickEvent +M.Event = Event +M.TaskProjection = TaskProjection +M.DispatchStep = DispatchStep +M.BusState = BusState +M.DispatchSummary = DispatchSummary +M.SubscriberResult = SubscriberResult +M.PublishResult = PublishResult +M.ReplayResult = ReplayResult +M.Subscriber = Subscriber +M.Projector = Projector +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): EventMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/result.lua b/testdata/fixtures/realworld/event-bus-saga-runtime/result.lua new file mode 100644 index 00000000..25095ae7 --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "rate_limited" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/event-bus-saga-runtime/subscriber_builder.lua b/testdata/fixtures/realworld/event-bus-saga-runtime/subscriber_builder.lua new file mode 100644 index 00000000..39fc1dbf --- /dev/null +++ b/testdata/fixtures/realworld/event-bus-saga-runtime/subscriber_builder.lua @@ -0,0 +1,123 @@ +local protocol = require("protocol") +local helpers = require("helpers") + +type Decorator = (string, protocol.BusState, protocol.Event) -> string + +type SubscriberBuilder = { + name: string, + prefix: string, + required_tag: string?, + flag_name: string?, + decorator: Decorator?, + named: (self: SubscriberBuilder, name: string) -> SubscriberBuilder, + prefix_with: (self: SubscriberBuilder, prefix: string) -> SubscriberBuilder, + require_tag: (self: SubscriberBuilder, tag_name: string) -> SubscriberBuilder, + remember_flag: (self: SubscriberBuilder, flag_name: string) -> SubscriberBuilder, + decorate: (self: SubscriberBuilder, decorator: Decorator) -> SubscriberBuilder, + build: (self: SubscriberBuilder) -> protocol.Subscriber, +} + +type Builder = SubscriberBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.SubscriberBuilder = SubscriberBuilder + +function M.new(): SubscriberBuilder + local self: Builder = { + name = "subscriber", + prefix = "sub", + required_tag = nil, + flag_name = nil, + decorator = nil, + named = Builder.named, + prefix_with = Builder.prefix_with, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self +end + +function Builder:require_tag(tag_name: string): Builder + self.required_tag = tag_name + return self +end + +function Builder:remember_flag(flag_name: string): Builder + self.flag_name = flag_name + return self +end + +function Builder:decorate(decorator: Decorator): Builder + self.decorator = decorator + return self +end + +function Builder:build(): protocol.Subscriber + local name = self.name + local prefix = self.prefix + local required_tag = self.required_tag + local flag_name = self.flag_name + local decorator = self.decorator + + return function(state: protocol.BusState, event: protocol.Event): protocol.SubscriberResult + if event.kind == "tick" then + return {ok = true, value = nil} + end + + local note = prefix .. ":" .. name .. ":" .. helpers.event_label(event) + + if required_tag then + local tags = event.meta.tags + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + note = note .. ":" .. value + end + + if flag_name then + state.flags[flag_name] = true + end + if decorator then + note = decorator(note, state, event) + end + + return {ok = true, value = note} + end +end + +return M diff --git a/testdata/fixtures/realworld/fluent-prompt-builder/manifest.json b/testdata/fixtures/realworld/fluent-prompt-builder/manifest.json index 92058167..a6474a2b 100644 --- a/testdata/fixtures/realworld/fluent-prompt-builder/manifest.json +++ b/testdata/fixtures/realworld/fluent-prompt-builder/manifest.json @@ -1,4 +1,4 @@ { "files": ["builder.lua", "main.lua"], - "check": {"errors": 4} + "check": {"errors": 0} } diff --git a/testdata/fixtures/realworld/generic-registry/main.lua b/testdata/fixtures/realworld/generic-registry/main.lua index f0972e14..dd72a5a7 100644 --- a/testdata/fixtures/realworld/generic-registry/main.lua +++ b/testdata/fixtures/realworld/generic-registry/main.lua @@ -4,12 +4,12 @@ local r = plugins.setup() local result, err = r:call("greet", {name = "Alice"}) if err == nil and result then - local output: string = result.output + local output = result.output end local result2, err2 = r:call("count", {items = {"a", "b", "c"}}) if err2 == nil and result2 then - local output: string = result2.output + local output = result2.output end local missing, missing_err = r:call("nonexistent", {}) diff --git a/testdata/fixtures/realworld/generic-registry/manifest.json b/testdata/fixtures/realworld/generic-registry/manifest.json index c3e10b7b..967c15ef 100644 --- a/testdata/fixtures/realworld/generic-registry/manifest.json +++ b/testdata/fixtures/realworld/generic-registry/manifest.json @@ -1,4 +1 @@ -{ - "files": ["registry.lua", "plugins.lua", "main.lua"], - "check": {"errors": 4} -} +{"files": ["registry.lua", "plugins.lua", "main.lua"], "check": {"errors": 0}} diff --git a/testdata/fixtures/realworld/generic-registry/plugins.lua b/testdata/fixtures/realworld/generic-registry/plugins.lua index d248c71e..bdd18f6c 100644 --- a/testdata/fixtures/realworld/generic-registry/plugins.lua +++ b/testdata/fixtures/realworld/generic-registry/plugins.lua @@ -5,7 +5,7 @@ type PluginResult = {output: string, metadata: {[string]: any}} local M = {} -function M.setup(): Registry +function M.setup(): registry.Registry local r = registry.new() r:register("greet", function(args: {[string]: any}): (PluginResult?, string?) diff --git a/testdata/fixtures/realworld/index-presence-laws/main.lua b/testdata/fixtures/realworld/index-presence-laws/main.lua new file mode 100644 index 00000000..d32de260 --- /dev/null +++ b/testdata/fixtures/realworld/index-presence-laws/main.lua @@ -0,0 +1,25 @@ +type Message = { + _topic: string, + topic: (self: Message) -> string, +} + +local messages: {[string]: Message} = {} + +if not messages["root"] then + messages["root"] = { + _topic = "installed", + topic = function(self: Message): string + return self._topic + end, + } +end + +local installed: string = messages["root"]:topic() + +local cached = messages["root"] +if cached then + local cached_topic: string = cached:topic() +end + +assert(messages["root"]) +local asserted: string = messages["root"]:topic() diff --git a/testdata/fixtures/realworld/index-presence-laws/manifest.json b/testdata/fixtures/realworld/index-presence-laws/manifest.json new file mode 100644 index 00000000..142ae27a --- /dev/null +++ b/testdata/fixtures/realworld/index-presence-laws/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/realworld/lookup-table-cast/manifest.json b/testdata/fixtures/realworld/lookup-table-cast/manifest.json index 08ac459d..b118ec7b 100644 --- a/testdata/fixtures/realworld/lookup-table-cast/manifest.json +++ b/testdata/fixtures/realworld/lookup-table-cast/manifest.json @@ -1 +1 @@ -{"files": ["constants.lua", "mapper.lua", "main.lua"], "check": {"errors": 11}} +{"files": ["constants.lua", "mapper.lua", "main.lua"]} diff --git a/testdata/fixtures/realworld/lookup-table-cast/mapper.lua b/testdata/fixtures/realworld/lookup-table-cast/mapper.lua index 59410c4c..21691184 100644 --- a/testdata/fixtures/realworld/lookup-table-cast/mapper.lua +++ b/testdata/fixtures/realworld/lookup-table-cast/mapper.lua @@ -6,23 +6,26 @@ type StatusCodeMap = {[number]: string} local M = {} -M.finish_reasons: FinishReasonMap = {} -M.finish_reasons["end_turn"] = constants.FINISH_REASON.STOP -M.finish_reasons["max_tokens"] = constants.FINISH_REASON.LENGTH -M.finish_reasons["stop_sequence"] = constants.FINISH_REASON.STOP -M.finish_reasons["tool_use"] = constants.FINISH_REASON.TOOL_CALL - -M.error_types: ErrorTypeMap = {} -M.error_types["invalid_request_error"] = constants.ERROR_TYPE.INVALID_REQUEST -M.error_types["authentication_error"] = constants.ERROR_TYPE.AUTHENTICATION -M.error_types["rate_limit_error"] = constants.ERROR_TYPE.RATE_LIMIT -M.error_types["api_error"] = constants.ERROR_TYPE.SERVER_ERROR - -M.status_codes: StatusCodeMap = {} -M.status_codes[400] = constants.ERROR_TYPE.INVALID_REQUEST -M.status_codes[401] = constants.ERROR_TYPE.AUTHENTICATION -M.status_codes[429] = constants.ERROR_TYPE.RATE_LIMIT -M.status_codes[500] = constants.ERROR_TYPE.SERVER_ERROR +local finish_reasons: FinishReasonMap = {} +finish_reasons["end_turn"] = constants.FINISH_REASON.STOP +finish_reasons["max_tokens"] = constants.FINISH_REASON.LENGTH +finish_reasons["stop_sequence"] = constants.FINISH_REASON.STOP +finish_reasons["tool_use"] = constants.FINISH_REASON.TOOL_CALL +M.finish_reasons = finish_reasons + +local error_types: ErrorTypeMap = {} +error_types["invalid_request_error"] = constants.ERROR_TYPE.INVALID_REQUEST +error_types["authentication_error"] = constants.ERROR_TYPE.AUTHENTICATION +error_types["rate_limit_error"] = constants.ERROR_TYPE.RATE_LIMIT +error_types["api_error"] = constants.ERROR_TYPE.SERVER_ERROR +M.error_types = error_types + +local status_codes: StatusCodeMap = {} +status_codes[400] = constants.ERROR_TYPE.INVALID_REQUEST +status_codes[401] = constants.ERROR_TYPE.AUTHENTICATION +status_codes[429] = constants.ERROR_TYPE.RATE_LIMIT +status_codes[500] = constants.ERROR_TYPE.SERVER_ERROR +M.status_codes = status_codes function M.map_finish_reason(api_reason: string): string return M.finish_reasons[api_reason] or "unknown" diff --git a/testdata/fixtures/realworld/metatable-oop/counter.lua b/testdata/fixtures/realworld/metatable-oop/counter.lua index 821b06f9..9e9e55b2 100644 --- a/testdata/fixtures/realworld/metatable-oop/counter.lua +++ b/testdata/fixtures/realworld/metatable-oop/counter.lua @@ -3,12 +3,12 @@ local class = require("class") type Counter = { _count: number, _name: string, - _emitter: EventEmitter, + _emitter: class.EventEmitter, increment: (self: Counter) -> (), decrement: (self: Counter) -> (), get: (self: Counter) -> number, name: (self: Counter) -> string, - on_change: (self: Counter, handler: (self: EventEmitter, event: string, data: any) -> ()) -> Counter, + on_change: (self: Counter, handler: (self: class.EventEmitter, event: string, data: any) -> ()) -> Counter, } local Counter = {} @@ -48,7 +48,7 @@ function Counter:name(): string return self._name end -function Counter:on_change(handler: (self: EventEmitter, event: string, data: any) -> ()): Counter +function Counter:on_change(handler: (self: class.EventEmitter, event: string, data: any) -> ()): Counter self._emitter:on("change", handler) return self end diff --git a/testdata/fixtures/realworld/metatable-oop/manifest.json b/testdata/fixtures/realworld/metatable-oop/manifest.json index 4f567b9e..7c185940 100644 --- a/testdata/fixtures/realworld/metatable-oop/manifest.json +++ b/testdata/fixtures/realworld/metatable-oop/manifest.json @@ -1,4 +1,4 @@ { "files": ["class.lua", "counter.lua", "main.lua"], - "check": {"errors": 4} + "check": {"errors": 0} } diff --git a/testdata/fixtures/realworld/metatable-shared-self/builder.lua b/testdata/fixtures/realworld/metatable-shared-self/builder.lua new file mode 100644 index 00000000..af037a29 --- /dev/null +++ b/testdata/fixtures/realworld/metatable-shared-self/builder.lua @@ -0,0 +1,31 @@ +type Builder = { + _name: string, + rename: (self: Builder, name: string) -> Builder, + name: (self: Builder) -> string, +} + +local Builder = {} +Builder.__index = Builder + +function Builder.new(name: string): Builder + local self: Builder = { + _name = name, + rename = Builder.rename, + name = Builder.name, + } + setmetatable(self, Builder) + return self +end + +function Builder:rename(name: string): Builder + self._name = name + return self +end + +function Builder:name(): string + return self._name +end + +local M = {} +M.new = Builder.new +return M diff --git a/testdata/fixtures/realworld/metatable-shared-self/main.lua b/testdata/fixtures/realworld/metatable-shared-self/main.lua new file mode 100644 index 00000000..995df6df --- /dev/null +++ b/testdata/fixtures/realworld/metatable-shared-self/main.lua @@ -0,0 +1,5 @@ +local builder = require("builder") + +local b = builder.new("first") +local renamed = b:rename("second") +local name: string = renamed:name() diff --git a/testdata/fixtures/realworld/metatable-shared-self/manifest.json b/testdata/fixtures/realworld/metatable-shared-self/manifest.json new file mode 100644 index 00000000..a6474a2b --- /dev/null +++ b/testdata/fixtures/realworld/metatable-shared-self/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["builder.lua", "main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/realworld/middleware-session-router-soundness/main.lua b/testdata/fixtures/realworld/middleware-session-router-soundness/main.lua new file mode 100644 index 00000000..b3dba3a0 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router-soundness/main.lua @@ -0,0 +1,70 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local session_store = require("session_store") +local middleware_builder = require("middleware_builder") +local route_builder = require("route_builder") +local router = require("router") + +local now = time.now() +local store = session_store.new() + +store:save("token-1", { + id = "s-1", + user_id = "u-1", + scopes = {["chat.read"] = true}, + last_seen = nil, + attributes = nil, +}) + +local auth = middleware_builder.new() + :named("auth") + :require_header("authorization") + :load_sessions_from(store) + :require_scope("chat.read") + :copy_tag_to_local("source", "source") + :build() + +local unsafe_route = route_builder.new() + :key("GET /rooms/show") + :use(auth) + :handle(function(ctx: protocol.RequestContext): protocol.ResponseResult + local user_id: string = ctx.session.user_id -- expect-error + local room_id: string = ctx.params["room_id"] -- expect-error + return { + ok = true, + value = { + status = 200, + body = user_id .. ":" .. room_id, + headers = {["x-user"] = user_id}, + }, + } + end) + :build() + +local app = router.new():register_route(unsafe_route) + +local room_request: protocol.HttpRequest = { + kind = "http", + method = "GET", + path = "/rooms/show", + headers = {authorization = "token-1"}, + params = {room_id = "room-1"}, + body = nil, + meta = protocol.meta("trace-1", nil), +} + +local response = app:dispatch(room_request, now) +if response.ok then + local header: string = response.value.headers["x-user"] -- expect-error +end + +local snapshot = store:lookup("token-1") +if snapshot then + local elapsed = now:sub(snapshot.last_seen) -- expect-error +end + +local tags = room_request.meta.tags +local source: string = tags["source"] -- expect-error + +local request_room: string = room_request.params["room_id"] -- expect-error diff --git a/testdata/fixtures/realworld/middleware-session-router-soundness/manifest.json b/testdata/fixtures/realworld/middleware-session-router-soundness/manifest.json new file mode 100644 index 00000000..25442b50 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router-soundness/manifest.json @@ -0,0 +1,13 @@ +{ + "description": "Soundness checks for the middleware session router domain: unsafe optional session, map, header, and time uses must be rejected.", + "files": [ + "result.lua", + "protocol.lua", + "session_store.lua", + "middleware_builder.lua", + "route_builder.lua", + "router.lua", + "main.lua" + ], + "packages": ["time"] +} diff --git a/testdata/fixtures/realworld/middleware-session-router-soundness/middleware_builder.lua b/testdata/fixtures/realworld/middleware-session-router-soundness/middleware_builder.lua new file mode 100644 index 00000000..784ee880 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router-soundness/middleware_builder.lua @@ -0,0 +1,135 @@ +local result = require("result") +local protocol = require("protocol") +local session_store = require("session_store") + +type MiddlewareBuilder = { + name: string, + required_header: string?, + local_tag_key: string?, + local_field: string?, + store: session_store.SessionStore?, + required_scope: string?, + named: (self: MiddlewareBuilder, name: string) -> MiddlewareBuilder, + require_header: (self: MiddlewareBuilder, header: string) -> MiddlewareBuilder, + copy_tag_to_local: (self: MiddlewareBuilder, tag_key: string, local_field: string) -> MiddlewareBuilder, + load_sessions_from: (self: MiddlewareBuilder, store: session_store.SessionStore) -> MiddlewareBuilder, + require_scope: (self: MiddlewareBuilder, scope: string) -> MiddlewareBuilder, + build: (self: MiddlewareBuilder) -> protocol.Middleware, +} + +type Builder = MiddlewareBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): MiddlewareBuilder + local self: Builder = { + name = "middleware", + required_header = nil, + local_tag_key = nil, + local_field = nil, + store = nil, + required_scope = nil, + named = Builder.named, + require_header = Builder.require_header, + copy_tag_to_local = Builder.copy_tag_to_local, + load_sessions_from = Builder.load_sessions_from, + require_scope = Builder.require_scope, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_header(header: string): Builder + self.required_header = header + return self +end + +function Builder:copy_tag_to_local(tag_key: string, local_field: string): Builder + self.local_tag_key = tag_key + self.local_field = local_field + return self +end + +function Builder:load_sessions_from(store: session_store.SessionStore): Builder + self.store = store + return self +end + +function Builder:require_scope(scope: string): Builder + self.required_scope = scope + return self +end + +function Builder:build(): protocol.Middleware + local name = self.name + local required_header = self.required_header + local local_tag_key = self.local_tag_key + local local_field = self.local_field + local store = self.store + local required_scope = self.required_scope + + return function(ctx: protocol.RequestContext): protocol.MiddlewareResult + local next_ctx: protocol.RequestContext = ctx + if required_header then + local token = next_ctx.request.headers[required_header] + if not token then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing header: " .. required_header, + retryable = false, + }, + } + end + + if store then + local snapshot = store:lookup(token) + if not snapshot then + return { + ok = false, + error = { + code = "not_found", + message = name .. " missing session", + retryable = false, + }, + } + end + if required_scope and not snapshot.scopes[required_scope] then + return { + ok = false, + error = { + code = "invalid", + message = name .. " denied", + retryable = false, + }, + } + end + next_ctx.session = snapshot + end + end + + if local_tag_key and local_field then + local tags = next_ctx.request.meta.tags + if tags then + local tag = tags[local_tag_key] + if tag then + next_ctx.locals[local_field] = tag + end + end + end + + return {ok = true, value = next_ctx} + end +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router-soundness/protocol.lua b/testdata/fixtures/realworld/middleware-session-router-soundness/protocol.lua new file mode 100644 index 00000000..670ec999 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router-soundness/protocol.lua @@ -0,0 +1,84 @@ +local time = require("time") +local result = require("result") + +type RequestMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type HttpRequest = { + kind: "http", + method: "GET" | "POST", + path: string, + headers: {[string]: string}, + params: {[string]: string}?, + body: string?, + meta: RequestMeta, +} + +type TimerRequest = { + kind: "timer", + id: string, + at: time.Time, + meta: RequestMeta, +} + +type Request = HttpRequest | TimerRequest + +type SessionSnapshot = { + id: string, + user_id: string, + scopes: {[string]: boolean}, + last_seen: time.Time?, + attributes: {[string]: string}?, +} + +type RequestContext = { + request: HttpRequest, + params: {[string]: string}, + locals: {[string]: string}, + session: SessionSnapshot?, +} + +type Response = { + status: integer, + body: string, + headers: {[string]: string}, +} + +type AppError = result.AppError +type ResponseResult = {ok: true, value: Response} | {ok: false, error: AppError} +type MiddlewareResult = {ok: true, value: RequestContext} | {ok: false, error: AppError} +type Middleware = (RequestContext) -> MiddlewareResult +type RouteHandler = (RequestContext) -> ResponseResult +type AfterHook = (RequestContext, Response) -> () + +type Route = { + key: string, + middlewares: {Middleware}, + handle: RouteHandler, +} + +local M = {} +M.RequestMeta = RequestMeta +M.HttpRequest = HttpRequest +M.TimerRequest = TimerRequest +M.Request = Request +M.SessionSnapshot = SessionSnapshot +M.RequestContext = RequestContext +M.Response = Response +M.ResponseResult = ResponseResult +M.MiddlewareResult = MiddlewareResult +M.Middleware = Middleware +M.RouteHandler = RouteHandler +M.AfterHook = AfterHook +M.Route = Route + +function M.meta(trace_id: string, tags: {[string]: string}?): RequestMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router-soundness/result.lua b/testdata/fixtures/realworld/middleware-session-router-soundness/result.lua new file mode 100644 index 00000000..25095ae7 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router-soundness/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "rate_limited" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router-soundness/route_builder.lua b/testdata/fixtures/realworld/middleware-session-router-soundness/route_builder.lua new file mode 100644 index 00000000..4dd5fee0 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router-soundness/route_builder.lua @@ -0,0 +1,128 @@ +local result = require("result") +local protocol = require("protocol") + +type BodyDecorator = (string, protocol.RequestContext) -> string + +type RouteBuilder = { + route_key: string, + middlewares: {protocol.Middleware}, + required_param: string?, + decorator: BodyDecorator?, + handler: protocol.RouteHandler, + key: (self: RouteBuilder, route_key: string) -> RouteBuilder, + use: (self: RouteBuilder, middleware: protocol.Middleware) -> RouteBuilder, + require_param: (self: RouteBuilder, param: string) -> RouteBuilder, + decorate_body: (self: RouteBuilder, decorator: BodyDecorator) -> RouteBuilder, + handle: (self: RouteBuilder, handler: protocol.RouteHandler) -> RouteBuilder, + build: (self: RouteBuilder) -> protocol.Route, +} + +type Builder = RouteBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +local function missing_handler(_ctx: protocol.RequestContext): protocol.ResponseResult + return { + ok = false, + error = { + code = "invalid", + message = "missing handler", + retryable = false, + }, + } +end + +function M.new(): RouteBuilder + local self: Builder = { + route_key = "GET /", + middlewares = {}, + required_param = nil, + decorator = nil, + handler = missing_handler, + key = Builder.key, + use = Builder.use, + require_param = Builder.require_param, + decorate_body = Builder.decorate_body, + handle = Builder.handle, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:key(route_key: string): Builder + self.route_key = route_key + return self +end + +function Builder:use(middleware: protocol.Middleware): Builder + table.insert(self.middlewares, middleware) + return self +end + +function Builder:require_param(param: string): Builder + self.required_param = param + return self +end + +function Builder:decorate_body(decorator: BodyDecorator): Builder + self.decorator = decorator + return self +end + +function Builder:handle(handler: protocol.RouteHandler): Builder + self.handler = handler + return self +end + +function Builder:build(): protocol.Route + local route_key = self.route_key + local middlewares = self.middlewares + local required_param = self.required_param + local decorator = self.decorator + local handler = self.handler + + return { + key = route_key, + middlewares = middlewares, + handle = function(ctx: protocol.RequestContext): protocol.ResponseResult + if required_param then + local raw = ctx.params[required_param] + if not raw then + return { + ok = false, + error = { + code = "invalid", + message = "missing param: " .. required_param, + retryable = false, + }, + } + end + end + + local response_result = handler(ctx) + if not response_result.ok then + return response_result + end + + if decorator then + local response = response_result.value + return { + ok = true, + value = { + status = response.status, + body = decorator(response.body, ctx), + headers = response.headers, + }, + } + end + + return response_result + end, + } +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router-soundness/router.lua b/testdata/fixtures/realworld/middleware-session-router-soundness/router.lua new file mode 100644 index 00000000..5289df7e --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router-soundness/router.lua @@ -0,0 +1,110 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") + +type Router = { + routes: {[string]: protocol.Route}, + global_middlewares: {protocol.Middleware}, + hooks: {protocol.AfterHook}, + register_route: (self: Router, route: protocol.Route) -> Router, + use: (self: Router, middleware: protocol.Middleware) -> Router, + on_response: (self: Router, hook: protocol.AfterHook) -> Router, + dispatch: (self: Router, request: protocol.Request, now: time.Time) -> protocol.ResponseResult, +} + +type RuntimeRouter = Router + +local RuntimeRouter = {} +RuntimeRouter.__index = RuntimeRouter + +local M = {} + +function M.new(): Router + local self: RuntimeRouter = { + routes = {}, + global_middlewares = {}, + hooks = {}, + register_route = RuntimeRouter.register_route, + use = RuntimeRouter.use, + on_response = RuntimeRouter.on_response, + dispatch = RuntimeRouter.dispatch, + } + setmetatable(self, RuntimeRouter) + return self +end + +function RuntimeRouter:register_route(route: protocol.Route): RuntimeRouter + self.routes[route.key] = route + return self +end + +function RuntimeRouter:use(middleware: protocol.Middleware): RuntimeRouter + table.insert(self.global_middlewares, middleware) + return self +end + +function RuntimeRouter:on_response(hook: protocol.AfterHook): RuntimeRouter + table.insert(self.hooks, hook) + return self +end + +function RuntimeRouter:dispatch(request: protocol.Request, now: time.Time): protocol.ResponseResult + if request.kind == "timer" then + return { + ok = true, + value = { + status = 202, + body = "timer:" .. request.id .. ":" .. tostring(request.at:unix()), + headers = {["x-trace"] = request.meta.trace_id}, + }, + } + end + + local route = self.routes[request.method .. " " .. request.path] + if not route then + return { + ok = false, + error = { + code = "not_found", + message = "missing route: " .. request.path, + retryable = false, + }, + } + end + local route_value: protocol.Route = route + + local ctx: protocol.RequestContext = { + request = request, + params = request.params or {}, + locals = {}, + session = nil, + } + + local current_ctx: protocol.RequestContext = ctx + for _, middleware in ipairs(self.global_middlewares) do + local step = middleware(current_ctx) + if not step.ok then + return {ok = false, error = step.error} + end + current_ctx = step.value + end + for _, middleware in ipairs(route_value.middlewares) do + local step = middleware(current_ctx) + if not step.ok then + return {ok = false, error = step.error} + end + current_ctx = step.value + end + + local final_ctx = current_ctx + local response_result = route_value.handle(final_ctx) + if response_result.ok then + final_ctx.locals["handled_at"] = tostring(now:unix()) + for _, hook in ipairs(self.hooks) do + hook(final_ctx, response_result.value) + end + end + return response_result +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router-soundness/session_store.lua b/testdata/fixtures/realworld/middleware-session-router-soundness/session_store.lua new file mode 100644 index 00000000..7a509d83 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router-soundness/session_store.lua @@ -0,0 +1,45 @@ +local time = require("time") +local protocol = require("protocol") + +type SessionStore = { + by_token: {[string]: protocol.SessionSnapshot}, + lookup: (self: SessionStore, token: string) -> protocol.SessionSnapshot?, + touch: (self: SessionStore, token: string, at: time.Time) -> (), + save: (self: SessionStore, token: string, snapshot: protocol.SessionSnapshot) -> (), +} + +type Store = SessionStore + +local Store = {} +Store.__index = Store + +local M = {} +M.SessionStore = SessionStore + +function M.new(): SessionStore + local self: Store = { + by_token = {}, + lookup = Store.lookup, + touch = Store.touch, + save = Store.save, + } + setmetatable(self, Store) + return self +end + +function Store:lookup(token: string): protocol.SessionSnapshot? + return self.by_token[token] +end + +function Store:touch(token: string, at: time.Time) + local snapshot = self.by_token[token] + if snapshot then + snapshot.last_seen = at + end +end + +function Store:save(token: string, snapshot: protocol.SessionSnapshot) + self.by_token[token] = snapshot +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router/main.lua b/testdata/fixtures/realworld/middleware-session-router/main.lua new file mode 100644 index 00000000..b6ae57b3 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router/main.lua @@ -0,0 +1,209 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local session_store = require("session_store") +local middleware_builder = require("middleware_builder") +local route_builder = require("route_builder") +local router = require("router") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() +local store = session_store.new() + +store:save("token-1", { + id = "s-1", + user_id = "u-1", + scopes = {["chat.read"] = true}, + last_seen = nil, + attributes = {role = "owner"}, +}) + +store:save("token-2", { + id = "s-2", + user_id = "u-2", + scopes = {["chat.read"] = false}, + last_seen = now, + attributes = nil, +}) + +local observed_users: {string} = {} +local observed_sources: {[string]: string} = {} +local last_timer_body: string? = nil + +local auth = middleware_builder.new() + :named("auth") + :require_header("authorization") + :load_sessions_from(store) + :require_scope("chat.read") + :copy_tag_to_local("source", "source") + :build() + +local trace = middleware_builder.new() + :named("trace") + :copy_tag_to_local("source", "source") + :build() + +local rooms_route = route_builder.new() + :key("GET /rooms/show") + :use(auth) + :require_param("room_id") + :decorate_body(function(body: string, ctx: protocol.RequestContext): string + local source = ctx.locals["source"] + if source then + return body .. ":" .. source + end + return body + end) + :handle(function(ctx: protocol.RequestContext): protocol.ResponseResult + local room_id = ctx.params["room_id"] + if not room_id then + return { + ok = false, + error = { + code = "invalid", + message = "missing room_id", + retryable = false, + }, + } + end + + local user_id = "guest" + local freshness = "cold" + if ctx.session then + user_id = ctx.session.user_id + if ctx.session.last_seen then + freshness = "warm" + end + end + + return { + ok = true, + value = { + status = 200, + body = "room:" .. room_id .. ":" .. user_id .. ":" .. freshness, + headers = {["x-user"] = user_id}, + }, + } + end) + :build() + +local health_route = route_builder.new() + :key("GET /health") + :handle(function(ctx: protocol.RequestContext): protocol.ResponseResult + local trace_id = ctx.request.meta.trace_id + return { + ok = true, + value = { + status = 204, + body = "health:" .. trace_id, + headers = {["x-trace"] = trace_id}, + }, + } + end) + :build() + +local app = router.new() + :use(trace) + :register_route(rooms_route) + :register_route(health_route) + +app:on_response(function(ctx: protocol.RequestContext, response: protocol.Response) + local user = response.headers["x-user"] + if user then + table.insert(observed_users, user) + end + + local source = ctx.locals["source"] + if source then + observed_sources[ctx.request.path] = source + end +end) + +local room_request: protocol.HttpRequest = { + kind = "http", + method = "GET", + path = "/rooms/show", + headers = {authorization = "token-1"}, + params = {room_id = "room-1"}, + body = nil, + meta = protocol.meta("trace-1", {source = "api"}), +} + +local health_request: protocol.HttpRequest = { + kind = "http", + method = "GET", + path = "/health", + headers = {}, + params = nil, + body = nil, + meta = protocol.meta("trace-2", {source = "probe"}), +} + +local timer_request: protocol.TimerRequest = { + kind = "timer", + id = "tick-1", + at = now, + meta = protocol.meta("trace-3", nil), +} + +local room_response = app:dispatch(room_request, now) +if room_response.ok then + local status: integer = room_response.value.status + local body: string = room_response.value.body + local maybe_user = room_response.value.headers["x-user"] + if maybe_user then + local stable_user: string = maybe_user + end +end + +local room_label = result.and_then(room_response, function(response: protocol.Response): StringResult + return result.ok(response.body) +end) + +if room_label.ok then + local label: string = room_label.value +end + +local health_response = app:dispatch(health_request, now) +if health_response.ok then + local health_body: string = health_response.value.body + local trace_header = health_response.value.headers["x-trace"] + if trace_header then + local stable_trace: string = trace_header + end +end + +local timer_response = app:dispatch(timer_request, now) +if timer_response.ok then + last_timer_body = timer_response.value.body +end + +if last_timer_body ~= nil then + local stable_body: string = last_timer_body +end + +for _, user in ipairs(observed_users) do + local stable_user: string = user +end + +for path, source in pairs(observed_sources) do + local stable_path: string = path + local stable_source: string = source +end + +store:touch("token-1", now) +local snapshot = store:lookup("token-1") +if snapshot then + local seen = snapshot.last_seen or now + local elapsed = now:sub(seen) + local seconds: number = elapsed:seconds() + + local attrs = snapshot.attributes + if attrs then + local role = attrs["role"] + if role then + local stable_role: string = role + end + end +end diff --git a/testdata/fixtures/realworld/middleware-session-router/manifest.json b/testdata/fixtures/realworld/middleware-session-router/manifest.json new file mode 100644 index 00000000..5cf4e8fe --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router/manifest.json @@ -0,0 +1,16 @@ +{ + "description": "Application-shaped router mixing fluent middleware and route builders, nested closures, metatable-backed router/store objects, dynamic registries, optional session state, maps, callbacks, and time-based dispatch.", + "files": [ + "result.lua", + "protocol.lua", + "session_store.lua", + "middleware_builder.lua", + "route_builder.lua", + "router.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/middleware-session-router/middleware_builder.lua b/testdata/fixtures/realworld/middleware-session-router/middleware_builder.lua new file mode 100644 index 00000000..784ee880 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router/middleware_builder.lua @@ -0,0 +1,135 @@ +local result = require("result") +local protocol = require("protocol") +local session_store = require("session_store") + +type MiddlewareBuilder = { + name: string, + required_header: string?, + local_tag_key: string?, + local_field: string?, + store: session_store.SessionStore?, + required_scope: string?, + named: (self: MiddlewareBuilder, name: string) -> MiddlewareBuilder, + require_header: (self: MiddlewareBuilder, header: string) -> MiddlewareBuilder, + copy_tag_to_local: (self: MiddlewareBuilder, tag_key: string, local_field: string) -> MiddlewareBuilder, + load_sessions_from: (self: MiddlewareBuilder, store: session_store.SessionStore) -> MiddlewareBuilder, + require_scope: (self: MiddlewareBuilder, scope: string) -> MiddlewareBuilder, + build: (self: MiddlewareBuilder) -> protocol.Middleware, +} + +type Builder = MiddlewareBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): MiddlewareBuilder + local self: Builder = { + name = "middleware", + required_header = nil, + local_tag_key = nil, + local_field = nil, + store = nil, + required_scope = nil, + named = Builder.named, + require_header = Builder.require_header, + copy_tag_to_local = Builder.copy_tag_to_local, + load_sessions_from = Builder.load_sessions_from, + require_scope = Builder.require_scope, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_header(header: string): Builder + self.required_header = header + return self +end + +function Builder:copy_tag_to_local(tag_key: string, local_field: string): Builder + self.local_tag_key = tag_key + self.local_field = local_field + return self +end + +function Builder:load_sessions_from(store: session_store.SessionStore): Builder + self.store = store + return self +end + +function Builder:require_scope(scope: string): Builder + self.required_scope = scope + return self +end + +function Builder:build(): protocol.Middleware + local name = self.name + local required_header = self.required_header + local local_tag_key = self.local_tag_key + local local_field = self.local_field + local store = self.store + local required_scope = self.required_scope + + return function(ctx: protocol.RequestContext): protocol.MiddlewareResult + local next_ctx: protocol.RequestContext = ctx + if required_header then + local token = next_ctx.request.headers[required_header] + if not token then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing header: " .. required_header, + retryable = false, + }, + } + end + + if store then + local snapshot = store:lookup(token) + if not snapshot then + return { + ok = false, + error = { + code = "not_found", + message = name .. " missing session", + retryable = false, + }, + } + end + if required_scope and not snapshot.scopes[required_scope] then + return { + ok = false, + error = { + code = "invalid", + message = name .. " denied", + retryable = false, + }, + } + end + next_ctx.session = snapshot + end + end + + if local_tag_key and local_field then + local tags = next_ctx.request.meta.tags + if tags then + local tag = tags[local_tag_key] + if tag then + next_ctx.locals[local_field] = tag + end + end + end + + return {ok = true, value = next_ctx} + end +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router/protocol.lua b/testdata/fixtures/realworld/middleware-session-router/protocol.lua new file mode 100644 index 00000000..670ec999 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router/protocol.lua @@ -0,0 +1,84 @@ +local time = require("time") +local result = require("result") + +type RequestMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type HttpRequest = { + kind: "http", + method: "GET" | "POST", + path: string, + headers: {[string]: string}, + params: {[string]: string}?, + body: string?, + meta: RequestMeta, +} + +type TimerRequest = { + kind: "timer", + id: string, + at: time.Time, + meta: RequestMeta, +} + +type Request = HttpRequest | TimerRequest + +type SessionSnapshot = { + id: string, + user_id: string, + scopes: {[string]: boolean}, + last_seen: time.Time?, + attributes: {[string]: string}?, +} + +type RequestContext = { + request: HttpRequest, + params: {[string]: string}, + locals: {[string]: string}, + session: SessionSnapshot?, +} + +type Response = { + status: integer, + body: string, + headers: {[string]: string}, +} + +type AppError = result.AppError +type ResponseResult = {ok: true, value: Response} | {ok: false, error: AppError} +type MiddlewareResult = {ok: true, value: RequestContext} | {ok: false, error: AppError} +type Middleware = (RequestContext) -> MiddlewareResult +type RouteHandler = (RequestContext) -> ResponseResult +type AfterHook = (RequestContext, Response) -> () + +type Route = { + key: string, + middlewares: {Middleware}, + handle: RouteHandler, +} + +local M = {} +M.RequestMeta = RequestMeta +M.HttpRequest = HttpRequest +M.TimerRequest = TimerRequest +M.Request = Request +M.SessionSnapshot = SessionSnapshot +M.RequestContext = RequestContext +M.Response = Response +M.ResponseResult = ResponseResult +M.MiddlewareResult = MiddlewareResult +M.Middleware = Middleware +M.RouteHandler = RouteHandler +M.AfterHook = AfterHook +M.Route = Route + +function M.meta(trace_id: string, tags: {[string]: string}?): RequestMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router/result.lua b/testdata/fixtures/realworld/middleware-session-router/result.lua new file mode 100644 index 00000000..25095ae7 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "rate_limited" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router/route_builder.lua b/testdata/fixtures/realworld/middleware-session-router/route_builder.lua new file mode 100644 index 00000000..4dd5fee0 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router/route_builder.lua @@ -0,0 +1,128 @@ +local result = require("result") +local protocol = require("protocol") + +type BodyDecorator = (string, protocol.RequestContext) -> string + +type RouteBuilder = { + route_key: string, + middlewares: {protocol.Middleware}, + required_param: string?, + decorator: BodyDecorator?, + handler: protocol.RouteHandler, + key: (self: RouteBuilder, route_key: string) -> RouteBuilder, + use: (self: RouteBuilder, middleware: protocol.Middleware) -> RouteBuilder, + require_param: (self: RouteBuilder, param: string) -> RouteBuilder, + decorate_body: (self: RouteBuilder, decorator: BodyDecorator) -> RouteBuilder, + handle: (self: RouteBuilder, handler: protocol.RouteHandler) -> RouteBuilder, + build: (self: RouteBuilder) -> protocol.Route, +} + +type Builder = RouteBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +local function missing_handler(_ctx: protocol.RequestContext): protocol.ResponseResult + return { + ok = false, + error = { + code = "invalid", + message = "missing handler", + retryable = false, + }, + } +end + +function M.new(): RouteBuilder + local self: Builder = { + route_key = "GET /", + middlewares = {}, + required_param = nil, + decorator = nil, + handler = missing_handler, + key = Builder.key, + use = Builder.use, + require_param = Builder.require_param, + decorate_body = Builder.decorate_body, + handle = Builder.handle, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:key(route_key: string): Builder + self.route_key = route_key + return self +end + +function Builder:use(middleware: protocol.Middleware): Builder + table.insert(self.middlewares, middleware) + return self +end + +function Builder:require_param(param: string): Builder + self.required_param = param + return self +end + +function Builder:decorate_body(decorator: BodyDecorator): Builder + self.decorator = decorator + return self +end + +function Builder:handle(handler: protocol.RouteHandler): Builder + self.handler = handler + return self +end + +function Builder:build(): protocol.Route + local route_key = self.route_key + local middlewares = self.middlewares + local required_param = self.required_param + local decorator = self.decorator + local handler = self.handler + + return { + key = route_key, + middlewares = middlewares, + handle = function(ctx: protocol.RequestContext): protocol.ResponseResult + if required_param then + local raw = ctx.params[required_param] + if not raw then + return { + ok = false, + error = { + code = "invalid", + message = "missing param: " .. required_param, + retryable = false, + }, + } + end + end + + local response_result = handler(ctx) + if not response_result.ok then + return response_result + end + + if decorator then + local response = response_result.value + return { + ok = true, + value = { + status = response.status, + body = decorator(response.body, ctx), + headers = response.headers, + }, + } + end + + return response_result + end, + } +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router/router.lua b/testdata/fixtures/realworld/middleware-session-router/router.lua new file mode 100644 index 00000000..5289df7e --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router/router.lua @@ -0,0 +1,110 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") + +type Router = { + routes: {[string]: protocol.Route}, + global_middlewares: {protocol.Middleware}, + hooks: {protocol.AfterHook}, + register_route: (self: Router, route: protocol.Route) -> Router, + use: (self: Router, middleware: protocol.Middleware) -> Router, + on_response: (self: Router, hook: protocol.AfterHook) -> Router, + dispatch: (self: Router, request: protocol.Request, now: time.Time) -> protocol.ResponseResult, +} + +type RuntimeRouter = Router + +local RuntimeRouter = {} +RuntimeRouter.__index = RuntimeRouter + +local M = {} + +function M.new(): Router + local self: RuntimeRouter = { + routes = {}, + global_middlewares = {}, + hooks = {}, + register_route = RuntimeRouter.register_route, + use = RuntimeRouter.use, + on_response = RuntimeRouter.on_response, + dispatch = RuntimeRouter.dispatch, + } + setmetatable(self, RuntimeRouter) + return self +end + +function RuntimeRouter:register_route(route: protocol.Route): RuntimeRouter + self.routes[route.key] = route + return self +end + +function RuntimeRouter:use(middleware: protocol.Middleware): RuntimeRouter + table.insert(self.global_middlewares, middleware) + return self +end + +function RuntimeRouter:on_response(hook: protocol.AfterHook): RuntimeRouter + table.insert(self.hooks, hook) + return self +end + +function RuntimeRouter:dispatch(request: protocol.Request, now: time.Time): protocol.ResponseResult + if request.kind == "timer" then + return { + ok = true, + value = { + status = 202, + body = "timer:" .. request.id .. ":" .. tostring(request.at:unix()), + headers = {["x-trace"] = request.meta.trace_id}, + }, + } + end + + local route = self.routes[request.method .. " " .. request.path] + if not route then + return { + ok = false, + error = { + code = "not_found", + message = "missing route: " .. request.path, + retryable = false, + }, + } + end + local route_value: protocol.Route = route + + local ctx: protocol.RequestContext = { + request = request, + params = request.params or {}, + locals = {}, + session = nil, + } + + local current_ctx: protocol.RequestContext = ctx + for _, middleware in ipairs(self.global_middlewares) do + local step = middleware(current_ctx) + if not step.ok then + return {ok = false, error = step.error} + end + current_ctx = step.value + end + for _, middleware in ipairs(route_value.middlewares) do + local step = middleware(current_ctx) + if not step.ok then + return {ok = false, error = step.error} + end + current_ctx = step.value + end + + local final_ctx = current_ctx + local response_result = route_value.handle(final_ctx) + if response_result.ok then + final_ctx.locals["handled_at"] = tostring(now:unix()) + for _, hook in ipairs(self.hooks) do + hook(final_ctx, response_result.value) + end + end + return response_result +end + +return M diff --git a/testdata/fixtures/realworld/middleware-session-router/session_store.lua b/testdata/fixtures/realworld/middleware-session-router/session_store.lua new file mode 100644 index 00000000..7a509d83 --- /dev/null +++ b/testdata/fixtures/realworld/middleware-session-router/session_store.lua @@ -0,0 +1,45 @@ +local time = require("time") +local protocol = require("protocol") + +type SessionStore = { + by_token: {[string]: protocol.SessionSnapshot}, + lookup: (self: SessionStore, token: string) -> protocol.SessionSnapshot?, + touch: (self: SessionStore, token: string, at: time.Time) -> (), + save: (self: SessionStore, token: string, snapshot: protocol.SessionSnapshot) -> (), +} + +type Store = SessionStore + +local Store = {} +Store.__index = Store + +local M = {} +M.SessionStore = SessionStore + +function M.new(): SessionStore + local self: Store = { + by_token = {}, + lookup = Store.lookup, + touch = Store.touch, + save = Store.save, + } + setmetatable(self, Store) + return self +end + +function Store:lookup(token: string): protocol.SessionSnapshot? + return self.by_token[token] +end + +function Store:touch(token: string, at: time.Time) + local snapshot = self.by_token[token] + if snapshot then + snapshot.last_seen = at + end +end + +function Store:save(token: string, snapshot: protocol.SessionSnapshot) + self.by_token[token] = snapshot +end + +return M diff --git a/testdata/fixtures/realworld/multi-return-error-chain/process.lua b/testdata/fixtures/realworld/multi-return-error-chain/process.lua index b46e1422..36f8500e 100644 --- a/testdata/fixtures/realworld/multi-return-error-chain/process.lua +++ b/testdata/fixtures/realworld/multi-return-error-chain/process.lua @@ -2,7 +2,7 @@ local validate = require("validate") type ProcessResult = { message: string, - config: ValidConfig, + config: validate.ValidConfig, } local M = {} diff --git a/testdata/fixtures/realworld/multi-return-error-chain/validate.lua b/testdata/fixtures/realworld/multi-return-error-chain/validate.lua index 96294737..bba3d2ee 100644 --- a/testdata/fixtures/realworld/multi-return-error-chain/validate.lua +++ b/testdata/fixtures/realworld/multi-return-error-chain/validate.lua @@ -10,7 +10,7 @@ type ValidConfig = { local M = {} M.ValidConfig = ValidConfig -function M.validate(config: ParsedConfig): (ValidConfig?, string?) +function M.validate(config: parse.ParsedConfig): (ValidConfig?, string?) if #config.host == 0 then return nil, "host is empty" end diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/delivery_store.lua b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/delivery_store.lua new file mode 100644 index 00000000..d474e648 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/delivery_store.lua @@ -0,0 +1,117 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") + +type DeliveryStore = { + state: protocol.StoreState, + lookup_record: (self: DeliveryStore, message_id: string) -> protocol.DeliveryRecord?, + lookup_receipt: (self: DeliveryStore, provider_id: string) -> protocol.DeliveryReceipt?, + push_step: (self: DeliveryStore, step: protocol.DeliveryStep, at: time.Time) -> (), + record_receipt: (self: DeliveryStore, request: protocol.Request, receipt: protocol.DeliveryReceipt) -> (), + summarize: (self: DeliveryStore, now: time.Time, last_status: string?) -> protocol.RunSummary, +} + +type Store = DeliveryStore + +local Store = {} +Store.__index = Store + +local M = {} +M.DeliveryStore = DeliveryStore + +function M.new(id: string, now: time.Time): DeliveryStore + local self: Store = { + state = { + id = id, + started_at = now, + last_delivery_at = nil, + records = {}, + cached_receipts = {}, + steps = {}, + counters = {}, + flags = {}, + }, + lookup_record = Store.lookup_record, + lookup_receipt = Store.lookup_receipt, + push_step = Store.push_step, + record_receipt = Store.record_receipt, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:lookup_record(message_id: string): protocol.DeliveryRecord? + return self.state.records[message_id] +end + +function Store:lookup_receipt(provider_id: string): protocol.DeliveryReceipt? + return self.state.cached_receipts[provider_id] +end + +function Store:push_step(step: protocol.DeliveryStep, at: time.Time) + table.insert(self.state.steps, step) + self.state.last_delivery_at = at +end + +function Store:record_receipt(request: protocol.Request, receipt: protocol.DeliveryReceipt) + if request.kind == "tick" then + return + end + + local source = helpers.tag_value(request, "source") + local priority = helpers.tag_value(request, "priority") + + local attempts: {[string]: integer} = {} + local current = self.state.records[request.message_id] + if current then + attempts = current.attempts + end + + local existing = attempts[receipt.channel] + if existing then + attempts[receipt.channel] = existing + 1 + else + attempts[receipt.channel] = 1 + end + + local record: protocol.DeliveryRecord = { + tenant_id = request.tenant_id, + message_id = request.message_id, + channel = receipt.channel, + last_status = receipt.local_status, + attempts = attempts, + rendered_preview = receipt.preview, + last_receipt = receipt, + updated_at = receipt.delivered_at, + source = source, + priority = priority, + last_error = nil, + } + + self.state.records[request.message_id] = record + self.state.cached_receipts[receipt.provider_id] = receipt + + if receipt.counter_key then + helpers.bump_counter(self.state.counters, receipt.counter_key) + end + helpers.bump_counter(self.state.counters, receipt.local_status) +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.RunSummary + local sent_count = self.state.counters["sent"] or 0 + local queued_count = self.state.counters["queued"] or 0 + local retrying_count = self.state.counters["retrying"] or 0 + + return { + id = self.state.id, + total_processed = #self.state.steps, + sent_count = sent_count, + queued_count = queued_count, + retrying_count = retrying_count, + elapsed_seconds = now:sub(self.state.started_at), + last_status = last_status, + } +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/helpers.lua b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/helpers.lua new file mode 100644 index 00000000..2aa84f51 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/helpers.lua @@ -0,0 +1,43 @@ +local protocol = require("protocol") + +local M = {} + +function M.request_label(request: protocol.Request): string + if request.kind == "email" then + return "email:" .. request.subject + end + if request.kind == "sms" then + return "sms:" .. request.phone + end + if request.kind == "webhook" then + return "webhook:" .. request.endpoint + end + return "tick" +end + +function M.tag_value(request: protocol.Request, key: string): string? + if request.kind == "tick" then + return nil + end + + local tags = request.meta.tags + if not tags then + return nil + end + return tags[key] +end + +function M.receipt_note(receipt: protocol.DeliveryReceipt): string + return receipt.channel .. ":" .. receipt.local_status .. ":" .. receipt.provider_id +end + +function M.bump_counter(counters: {[string]: integer}, key: string) + local current = counters[key] + if current then + counters[key] = current + 1 + return + end + counters[key] = 1 +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/main.lua b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/main.lua new file mode 100644 index 00000000..2155edb5 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/main.lua @@ -0,0 +1,53 @@ +local time = require("time") +local protocol = require("protocol") +local template_builder = require("template_builder") +local transport_builder = require("transport_builder") +local runtime = require("runtime") + +local now = time.now() + +local email_renderer = template_builder.new() + :named("email") + :prefix_with("subject") + :require_tag("source") + :build() + +local email_transport = transport_builder.new() + :for_channel("email") + :use_renderer(email_renderer) + :count_as("mailops") + :require_tag("source") + :build() + +local app = runtime.new():register_transport("email", email_transport) + +local email_one: protocol.EmailRequest = { + kind = "email", + tenant_id = "tenant-a", + message_id = "msg-1", + recipient = "alice@example.com", + subject = "welcome", + template = "welcome-email", + meta = protocol.meta("trace-1", nil), +} + +local tick: protocol.TickRequest = { + kind = "tick", + at = now, +} + +local store = app:new_store("delivery-soundness", now) +local run_result = app:run(store, {tick}, now) + +if run_result.ok then + local last_status: string = run_result.value.last_status -- expect-error +end + +local record: protocol.DeliveryRecord = store.state.records["missing"] -- expect-error +local receipt: protocol.DeliveryReceipt = store.state.cached_receipts["missing"] -- expect-error +local handler: protocol.TransportHandler = app.transports["email"] -- expect-error + +local elapsed = now:sub(store.state.last_delivery_at) -- expect-error +local source: string = email_one.meta.tags["source"] -- expect-error + +local missing_status: string = store:lookup_record("msg-1").last_status -- expect-error diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/manifest.json b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/manifest.json new file mode 100644 index 00000000..5cc9ed38 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/manifest.json @@ -0,0 +1,17 @@ +{ + "description": "Soundness counterpart for notification delivery runtime, pinning unsafe optional cache, tag, registry, and time uses in the same multi-module application shape.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "delivery_store.lua", + "template_builder.lua", + "transport_builder.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 6 + } +} diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/protocol.lua b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/protocol.lua new file mode 100644 index 00000000..c7f9b177 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/protocol.lua @@ -0,0 +1,158 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type DeliveryMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type EmailRequest = { + kind: "email", + tenant_id: string, + message_id: string, + recipient: string, + subject: string, + template: string, + meta: DeliveryMeta, +} + +type SmsRequest = { + kind: "sms", + tenant_id: string, + message_id: string, + phone: string, + template: string, + meta: DeliveryMeta, +} + +type WebhookRequest = { + kind: "webhook", + tenant_id: string, + message_id: string, + endpoint: string, + template: string, + meta: DeliveryMeta, +} + +type TickRequest = { + kind: "tick", + at: time.Time, +} + +type Request = EmailRequest | SmsRequest | WebhookRequest | TickRequest + +type DeliveryReceipt = { + message_id: string, + tenant_id: string, + channel: "email" | "sms" | "webhook", + provider_id: string, + preview: string, + local_status: "sent" | "queued" | "retrying", + delivered_at: time.Time, + retry_after: time.Time?, + tags: {[string]: string}?, + counter_key: string?, +} + +type DeliveryRecord = { + tenant_id: string, + message_id: string, + channel: "email" | "sms" | "webhook", + last_status: "sent" | "queued" | "retrying", + attempts: {[string]: integer}, + rendered_preview: string?, + last_receipt: DeliveryReceipt?, + updated_at: time.Time?, + source: string?, + priority: string?, + last_error: string?, +} + +type DeliveryEventStep = { + kind: "delivery", + channel: "email" | "sms" | "webhook", + message_id: string, + note: string, + provider_id: string, +} + +type RetryStep = { + kind: "retry", + channel: "email" | "sms" | "webhook", + message_id: string, + note: string, + retry_at: time.Time, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type DeliveryStep = DeliveryEventStep | RetryStep | AuditStep + +type StoreState = { + id: string, + started_at: time.Time, + last_delivery_at: time.Time?, + records: {[string]: DeliveryRecord}, + cached_receipts: {[string]: DeliveryReceipt}, + steps: {DeliveryStep}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type RunSummary = { + id: string, + total_processed: number, + sent_count: number, + queued_count: number, + retrying_count: number, + elapsed_seconds: time.Duration, + last_status: string?, +} + +type TemplateResult = {ok: true, value: string} | {ok: false, error: AppError} +type TransportResult = {ok: true, value: DeliveryReceipt} | {ok: false, error: AppError} +type DeliverResult = {ok: true, value: string?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type TemplateRenderer = (StoreState, Request) -> TemplateResult +type TransportHandler = (StoreState, Request, time.Time) -> TransportResult +type StepHook = (DeliveryStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.DeliveryMeta = DeliveryMeta +M.EmailRequest = EmailRequest +M.SmsRequest = SmsRequest +M.WebhookRequest = WebhookRequest +M.TickRequest = TickRequest +M.Request = Request +M.DeliveryReceipt = DeliveryReceipt +M.DeliveryRecord = DeliveryRecord +M.DeliveryEventStep = DeliveryEventStep +M.RetryStep = RetryStep +M.AuditStep = AuditStep +M.DeliveryStep = DeliveryStep +M.StoreState = StoreState +M.RunSummary = RunSummary +M.TemplateResult = TemplateResult +M.TransportResult = TransportResult +M.DeliverResult = DeliverResult +M.RunResult = RunResult +M.TemplateRenderer = TemplateRenderer +M.TransportHandler = TransportHandler +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): DeliveryMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/result.lua b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/runtime.lua b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/runtime.lua new file mode 100644 index 00000000..8ea85a7b --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/runtime.lua @@ -0,0 +1,137 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local helpers = require("helpers") +local delivery_store = require("delivery_store") + +type DeliveryRuntime = { + transports: {[string]: protocol.TransportHandler}, + hooks: {protocol.StepHook}, + register_transport: (self: DeliveryRuntime, kind: string, handler: protocol.TransportHandler) -> DeliveryRuntime, + on_step: (self: DeliveryRuntime, hook: protocol.StepHook) -> DeliveryRuntime, + new_store: (self: DeliveryRuntime, id: string, now: time.Time) -> delivery_store.DeliveryStore, + emit: (self: DeliveryRuntime, store: delivery_store.DeliveryStore, step: protocol.DeliveryStep, at: time.Time) -> (), + deliver: (self: DeliveryRuntime, store: delivery_store.DeliveryStore, request: protocol.Request, at: time.Time) -> protocol.DeliverResult, + run: (self: DeliveryRuntime, store: delivery_store.DeliveryStore, requests: {protocol.Request}, now: time.Time) -> protocol.RunResult, +} + +type Runtime = DeliveryRuntime + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.DeliveryRuntime = DeliveryRuntime + +function M.new(): DeliveryRuntime + local self: Runtime = { + transports = {}, + hooks = {}, + register_transport = Runtime.register_transport, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + deliver = Runtime.deliver, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:register_transport(kind: string, handler: protocol.TransportHandler): Runtime + self.transports[kind] = handler + return self +end + +function Runtime:on_step(hook: protocol.StepHook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): delivery_store.DeliveryStore + return delivery_store.new(id, now) +end + +function Runtime:emit(store: delivery_store.DeliveryStore, step: protocol.DeliveryStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:deliver( + store: delivery_store.DeliveryStore, + request: protocol.Request, + at: time.Time +): protocol.DeliverResult + if request.kind == "tick" then + local audit_step: protocol.AuditStep = {kind = "audit", note = "tick", at = request.at} + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + local handler = self.transports[request.kind] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing transport: " .. request.kind, + retryable = false, + }, + } + end + + local receipt_result = handler(store.state, request, at) + if not receipt_result.ok then + return {ok = false, error = receipt_result.error} + end + + local receipt = receipt_result.value + store:record_receipt(request, receipt) + + if receipt.local_status == "retrying" then + local retry_at = receipt.retry_after or at + local retry_step: protocol.RetryStep = { + kind = "retry", + channel = receipt.channel, + message_id = receipt.message_id, + note = helpers.receipt_note(receipt), + retry_at = retry_at, + } + self:emit(store, retry_step, at) + return {ok = true, value = "retrying"} + end + + local delivery_step: protocol.DeliveryEventStep = { + kind = "delivery", + channel = receipt.channel, + message_id = receipt.message_id, + note = helpers.receipt_note(receipt), + provider_id = receipt.provider_id, + } + self:emit(store, delivery_step, at) + return {ok = true, value = receipt.local_status} +end + +function Runtime:run( + store: delivery_store.DeliveryStore, + requests: {protocol.Request}, + now: time.Time +): protocol.RunResult + local last_status: string? = nil + + for _, request in ipairs(requests) do + local deliver_result: protocol.DeliverResult = self:deliver(store, request, now) + if not deliver_result.ok then + return {ok = false, error = deliver_result.error} + end + if deliver_result.value then + last_status = deliver_result.value + end + end + + return {ok = true, value = store:summarize(now, last_status)} +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/template_builder.lua b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/template_builder.lua new file mode 100644 index 00000000..5ec9d064 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/template_builder.lua @@ -0,0 +1,126 @@ +local result = require("result") +local protocol = require("protocol") + +type TemplateDecorator = (string, protocol.Request) -> string + +type TemplateBuilder = { + name: string, + prefix: string?, + required_tag: string?, + suffix: string?, + decorator: TemplateDecorator?, + named: (self: TemplateBuilder, name: string) -> TemplateBuilder, + prefix_with: (self: TemplateBuilder, prefix: string) -> TemplateBuilder, + require_tag: (self: TemplateBuilder, tag: string) -> TemplateBuilder, + suffix_with: (self: TemplateBuilder, suffix: string) -> TemplateBuilder, + decorate: (self: TemplateBuilder, decorator: TemplateDecorator) -> TemplateBuilder, + build: (self: TemplateBuilder) -> protocol.TemplateRenderer, +} + +type Builder = TemplateBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): TemplateBuilder + local self: Builder = { + name = "template", + prefix = nil, + required_tag = nil, + suffix = nil, + decorator = nil, + named = Builder.named, + prefix_with = Builder.prefix_with, + require_tag = Builder.require_tag, + suffix_with = Builder.suffix_with, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self +end + +function Builder:require_tag(tag: string): Builder + self.required_tag = tag + return self +end + +function Builder:suffix_with(suffix: string): Builder + self.suffix = suffix + return self +end + +function Builder:decorate(decorator: TemplateDecorator): Builder + self.decorator = decorator + return self +end + +function Builder:build(): protocol.TemplateRenderer + local name = self.name + local prefix = self.prefix + local required_tag = self.required_tag + local suffix = self.suffix + local decorator = self.decorator + + return function(_state: protocol.StoreState, request: protocol.Request): protocol.TemplateResult + if request.kind == "tick" then + return { + ok = false, + error = { + code = "invalid", + message = name .. " cannot render ticks", + retryable = false, + }, + } + end + + if required_tag then + local tags = request.meta.tags + if not tags or not tags[required_tag] then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + local body: string + if request.kind == "email" then + body = request.subject .. ":" .. request.recipient + elseif request.kind == "sms" then + body = request.phone .. ":" .. request.template + else + body = request.endpoint .. ":" .. request.template + end + + if prefix then + body = prefix .. ":" .. body + end + if suffix then + body = body .. ":" .. suffix + end + if decorator then + body = decorator(body, request) + end + + return {ok = true, value = body} + end +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime-soundness/transport_builder.lua b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/transport_builder.lua new file mode 100644 index 00000000..ca3e89a1 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime-soundness/transport_builder.lua @@ -0,0 +1,168 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") + +type PreviewDecorator = (string, protocol.Request) -> string + +type TransportBuilder = { + channel: "email" | "sms" | "webhook", + renderer: protocol.TemplateRenderer?, + counter_key: string?, + required_tag: string?, + preview_decorator: PreviewDecorator?, + for_channel: (self: TransportBuilder, channel: "email" | "sms" | "webhook") -> TransportBuilder, + use_renderer: (self: TransportBuilder, renderer: protocol.TemplateRenderer) -> TransportBuilder, + count_as: (self: TransportBuilder, key: string) -> TransportBuilder, + require_tag: (self: TransportBuilder, key: string) -> TransportBuilder, + decorate_preview: (self: TransportBuilder, decorator: PreviewDecorator) -> TransportBuilder, + build: (self: TransportBuilder) -> protocol.TransportHandler, +} + +type Builder = TransportBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): TransportBuilder + local self: Builder = { + channel = "email", + renderer = nil, + counter_key = nil, + required_tag = nil, + preview_decorator = nil, + for_channel = Builder.for_channel, + use_renderer = Builder.use_renderer, + count_as = Builder.count_as, + require_tag = Builder.require_tag, + decorate_preview = Builder.decorate_preview, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:for_channel(channel: "email" | "sms" | "webhook"): Builder + self.channel = channel + return self +end + +function Builder:use_renderer(renderer: protocol.TemplateRenderer): Builder + self.renderer = renderer + return self +end + +function Builder:count_as(key: string): Builder + self.counter_key = key + return self +end + +function Builder:require_tag(key: string): Builder + self.required_tag = key + return self +end + +function Builder:decorate_preview(decorator: PreviewDecorator): Builder + self.preview_decorator = decorator + return self +end + +function Builder:build(): protocol.TransportHandler + local channel = self.channel + local renderer = self.renderer + local counter_key = self.counter_key + local required_tag = self.required_tag + local preview_decorator = self.preview_decorator + + return function(state: protocol.StoreState, request: protocol.Request, at: time.Time): protocol.TransportResult + if request.kind == "tick" then + return { + ok = false, + error = { + code = "invalid", + message = channel .. " cannot deliver ticks", + retryable = false, + }, + } + end + + local active_renderer = renderer + if not active_renderer then + return { + ok = false, + error = { + code = "invalid", + message = channel .. " missing renderer", + retryable = false, + }, + } + end + + if request.kind ~= channel then + return { + ok = false, + error = { + code = "invalid", + message = channel .. " wrong request kind", + retryable = false, + }, + } + end + + if required_tag then + local tags = request.meta.tags + if not tags or not tags[required_tag] then + return { + ok = false, + error = { + code = "invalid", + message = channel .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + local rendered = active_renderer(state, request) + if not rendered.ok then + return {ok = false, error = rendered.error} + end + + local preview = rendered.value + if preview_decorator then + preview = preview_decorator(preview, request) + end + + local status: "sent" | "queued" | "retrying" = "sent" + local tags = request.meta.tags + if request.kind == "webhook" then + status = "queued" + end + if tags and tags["retry"] == "true" then + status = "retrying" + end + + local retry_after: time.Time? = nil + if status == "retrying" then + retry_after = at + end + + local receipt: protocol.DeliveryReceipt = { + message_id = request.message_id, + tenant_id = request.tenant_id, + channel = channel, + provider_id = channel .. ":" .. request.message_id, + preview = preview, + local_status = status, + delivered_at = at, + retry_after = retry_after, + tags = tags, + counter_key = counter_key, + } + + return {ok = true, value = receipt} + end +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/delivery_store.lua b/testdata/fixtures/realworld/notification-delivery-runtime/delivery_store.lua new file mode 100644 index 00000000..d474e648 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/delivery_store.lua @@ -0,0 +1,117 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") + +type DeliveryStore = { + state: protocol.StoreState, + lookup_record: (self: DeliveryStore, message_id: string) -> protocol.DeliveryRecord?, + lookup_receipt: (self: DeliveryStore, provider_id: string) -> protocol.DeliveryReceipt?, + push_step: (self: DeliveryStore, step: protocol.DeliveryStep, at: time.Time) -> (), + record_receipt: (self: DeliveryStore, request: protocol.Request, receipt: protocol.DeliveryReceipt) -> (), + summarize: (self: DeliveryStore, now: time.Time, last_status: string?) -> protocol.RunSummary, +} + +type Store = DeliveryStore + +local Store = {} +Store.__index = Store + +local M = {} +M.DeliveryStore = DeliveryStore + +function M.new(id: string, now: time.Time): DeliveryStore + local self: Store = { + state = { + id = id, + started_at = now, + last_delivery_at = nil, + records = {}, + cached_receipts = {}, + steps = {}, + counters = {}, + flags = {}, + }, + lookup_record = Store.lookup_record, + lookup_receipt = Store.lookup_receipt, + push_step = Store.push_step, + record_receipt = Store.record_receipt, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:lookup_record(message_id: string): protocol.DeliveryRecord? + return self.state.records[message_id] +end + +function Store:lookup_receipt(provider_id: string): protocol.DeliveryReceipt? + return self.state.cached_receipts[provider_id] +end + +function Store:push_step(step: protocol.DeliveryStep, at: time.Time) + table.insert(self.state.steps, step) + self.state.last_delivery_at = at +end + +function Store:record_receipt(request: protocol.Request, receipt: protocol.DeliveryReceipt) + if request.kind == "tick" then + return + end + + local source = helpers.tag_value(request, "source") + local priority = helpers.tag_value(request, "priority") + + local attempts: {[string]: integer} = {} + local current = self.state.records[request.message_id] + if current then + attempts = current.attempts + end + + local existing = attempts[receipt.channel] + if existing then + attempts[receipt.channel] = existing + 1 + else + attempts[receipt.channel] = 1 + end + + local record: protocol.DeliveryRecord = { + tenant_id = request.tenant_id, + message_id = request.message_id, + channel = receipt.channel, + last_status = receipt.local_status, + attempts = attempts, + rendered_preview = receipt.preview, + last_receipt = receipt, + updated_at = receipt.delivered_at, + source = source, + priority = priority, + last_error = nil, + } + + self.state.records[request.message_id] = record + self.state.cached_receipts[receipt.provider_id] = receipt + + if receipt.counter_key then + helpers.bump_counter(self.state.counters, receipt.counter_key) + end + helpers.bump_counter(self.state.counters, receipt.local_status) +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.RunSummary + local sent_count = self.state.counters["sent"] or 0 + local queued_count = self.state.counters["queued"] or 0 + local retrying_count = self.state.counters["retrying"] or 0 + + return { + id = self.state.id, + total_processed = #self.state.steps, + sent_count = sent_count, + queued_count = queued_count, + retrying_count = retrying_count, + elapsed_seconds = now:sub(self.state.started_at), + last_status = last_status, + } +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/helpers.lua b/testdata/fixtures/realworld/notification-delivery-runtime/helpers.lua new file mode 100644 index 00000000..2aa84f51 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/helpers.lua @@ -0,0 +1,43 @@ +local protocol = require("protocol") + +local M = {} + +function M.request_label(request: protocol.Request): string + if request.kind == "email" then + return "email:" .. request.subject + end + if request.kind == "sms" then + return "sms:" .. request.phone + end + if request.kind == "webhook" then + return "webhook:" .. request.endpoint + end + return "tick" +end + +function M.tag_value(request: protocol.Request, key: string): string? + if request.kind == "tick" then + return nil + end + + local tags = request.meta.tags + if not tags then + return nil + end + return tags[key] +end + +function M.receipt_note(receipt: protocol.DeliveryReceipt): string + return receipt.channel .. ":" .. receipt.local_status .. ":" .. receipt.provider_id +end + +function M.bump_counter(counters: {[string]: integer}, key: string) + local current = counters[key] + if current then + counters[key] = current + 1 + return + end + counters[key] = 1 +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/main.lua b/testdata/fixtures/realworld/notification-delivery-runtime/main.lua new file mode 100644 index 00000000..df6910b9 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/main.lua @@ -0,0 +1,231 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local helpers = require("helpers") +local template_builder = require("template_builder") +local transport_builder = require("transport_builder") +local runtime = require("runtime") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() + +local observed_deliveries: {[string]: string} = {} +local observed_retries: {string} = {} +local observed_audits: {string} = {} +local last_runtime_id: string? = nil + +local email_renderer = template_builder.new() + :named("email") + :prefix_with("subject") + :require_tag("source") + :suffix_with("mail") + :decorate(function(body: string, request: protocol.Request): string + if request.kind == "email" then + return body .. ":" .. request.tenant_id + end + return body + end) + :build() + +local sms_renderer = template_builder.new() + :named("sms") + :prefix_with("sms") + :suffix_with("text") + :decorate(function(body: string, request: protocol.Request): string + if request.kind == "sms" then + return body .. ":" .. request.tenant_id + end + return body + end) + :build() + +local webhook_renderer = template_builder.new() + :named("webhook") + :prefix_with("json") + :require_tag("source") + :suffix_with("hook") + :decorate(function(body: string, request: protocol.Request): string + if request.kind == "webhook" then + return body .. ":" .. request.tenant_id + end + return body + end) + :build() + +local email_transport = transport_builder.new() + :for_channel("email") + :use_renderer(email_renderer) + :count_as("mailops") + :require_tag("source") + :decorate_preview(function(preview: string, request: protocol.Request): string + return preview .. ":" .. helpers.request_label(request) + end) + :build() + +local sms_transport = transport_builder.new() + :for_channel("sms") + :use_renderer(sms_renderer) + :count_as("smsops") + :decorate_preview(function(preview: string, request: protocol.Request): string + return preview .. ":" .. helpers.request_label(request) + end) + :build() + +local webhook_transport = transport_builder.new() + :for_channel("webhook") + :use_renderer(webhook_renderer) + :count_as("hookops") + :require_tag("source") + :decorate_preview(function(preview: string, request: protocol.Request): string + return preview .. ":" .. helpers.request_label(request) + end) + :build() + +local app = runtime.new() + :register_transport("email", email_transport) + :register_transport("sms", sms_transport) + :register_transport("webhook", webhook_transport) + +app:on_step(function(step: protocol.DeliveryStep, state: protocol.StoreState) + last_runtime_id = state.id + if step.kind == "delivery" then + observed_deliveries[step.message_id] = step.note + local provider_id: string = step.provider_id + elseif step.kind == "retry" then + table.insert(observed_retries, step.note) + local retry_seconds: integer = step.retry_at:unix() + else + table.insert(observed_audits, step.note) + local at_seconds: integer = step.at:unix() + end +end) + +local email_one: protocol.EmailRequest = { + kind = "email", + tenant_id = "tenant-a", + message_id = "msg-1", + recipient = "alice@example.com", + subject = "welcome", + template = "welcome-email", + meta = protocol.meta("trace-1", {source = "api", priority = "high"}), +} + +local sms_one: protocol.SmsRequest = { + kind = "sms", + tenant_id = "tenant-a", + message_id = "msg-2", + phone = "+155555501", + template = "otp", + meta = protocol.meta("trace-2", {source = "cron"}), +} + +local webhook_one: protocol.WebhookRequest = { + kind = "webhook", + tenant_id = "tenant-b", + message_id = "msg-3", + endpoint = "https://example.com/hook", + template = "sync", + meta = protocol.meta("trace-3", {source = "worker", retry = "true", priority = "low"}), +} + +local tick: protocol.TickRequest = { + kind = "tick", + at = now, +} + +local requests: {protocol.Request} = { + email_one, + sms_one, + webhook_one, + tick, +} + +local store = app:new_store("delivery-1", now) +local summary_result = app:run(store, requests, now) +if not summary_result.ok then + local message: string = summary_result.error.message + local retryable: boolean = summary_result.error.retryable +else + local summary = summary_result.value + local runtime_id: string = summary.id + local total_processed: number = summary.total_processed + local sent_count: number = summary.sent_count + local queued_count: number = summary.queued_count + local retrying_count: number = summary.retrying_count + local elapsed_seconds: time.Duration = summary.elapsed_seconds + local last_status: string? = summary.last_status +end + +local label_result = result.map(summary_result, function(summary: protocol.RunSummary): string + return summary.id .. ":" .. tostring(summary.sent_count + summary.retrying_count) +end) + +if label_result.ok then + local label: string = label_result.value +end + +local checked_result = result.and_then(summary_result, function(summary: protocol.RunSummary): StringResult + if summary.retrying_count == 0 then + return { + ok = false, + error = { + code = "invalid", + message = "expected retry", + retryable = false, + }, + } + end + return {ok = true, value = summary.id} +end) + +if checked_result.ok then + local checked_id: string = checked_result.value +end + +local record_one = store:lookup_record("msg-1") +if record_one then + local tenant_id: string = record_one.tenant_id + local source: string? = record_one.source + local attempts = record_one.attempts["email"] + if attempts then + local email_attempts: integer = attempts + end + local receipt = record_one.last_receipt + if receipt then + local provider: string = receipt.provider_id + end +end + +local cached = store:lookup_receipt("webhook:msg-3") +if cached then + local provider_id: string = cached.provider_id + local retry_at = cached.retry_after + if retry_at then + local retry_seconds: integer = retry_at:unix() + end +end + +local sent_counter = store.state.counters["sent"] +if sent_counter then + local sent_value: integer = sent_counter +end + +local queue_counter = store.state.counters["queued"] +if queue_counter then + local queued_value: integer = queue_counter +end + +local retry_counter = store.state.counters["retrying"] +if retry_counter then + local retrying_value: integer = retry_counter +end + +local last_seen = store.state.last_delivery_at +if last_seen then + local last_seconds: integer = last_seen:unix() +end + +if last_runtime_id then + local runtime_id: string = last_runtime_id +end diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/manifest.json b/testdata/fixtures/realworld/notification-delivery-runtime/manifest.json new file mode 100644 index 00000000..78a9ca9a --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/manifest.json @@ -0,0 +1,17 @@ +{ + "description": "Application-shaped notification delivery runtime mixing closure-built renderers, fluent transport builders, metatable-backed store/runtime objects, discriminated requests, dynamic transport registries, optional cached receipts, callbacks, counters, tag maps, and time-based summaries.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "delivery_store.lua", + "template_builder.lua", + "transport_builder.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/protocol.lua b/testdata/fixtures/realworld/notification-delivery-runtime/protocol.lua new file mode 100644 index 00000000..c7f9b177 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/protocol.lua @@ -0,0 +1,158 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type DeliveryMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type EmailRequest = { + kind: "email", + tenant_id: string, + message_id: string, + recipient: string, + subject: string, + template: string, + meta: DeliveryMeta, +} + +type SmsRequest = { + kind: "sms", + tenant_id: string, + message_id: string, + phone: string, + template: string, + meta: DeliveryMeta, +} + +type WebhookRequest = { + kind: "webhook", + tenant_id: string, + message_id: string, + endpoint: string, + template: string, + meta: DeliveryMeta, +} + +type TickRequest = { + kind: "tick", + at: time.Time, +} + +type Request = EmailRequest | SmsRequest | WebhookRequest | TickRequest + +type DeliveryReceipt = { + message_id: string, + tenant_id: string, + channel: "email" | "sms" | "webhook", + provider_id: string, + preview: string, + local_status: "sent" | "queued" | "retrying", + delivered_at: time.Time, + retry_after: time.Time?, + tags: {[string]: string}?, + counter_key: string?, +} + +type DeliveryRecord = { + tenant_id: string, + message_id: string, + channel: "email" | "sms" | "webhook", + last_status: "sent" | "queued" | "retrying", + attempts: {[string]: integer}, + rendered_preview: string?, + last_receipt: DeliveryReceipt?, + updated_at: time.Time?, + source: string?, + priority: string?, + last_error: string?, +} + +type DeliveryEventStep = { + kind: "delivery", + channel: "email" | "sms" | "webhook", + message_id: string, + note: string, + provider_id: string, +} + +type RetryStep = { + kind: "retry", + channel: "email" | "sms" | "webhook", + message_id: string, + note: string, + retry_at: time.Time, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type DeliveryStep = DeliveryEventStep | RetryStep | AuditStep + +type StoreState = { + id: string, + started_at: time.Time, + last_delivery_at: time.Time?, + records: {[string]: DeliveryRecord}, + cached_receipts: {[string]: DeliveryReceipt}, + steps: {DeliveryStep}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type RunSummary = { + id: string, + total_processed: number, + sent_count: number, + queued_count: number, + retrying_count: number, + elapsed_seconds: time.Duration, + last_status: string?, +} + +type TemplateResult = {ok: true, value: string} | {ok: false, error: AppError} +type TransportResult = {ok: true, value: DeliveryReceipt} | {ok: false, error: AppError} +type DeliverResult = {ok: true, value: string?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type TemplateRenderer = (StoreState, Request) -> TemplateResult +type TransportHandler = (StoreState, Request, time.Time) -> TransportResult +type StepHook = (DeliveryStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.DeliveryMeta = DeliveryMeta +M.EmailRequest = EmailRequest +M.SmsRequest = SmsRequest +M.WebhookRequest = WebhookRequest +M.TickRequest = TickRequest +M.Request = Request +M.DeliveryReceipt = DeliveryReceipt +M.DeliveryRecord = DeliveryRecord +M.DeliveryEventStep = DeliveryEventStep +M.RetryStep = RetryStep +M.AuditStep = AuditStep +M.DeliveryStep = DeliveryStep +M.StoreState = StoreState +M.RunSummary = RunSummary +M.TemplateResult = TemplateResult +M.TransportResult = TransportResult +M.DeliverResult = DeliverResult +M.RunResult = RunResult +M.TemplateRenderer = TemplateRenderer +M.TransportHandler = TransportHandler +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): DeliveryMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/result.lua b/testdata/fixtures/realworld/notification-delivery-runtime/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/runtime.lua b/testdata/fixtures/realworld/notification-delivery-runtime/runtime.lua new file mode 100644 index 00000000..8ea85a7b --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/runtime.lua @@ -0,0 +1,137 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local helpers = require("helpers") +local delivery_store = require("delivery_store") + +type DeliveryRuntime = { + transports: {[string]: protocol.TransportHandler}, + hooks: {protocol.StepHook}, + register_transport: (self: DeliveryRuntime, kind: string, handler: protocol.TransportHandler) -> DeliveryRuntime, + on_step: (self: DeliveryRuntime, hook: protocol.StepHook) -> DeliveryRuntime, + new_store: (self: DeliveryRuntime, id: string, now: time.Time) -> delivery_store.DeliveryStore, + emit: (self: DeliveryRuntime, store: delivery_store.DeliveryStore, step: protocol.DeliveryStep, at: time.Time) -> (), + deliver: (self: DeliveryRuntime, store: delivery_store.DeliveryStore, request: protocol.Request, at: time.Time) -> protocol.DeliverResult, + run: (self: DeliveryRuntime, store: delivery_store.DeliveryStore, requests: {protocol.Request}, now: time.Time) -> protocol.RunResult, +} + +type Runtime = DeliveryRuntime + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.DeliveryRuntime = DeliveryRuntime + +function M.new(): DeliveryRuntime + local self: Runtime = { + transports = {}, + hooks = {}, + register_transport = Runtime.register_transport, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + deliver = Runtime.deliver, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:register_transport(kind: string, handler: protocol.TransportHandler): Runtime + self.transports[kind] = handler + return self +end + +function Runtime:on_step(hook: protocol.StepHook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): delivery_store.DeliveryStore + return delivery_store.new(id, now) +end + +function Runtime:emit(store: delivery_store.DeliveryStore, step: protocol.DeliveryStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:deliver( + store: delivery_store.DeliveryStore, + request: protocol.Request, + at: time.Time +): protocol.DeliverResult + if request.kind == "tick" then + local audit_step: protocol.AuditStep = {kind = "audit", note = "tick", at = request.at} + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + local handler = self.transports[request.kind] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing transport: " .. request.kind, + retryable = false, + }, + } + end + + local receipt_result = handler(store.state, request, at) + if not receipt_result.ok then + return {ok = false, error = receipt_result.error} + end + + local receipt = receipt_result.value + store:record_receipt(request, receipt) + + if receipt.local_status == "retrying" then + local retry_at = receipt.retry_after or at + local retry_step: protocol.RetryStep = { + kind = "retry", + channel = receipt.channel, + message_id = receipt.message_id, + note = helpers.receipt_note(receipt), + retry_at = retry_at, + } + self:emit(store, retry_step, at) + return {ok = true, value = "retrying"} + end + + local delivery_step: protocol.DeliveryEventStep = { + kind = "delivery", + channel = receipt.channel, + message_id = receipt.message_id, + note = helpers.receipt_note(receipt), + provider_id = receipt.provider_id, + } + self:emit(store, delivery_step, at) + return {ok = true, value = receipt.local_status} +end + +function Runtime:run( + store: delivery_store.DeliveryStore, + requests: {protocol.Request}, + now: time.Time +): protocol.RunResult + local last_status: string? = nil + + for _, request in ipairs(requests) do + local deliver_result: protocol.DeliverResult = self:deliver(store, request, now) + if not deliver_result.ok then + return {ok = false, error = deliver_result.error} + end + if deliver_result.value then + last_status = deliver_result.value + end + end + + return {ok = true, value = store:summarize(now, last_status)} +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/template_builder.lua b/testdata/fixtures/realworld/notification-delivery-runtime/template_builder.lua new file mode 100644 index 00000000..5ec9d064 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/template_builder.lua @@ -0,0 +1,126 @@ +local result = require("result") +local protocol = require("protocol") + +type TemplateDecorator = (string, protocol.Request) -> string + +type TemplateBuilder = { + name: string, + prefix: string?, + required_tag: string?, + suffix: string?, + decorator: TemplateDecorator?, + named: (self: TemplateBuilder, name: string) -> TemplateBuilder, + prefix_with: (self: TemplateBuilder, prefix: string) -> TemplateBuilder, + require_tag: (self: TemplateBuilder, tag: string) -> TemplateBuilder, + suffix_with: (self: TemplateBuilder, suffix: string) -> TemplateBuilder, + decorate: (self: TemplateBuilder, decorator: TemplateDecorator) -> TemplateBuilder, + build: (self: TemplateBuilder) -> protocol.TemplateRenderer, +} + +type Builder = TemplateBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): TemplateBuilder + local self: Builder = { + name = "template", + prefix = nil, + required_tag = nil, + suffix = nil, + decorator = nil, + named = Builder.named, + prefix_with = Builder.prefix_with, + require_tag = Builder.require_tag, + suffix_with = Builder.suffix_with, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self +end + +function Builder:require_tag(tag: string): Builder + self.required_tag = tag + return self +end + +function Builder:suffix_with(suffix: string): Builder + self.suffix = suffix + return self +end + +function Builder:decorate(decorator: TemplateDecorator): Builder + self.decorator = decorator + return self +end + +function Builder:build(): protocol.TemplateRenderer + local name = self.name + local prefix = self.prefix + local required_tag = self.required_tag + local suffix = self.suffix + local decorator = self.decorator + + return function(_state: protocol.StoreState, request: protocol.Request): protocol.TemplateResult + if request.kind == "tick" then + return { + ok = false, + error = { + code = "invalid", + message = name .. " cannot render ticks", + retryable = false, + }, + } + end + + if required_tag then + local tags = request.meta.tags + if not tags or not tags[required_tag] then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + local body: string + if request.kind == "email" then + body = request.subject .. ":" .. request.recipient + elseif request.kind == "sms" then + body = request.phone .. ":" .. request.template + else + body = request.endpoint .. ":" .. request.template + end + + if prefix then + body = prefix .. ":" .. body + end + if suffix then + body = body .. ":" .. suffix + end + if decorator then + body = decorator(body, request) + end + + return {ok = true, value = body} + end +end + +return M diff --git a/testdata/fixtures/realworld/notification-delivery-runtime/transport_builder.lua b/testdata/fixtures/realworld/notification-delivery-runtime/transport_builder.lua new file mode 100644 index 00000000..ca3e89a1 --- /dev/null +++ b/testdata/fixtures/realworld/notification-delivery-runtime/transport_builder.lua @@ -0,0 +1,168 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") + +type PreviewDecorator = (string, protocol.Request) -> string + +type TransportBuilder = { + channel: "email" | "sms" | "webhook", + renderer: protocol.TemplateRenderer?, + counter_key: string?, + required_tag: string?, + preview_decorator: PreviewDecorator?, + for_channel: (self: TransportBuilder, channel: "email" | "sms" | "webhook") -> TransportBuilder, + use_renderer: (self: TransportBuilder, renderer: protocol.TemplateRenderer) -> TransportBuilder, + count_as: (self: TransportBuilder, key: string) -> TransportBuilder, + require_tag: (self: TransportBuilder, key: string) -> TransportBuilder, + decorate_preview: (self: TransportBuilder, decorator: PreviewDecorator) -> TransportBuilder, + build: (self: TransportBuilder) -> protocol.TransportHandler, +} + +type Builder = TransportBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): TransportBuilder + local self: Builder = { + channel = "email", + renderer = nil, + counter_key = nil, + required_tag = nil, + preview_decorator = nil, + for_channel = Builder.for_channel, + use_renderer = Builder.use_renderer, + count_as = Builder.count_as, + require_tag = Builder.require_tag, + decorate_preview = Builder.decorate_preview, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:for_channel(channel: "email" | "sms" | "webhook"): Builder + self.channel = channel + return self +end + +function Builder:use_renderer(renderer: protocol.TemplateRenderer): Builder + self.renderer = renderer + return self +end + +function Builder:count_as(key: string): Builder + self.counter_key = key + return self +end + +function Builder:require_tag(key: string): Builder + self.required_tag = key + return self +end + +function Builder:decorate_preview(decorator: PreviewDecorator): Builder + self.preview_decorator = decorator + return self +end + +function Builder:build(): protocol.TransportHandler + local channel = self.channel + local renderer = self.renderer + local counter_key = self.counter_key + local required_tag = self.required_tag + local preview_decorator = self.preview_decorator + + return function(state: protocol.StoreState, request: protocol.Request, at: time.Time): protocol.TransportResult + if request.kind == "tick" then + return { + ok = false, + error = { + code = "invalid", + message = channel .. " cannot deliver ticks", + retryable = false, + }, + } + end + + local active_renderer = renderer + if not active_renderer then + return { + ok = false, + error = { + code = "invalid", + message = channel .. " missing renderer", + retryable = false, + }, + } + end + + if request.kind ~= channel then + return { + ok = false, + error = { + code = "invalid", + message = channel .. " wrong request kind", + retryable = false, + }, + } + end + + if required_tag then + local tags = request.meta.tags + if not tags or not tags[required_tag] then + return { + ok = false, + error = { + code = "invalid", + message = channel .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + local rendered = active_renderer(state, request) + if not rendered.ok then + return {ok = false, error = rendered.error} + end + + local preview = rendered.value + if preview_decorator then + preview = preview_decorator(preview, request) + end + + local status: "sent" | "queued" | "retrying" = "sent" + local tags = request.meta.tags + if request.kind == "webhook" then + status = "queued" + end + if tags and tags["retry"] == "true" then + status = "retrying" + end + + local retry_after: time.Time? = nil + if status == "retrying" then + retry_after = at + end + + local receipt: protocol.DeliveryReceipt = { + message_id = request.message_id, + tenant_id = request.tenant_id, + channel = channel, + provider_id = channel .. ":" .. request.message_id, + preview = preview, + local_status = status, + delivered_at = at, + retry_after = retry_after, + tags = tags, + counter_key = counter_key, + } + + return {ok = true, value = receipt} + end +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/main.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/main.lua new file mode 100644 index 00000000..012b3a10 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/main.lua @@ -0,0 +1,45 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local runtime = require("runtime") + +local now = time.now() +local app = runtime.new() + +local policy: protocol.RetryPolicy = { + label = "single", + max_attempts = 1, + compute_delay = function(attempt: integer): number + return attempt + end, + should_retry = function(_error: result.AppError, _attempt: integer): boolean + return false + end, +} + +local store = app:new_store("runtime-unsafe", now) +local call: protocol.PluginCall = { + kind = "plugin_call", + id = "p1", + plugin = "missing", + input = {query = "lua"}, + meta = protocol.meta("trace-unsafe", {source = "planner"}), +} + +local maybe_handler = app.handlers["missing"] +local produced = maybe_handler(store.state, call, policy) -- expect-error + +local cached = store:lookup("search") +local bad_content: string = cached.content -- expect-error + +local elapsed = now:sub(store.state.last_seen) -- expect-error +local seconds: number = elapsed:seconds() + +local hook_step: protocol.RuntimeStep = { + kind = "hook", + name = "retry", + detail = "retry", +} + +local tags = call.meta.tags +local bad_source: string = tags["source"] -- expect-error diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/manifest.json b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/manifest.json new file mode 100644 index 00000000..ebe53271 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/manifest.json @@ -0,0 +1,14 @@ +{ + "description": "Soundness guard for the plugin runtime pipeline domain: unsafe optional handler calls, unchecked optional cache reads, optional time arithmetic, and unchecked optional tag-map reads must be rejected.", + "files": [ + "result.lua", + "protocol.lua", + "runtime_store.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 4 + } +} diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/protocol.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/protocol.lua new file mode 100644 index 00000000..659e1358 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/protocol.lua @@ -0,0 +1,71 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type RetryPolicy = { + label: string, + max_attempts: integer, + compute_delay: (integer) -> number, + should_retry: (AppError, integer) -> boolean, +} + +type RequestMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type PluginCall = { + kind: "plugin_call", + id: string, + plugin: string, + input: {[string]: any}, + meta: RequestMeta, +} + +type PluginOutput = { + plugin: string, + content: string, + cached: boolean, + tags: {[string]: string}?, +} + +type RuntimeStep = { + kind: "hook", + name: string, + detail: string, +} + +type RuntimeState = { + id: string, + started_at: time.Time, + last_seen: time.Time?, + steps: {RuntimeStep}, + cache: {[string]: PluginOutput}, + flags: {[string]: boolean}, +} + +type Hook = (RuntimeStep, RuntimeState) -> () +type PluginResult = {ok: true, value: PluginOutput} | {ok: false, error: AppError} +type PluginHandler = (RuntimeState, PluginCall, RetryPolicy) -> PluginResult + +local M = {} +M.AppError = AppError +M.RetryPolicy = RetryPolicy +M.RequestMeta = RequestMeta +M.PluginCall = PluginCall +M.PluginOutput = PluginOutput +M.RuntimeStep = RuntimeStep +M.RuntimeState = RuntimeState +M.Hook = Hook +M.PluginResult = PluginResult +M.PluginHandler = PluginHandler + +function M.meta(trace_id: string, tags: {[string]: string}?): RequestMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/result.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/result.lua new file mode 100644 index 00000000..b11ef8c8 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/result.lua @@ -0,0 +1,31 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "rate_limited" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/runtime.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/runtime.lua new file mode 100644 index 00000000..70fe3bab --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/runtime.lua @@ -0,0 +1,44 @@ +local time = require("time") +local protocol = require("protocol") +local runtime_store = require("runtime_store") + +type Runtime = { + handlers: {[string]: protocol.PluginHandler}, + hooks: {protocol.Hook}, + register_plugin: (self: Runtime, name: string, handler: protocol.PluginHandler) -> Runtime, + on_step: (self: Runtime, hook: protocol.Hook) -> Runtime, + new_store: (self: Runtime, id: string, now: time.Time) -> runtime_store.RuntimeStore, +} + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} + +function M.new(): Runtime + local self: Runtime = { + handlers = {}, + hooks = {}, + register_plugin = Runtime.register_plugin, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:register_plugin(name: string, handler: protocol.PluginHandler): Runtime + self.handlers[name] = handler + return self +end + +function Runtime:on_step(hook: protocol.Hook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): runtime_store.RuntimeStore + return runtime_store.new(id, now) +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/runtime_store.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/runtime_store.lua new file mode 100644 index 00000000..77d66fea --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline-soundness/runtime_store.lua @@ -0,0 +1,37 @@ +local time = require("time") +local protocol = require("protocol") + +type RuntimeStore = { + state: protocol.RuntimeState, + lookup: (self: RuntimeStore, name: string) -> protocol.PluginOutput?, +} + +type Store = RuntimeStore + +local Store = {} +Store.__index = Store + +local M = {} +M.RuntimeStore = RuntimeStore + +function M.new(id: string, now: time.Time): RuntimeStore + local self: Store = { + state = { + id = id, + started_at = now, + last_seen = nil, + steps = {}, + cache = {}, + flags = {}, + }, + lookup = Store.lookup, + } + setmetatable(self, Store) + return self +end + +function Store:lookup(name: string): protocol.PluginOutput? + return self.state.cache[name] +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/helpers.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline/helpers.lua new file mode 100644 index 00000000..0ef6b0fc --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/helpers.lua @@ -0,0 +1,29 @@ +local protocol = require("protocol") + +local M = {} + +function M.source_tag(call: protocol.PluginCall): string + local tags = call.meta.tags + if not tags then + return "unknown" + end + + local source = tags["source"] + if source == nil then + return "unknown" + end + return source +end + +function M.policy_label(policy: protocol.RetryPolicy): string + return policy.label .. ":" .. tostring(policy.max_attempts) +end + +function M.status_name(status: string?): string + if status == nil then + return "pending" + end + return status +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/main.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline/main.lua new file mode 100644 index 00000000..ff3b2e32 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/main.lua @@ -0,0 +1,217 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local policy_builder = require("policy_builder") +local plugin_builder = require("plugin_builder") +local helpers = require("helpers") +local runtime = require("runtime") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() + +local observed_plugin_contents: {[string]: string} = {} +local retry_notes: {string} = {} +local last_runtime_id: string? = nil + +local policy = policy_builder.new() + :named("aggressive") + :max_attempts(3) + :scale_by(0.5) + :retry_on("busy") + :retry_on("rate_limited") + :with_backoff(function(attempt: integer): number + return attempt + 0.25 + end) + :build() + +local search_handler = plugin_builder.new() + :named("search") + :arg("query") + :prefix_with("search") + :remember_when_flag("warm") + :decorate(function(content: string, _state: protocol.RuntimeState, call: protocol.PluginCall): string + return content .. ":" .. helpers.source_tag(call) + end) + :tag_with(function(policy_value: protocol.RetryPolicy, _state: protocol.RuntimeState, call: protocol.PluginCall): {[string]: string} + return { + source = helpers.source_tag(call), + policy = helpers.policy_label(policy_value), + plugin = call.plugin, + } + end) + :build() + +local profile_handler = plugin_builder.new() + :named("profile") + :arg("user_id") + :prefix_with("profile") + :remember_when_flag("cache_hit") + :decorate(function(content: string, state: protocol.RuntimeState, _call: protocol.PluginCall): string + local seen = state.last_seen + if seen then + return content .. ":repeat" + end + return content .. ":first" + end) + :tag_with(function(policy_value: protocol.RetryPolicy, _state: protocol.RuntimeState, call: protocol.PluginCall): {[string]: string} + return { + source = helpers.source_tag(call), + policy = policy_value.label, + plugin = call.plugin, + } + end) + :build() + +local app = runtime.new(policy) + :register_plugin("search", search_handler) + :register_plugin("profile", profile_handler) + +app:on_step(function(step: protocol.RuntimeStep, state: protocol.RuntimeState) + last_runtime_id = state.id + + if step.kind == "plugin" then + observed_plugin_contents[step.plugin] = step.output.content + local cached: boolean = step.output.cached + elseif step.kind == "hook" then + table.insert(retry_notes, step.detail) + else + local note: string = step.note + local at_seconds: integer = step.at:unix() + end +end) + +local heartbeat: protocol.HeartbeatEvent = { + kind = "heartbeat", + at = now, +} + +local search_call: protocol.PluginCall = { + kind = "plugin_call", + id = "e1", + plugin = "search", + input = {query = "lua"}, + meta = protocol.meta("trace-1", {source = "planner"}), +} + +local profile_call: protocol.PluginCall = { + kind = "plugin_call", + id = "e2", + plugin = "profile", + input = {user_id = "u-1"}, + meta = protocol.meta("trace-2", nil), +} + +local repeat_search_call: protocol.PluginCall = { + kind = "plugin_call", + id = "e3", + plugin = "search", + input = {query = "lua"}, + meta = protocol.meta("trace-3", {source = "planner"}), +} + +local done_event: protocol.DoneEvent = { + kind = "done", + status = "ok", + reason = "complete", + meta = protocol.meta("trace-4", nil), +} + +local events: {protocol.Event} = { + heartbeat, + search_call, + profile_call, + repeat_search_call, + done_event, +} + +local store = app:new_store("runtime-1", now) +store:set_flag("warm") + +local summary_result = app:run(store, events, now) +local observed_error_message: string? = nil +summary_result = result.tap_error(summary_result, function(err: result.AppError) + observed_error_message = err.message +end) + +if not summary_result.ok then + local message: string = summary_result.error.message + local retryable: boolean = summary_result.error.retryable +else + local summary = summary_result.value + local runtime_id: string = summary.id + local total_steps: number = summary.total_steps + local cached_count: number = summary.cached_count + local elapsed_seconds: number = summary.elapsed_seconds + local last_status: string? = summary.last_status + local last_plugin: string? = summary.last_plugin +end + +local label_result = result.map(summary_result, function(summary: protocol.RuntimeSummary): string + return summary.id .. ":" .. helpers.status_name(summary.last_status) +end) + +if label_result.ok then + local label: string = label_result.value +end + +local last_plugin_result = result.and_then(summary_result, function(summary: protocol.RuntimeSummary): StringResult + if summary.total_steps == 0 then + return { + ok = false, + error = { + code = "invalid", + message = "expected steps", + retryable = false, + }, + } + end + + local last_plugin = summary.last_plugin or "none" + return { + ok = true, + value = last_plugin, + } +end) + +if last_plugin_result.ok then + local plugin_name: string = last_plugin_result.value +end + +local cached_search = store:lookup("search") +if cached_search then + local cached_content: string = cached_search.content + local cached_flag: boolean = cached_search.cached + local tags = cached_search.tags + if tags then + local source = tags["source"] + if source then + local stable_source: string = source + end + end +end + +for plugin_name, content in pairs(observed_plugin_contents) do + local stable_plugin: string = plugin_name + local stable_content: string = content +end + +for _, detail in ipairs(retry_notes) do + local retry_note: string = detail +end + +if last_runtime_id ~= nil then + local stable_runtime_id: string = last_runtime_id +end + +local search_tags = search_call.meta.tags +if search_tags then + local source = search_tags["source"] + if source then + local source_name: string = source + end +end + +local last_seen = store.state.last_seen or store.state.started_at +local elapsed = now:sub(last_seen) +local seconds: number = elapsed:seconds() diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/manifest.json b/testdata/fixtures/realworld/plugin-runtime-pipeline/manifest.json new file mode 100644 index 00000000..782ced16 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/manifest.json @@ -0,0 +1,17 @@ +{ + "description": "Plugin runtime pipeline mixing generic results, fluent retry-policy and plugin builders, closure-built handlers, metatable-backed runtime/store objects, callback hooks, dynamic registries, optional caches, and time-based summaries.", + "files": [ + "result.lua", + "protocol.lua", + "policy_builder.lua", + "plugin_builder.lua", + "helpers.lua", + "runtime_store.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/plugin_builder.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline/plugin_builder.lua new file mode 100644 index 00000000..ec15b357 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/plugin_builder.lua @@ -0,0 +1,128 @@ +local result = require("result") +local protocol = require("protocol") + +type Decorator = (string, protocol.RuntimeState, protocol.PluginCall) -> string +type Tagger = (protocol.RetryPolicy, protocol.RuntimeState, protocol.PluginCall) -> {[string]: string}? + +type PluginBuilder = { + name: string, + arg_key: string, + prefix: string, + remember_flag: string?, + decorator: Decorator?, + tagger: Tagger?, + named: (self: PluginBuilder, name: string) -> PluginBuilder, + arg: (self: PluginBuilder, key: string) -> PluginBuilder, + prefix_with: (self: PluginBuilder, prefix: string) -> PluginBuilder, + remember_when_flag: (self: PluginBuilder, flag: string) -> PluginBuilder, + decorate: (self: PluginBuilder, decorator: Decorator) -> PluginBuilder, + tag_with: (self: PluginBuilder, tagger: Tagger) -> PluginBuilder, + build: (self: PluginBuilder) -> protocol.PluginHandler, +} + +type Builder = PluginBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.PluginBuilder = PluginBuilder + +function M.new(): PluginBuilder + local self: Builder = { + name = "plugin", + arg_key = "value", + prefix = "plugin", + remember_flag = nil, + decorator = nil, + tagger = nil, + named = Builder.named, + arg = Builder.arg, + prefix_with = Builder.prefix_with, + remember_when_flag = Builder.remember_when_flag, + decorate = Builder.decorate, + tag_with = Builder.tag_with, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:arg(key: string): Builder + self.arg_key = key + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self +end + +function Builder:remember_when_flag(flag: string): Builder + self.remember_flag = flag + return self +end + +function Builder:decorate(decorator: Decorator): Builder + self.decorator = decorator + return self +end + +function Builder:tag_with(tagger: Tagger): Builder + self.tagger = tagger + return self +end + +function Builder:build(): protocol.PluginHandler + local name = self.name + local arg_key = self.arg_key + local prefix = self.prefix + local remember_flag = self.remember_flag + local decorator = self.decorator + local tagger = self.tagger + + return function(state: protocol.RuntimeState, call: protocol.PluginCall, policy: protocol.RetryPolicy): protocol.PluginResult + local raw = call.input[arg_key] + if type(raw) ~= "string" then + return { + ok = false, + error = { + code = "invalid", + message = name .. " " .. arg_key .. " must be string", + retryable = false, + }, + } + end + + local content = prefix .. ":" .. raw + if decorator then + content = decorator(content, state, call) + end + if remember_flag and state.flags[remember_flag] then + content = content .. ":flag" + end + + local delay = policy.compute_delay(1) + local tags = nil + if tagger then + tags = tagger(policy, state, call) + end + + return { + ok = true, + value = { + plugin = call.plugin, + content = content .. ":" .. tostring(delay), + cached = false, + tags = tags, + }, + } + end +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/policy_builder.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline/policy_builder.lua new file mode 100644 index 00000000..c9bac24f --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/policy_builder.lua @@ -0,0 +1,103 @@ +local result = require("result") +local protocol = require("protocol") + +type RetryPolicy = protocol.RetryPolicy + +type PolicyBuilder = { + label: string, + attempts: integer, + factor: number, + retryable_codes: {result.ErrorCode}, + backoff: (integer) -> number, + named: (self: PolicyBuilder, label: string) -> PolicyBuilder, + max_attempts: (self: PolicyBuilder, attempts: integer) -> PolicyBuilder, + scale_by: (self: PolicyBuilder, factor: number) -> PolicyBuilder, + retry_on: (self: PolicyBuilder, code: result.ErrorCode) -> PolicyBuilder, + with_backoff: (self: PolicyBuilder, backoff: (integer) -> number) -> PolicyBuilder, + build: (self: PolicyBuilder) -> RetryPolicy, +} + +type Builder = PolicyBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.PolicyBuilder = PolicyBuilder + +function M.new(): PolicyBuilder + local self: Builder = { + label = "default", + attempts = 1, + factor = 1, + retryable_codes = {}, + backoff = function(attempt: integer): number + return attempt + end, + named = Builder.named, + max_attempts = Builder.max_attempts, + scale_by = Builder.scale_by, + retry_on = Builder.retry_on, + with_backoff = Builder.with_backoff, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(label: string): Builder + self.label = label + return self +end + +function Builder:max_attempts(attempts: integer): Builder + self.attempts = attempts + return self +end + +function Builder:scale_by(factor: number): Builder + self.factor = factor + return self +end + +function Builder:retry_on(code: result.ErrorCode): Builder + table.insert(self.retryable_codes, code) + return self +end + +function Builder:with_backoff(backoff: (integer) -> number): Builder + self.backoff = backoff + return self +end + +function Builder:build(): RetryPolicy + local label = self.label + local attempts = self.attempts + local factor = self.factor + local retryable_codes = self.retryable_codes + local backoff = self.backoff + + return { + label = label, + max_attempts = attempts, + compute_delay = function(attempt: integer): number + return backoff(attempt) * factor + end, + should_retry = function(err: result.AppError, attempt: integer): boolean + if attempt >= attempts then + return false + end + if err.retryable then + return true + end + for _, code in ipairs(retryable_codes) do + if code == err.code then + return true + end + end + return false + end, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/protocol.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline/protocol.lua new file mode 100644 index 00000000..04d32108 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/protocol.lua @@ -0,0 +1,114 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type RetryPolicy = { + label: string, + max_attempts: integer, + compute_delay: (integer) -> number, + should_retry: (AppError, integer) -> boolean, +} + +type RequestMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type PluginCall = { + kind: "plugin_call", + id: string, + plugin: string, + input: {[string]: any}, + meta: RequestMeta, +} + +type HeartbeatEvent = { + kind: "heartbeat", + at: time.Time, +} + +type DoneEvent = { + kind: "done", + status: "ok" | "failed", + reason: string?, + meta: RequestMeta, +} + +type Event = PluginCall | HeartbeatEvent | DoneEvent + +type PluginOutput = { + plugin: string, + content: string, + cached: boolean, + tags: {[string]: string}?, +} + +type PluginStep = { + kind: "plugin", + plugin: string, + output: PluginOutput, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type HookStep = { + kind: "hook", + name: string, + detail: string, +} + +type RuntimeStep = PluginStep | AuditStep | HookStep + +type RuntimeState = { + id: string, + started_at: time.Time, + last_seen: time.Time?, + steps: {RuntimeStep}, + cache: {[string]: PluginOutput}, + flags: {[string]: boolean}, +} + +type RuntimeSummary = { + id: string, + total_steps: number, + cached_count: number, + last_status: string?, + elapsed_seconds: number, + last_plugin: string?, +} + +type Hook = (RuntimeStep, RuntimeState) -> () +type PluginResult = {ok: true, value: PluginOutput} | {ok: false, error: AppError} +type RuntimeResult = {ok: true, value: RuntimeSummary} | {ok: false, error: AppError} +type PluginHandler = (RuntimeState, PluginCall, RetryPolicy) -> PluginResult + +local M = {} +M.AppError = AppError +M.RetryPolicy = RetryPolicy +M.RequestMeta = RequestMeta +M.PluginCall = PluginCall +M.HeartbeatEvent = HeartbeatEvent +M.DoneEvent = DoneEvent +M.Event = Event +M.PluginOutput = PluginOutput +M.RuntimeStep = RuntimeStep +M.RuntimeState = RuntimeState +M.RuntimeSummary = RuntimeSummary +M.Hook = Hook +M.PluginResult = PluginResult +M.RuntimeResult = RuntimeResult +M.PluginHandler = PluginHandler + +function M.meta(trace_id: string, tags: {[string]: string}?): RequestMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/result.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline/result.lua new file mode 100644 index 00000000..d0593064 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/result.lua @@ -0,0 +1,52 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "rate_limited" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +function M.tap_error(r: Result, fn: (AppError) -> ()): Result + if not r.ok then + fn(r.error) + end + return r +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/runtime.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline/runtime.lua new file mode 100644 index 00000000..533c35db --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/runtime.lua @@ -0,0 +1,145 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local runtime_store = require("runtime_store") + +type EventResult = {ok: true, value: string?} | {ok: false, error: result.AppError} + +type Runtime = { + handlers: {[string]: protocol.PluginHandler}, + hooks: {protocol.Hook}, + default_policy: protocol.RetryPolicy, + register_plugin: (self: Runtime, name: string, handler: protocol.PluginHandler) -> Runtime, + on_step: (self: Runtime, hook: protocol.Hook) -> Runtime, + new_store: (self: Runtime, id: string, now: time.Time) -> runtime_store.RuntimeStore, + emit: (self: Runtime, store: runtime_store.RuntimeStore, step: protocol.RuntimeStep, at: time.Time) -> (), + handle_event: (self: Runtime, store: runtime_store.RuntimeStore, event: protocol.Event, at: time.Time) -> EventResult, + run: (self: Runtime, store: runtime_store.RuntimeStore, events: {protocol.Event}, now: time.Time) -> protocol.RuntimeResult, +} + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.Runtime = Runtime + +function M.new(policy: protocol.RetryPolicy): Runtime + local self: Runtime = { + handlers = {}, + hooks = {}, + default_policy = policy, + register_plugin = Runtime.register_plugin, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + handle_event = Runtime.handle_event, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:register_plugin(name: string, handler: protocol.PluginHandler): Runtime + self.handlers[name] = handler + return self +end + +function Runtime:on_step(hook: protocol.Hook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): runtime_store.RuntimeStore + return runtime_store.new(id, now) +end + +function Runtime:emit(store: runtime_store.RuntimeStore, step: protocol.RuntimeStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:handle_event(store: runtime_store.RuntimeStore, event: protocol.Event, at: time.Time): EventResult + if event.kind == "heartbeat" then + self:emit(store, {kind = "audit", note = "heartbeat", at = event.at}, at) + return {ok = true, value = nil} + end + + if event.kind == "done" then + local reason = event.reason or event.status + self:emit(store, {kind = "audit", note = "done:" .. reason, at = at}, at) + return {ok = true, value = event.status} + end + + local cached = store:lookup(event.plugin) + if cached then + self:emit(store, { + kind = "plugin", + plugin = event.plugin, + output = { + plugin = cached.plugin, + content = cached.content, + cached = true, + tags = cached.tags, + }, + }, at) + store:set_flag("cache_hit") + return {ok = true, value = nil} + end + + local handler = self.handlers[event.plugin] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing plugin: " .. event.plugin, + retryable = false, + }, + } + end + + local plugin_result = handler(store.state, event, self.default_policy) + if plugin_result.ok then + store:remember(plugin_result.value) + self:emit(store, { + kind = "plugin", + plugin = event.plugin, + output = plugin_result.value, + }, at) + return {ok = true, value = nil} + end + + local detail = "stop" + if self.default_policy.should_retry(plugin_result.error, 1) then + detail = "retry" + end + self:emit(store, {kind = "hook", name = "retry", detail = detail}, at) + return {ok = false, error = plugin_result.error} +end + +function Runtime:run( + store: runtime_store.RuntimeStore, + events: {protocol.Event}, + now: time.Time +): protocol.RuntimeResult + local last_status: string? = nil + + for _, event in ipairs(events) do + local event_result = self:handle_event(store, event, now) + if not event_result.ok then + return {ok = false, error = event_result.error} + end + if event_result.value ~= nil then + last_status = event_result.value + end + end + + return { + ok = true, + value = store:summarize(now, last_status), + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-runtime-pipeline/runtime_store.lua b/testdata/fixtures/realworld/plugin-runtime-pipeline/runtime_store.lua new file mode 100644 index 00000000..30af3d42 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-runtime-pipeline/runtime_store.lua @@ -0,0 +1,86 @@ +local time = require("time") +local protocol = require("protocol") + +type RuntimeStore = { + state: protocol.RuntimeState, + touch: (self: RuntimeStore, at: time.Time) -> RuntimeStore, + push_step: (self: RuntimeStore, step: protocol.RuntimeStep, at: time.Time) -> RuntimeStore, + remember: (self: RuntimeStore, output: protocol.PluginOutput) -> (), + lookup: (self: RuntimeStore, name: string) -> protocol.PluginOutput?, + set_flag: (self: RuntimeStore, name: string) -> (), + summarize: (self: RuntimeStore, now: time.Time, last_status: string?) -> protocol.RuntimeSummary, +} + +type Store = RuntimeStore + +local Store = {} +Store.__index = Store + +local M = {} +M.RuntimeStore = RuntimeStore + +function M.new(id: string, now: time.Time): RuntimeStore + local self: Store = { + state = { + id = id, + started_at = now, + last_seen = nil, + steps = {}, + cache = {}, + flags = {}, + }, + touch = Store.touch, + push_step = Store.push_step, + remember = Store.remember, + lookup = Store.lookup, + set_flag = Store.set_flag, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:touch(at: time.Time): Store + self.state.last_seen = at + return self +end + +function Store:push_step(step: protocol.RuntimeStep, at: time.Time): Store + table.insert(self.state.steps, step) + return self:touch(at) +end + +function Store:remember(output: protocol.PluginOutput) + self.state.cache[output.plugin] = output +end + +function Store:lookup(name: string): protocol.PluginOutput? + return self.state.cache[name] +end + +function Store:set_flag(name: string) + self.state.flags[name] = true +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.RuntimeSummary + local cached_count = 0 + local last_plugin: string? = nil + for name, _ in pairs(self.state.cache) do + cached_count = cached_count + 1 + last_plugin = name + end + + local since = self.state.last_seen or self.state.started_at + local elapsed = now:sub(since) + + return { + id = self.state.id, + total_steps = #self.state.steps, + cached_count = cached_count, + last_status = last_status, + elapsed_seconds = elapsed:seconds(), + last_plugin = last_plugin, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/fallback_builder.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/fallback_builder.lua new file mode 100644 index 00000000..a8992e84 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/fallback_builder.lua @@ -0,0 +1,100 @@ +local protocol = require("protocol") +local result = require("result") + +type NoteDecorator = (string, protocol.DispatchRequest, result.AppError) -> string + +type FallbackBuilder = { + plugin_name: string?, + retry_code: result.ErrorCode?, + queue_name: string?, + decorator: NoteDecorator?, + for_plugin: (self: FallbackBuilder, plugin_name: string) -> FallbackBuilder, + retry_on: (self: FallbackBuilder, code: result.ErrorCode) -> FallbackBuilder, + queue_named: (self: FallbackBuilder, queue_name: string) -> FallbackBuilder, + decorate_note: (self: FallbackBuilder, fn: NoteDecorator) -> FallbackBuilder, + build: (self: FallbackBuilder) -> protocol.FallbackHandler, +} + +type Builder = FallbackBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.FallbackBuilder = FallbackBuilder + +function M.new(): FallbackBuilder + local self: Builder = { + plugin_name = nil, + retry_code = nil, + queue_name = nil, + decorator = nil, + for_plugin = Builder.for_plugin, + retry_on = Builder.retry_on, + queue_named = Builder.queue_named, + decorate_note = Builder.decorate_note, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:for_plugin(plugin_name: string): Builder + self.plugin_name = plugin_name + return self +end + +function Builder:retry_on(code: result.ErrorCode): Builder + self.retry_code = code + return self +end + +function Builder:queue_named(queue_name: string): Builder + self.queue_name = queue_name + return self +end + +function Builder:decorate_note(fn: NoteDecorator): Builder + self.decorator = fn + return self +end + +function Builder:build(): protocol.FallbackHandler + local plugin_name = self.plugin_name + local retry_code = self.retry_code + local queue_name = self.queue_name or "retry" + local decorator = self.decorator + + return function( + state: protocol.StoreState, + request: protocol.DispatchRequest, + err: result.AppError, + at: time.Time + ): protocol.FallbackResult + if plugin_name and request.plugin ~= plugin_name then + return {ok = true, value = nil} + end + + if retry_code and err.code ~= retry_code then + return {ok = true, value = nil} + end + + state.flags["saw_fallback"] = true + + local note = queue_name .. ":" .. request.plugin .. ":" .. request.envelope.id + if decorator then + note = decorator(note, request, err) + end + + return { + ok = true, + value = { + queue = queue_name, + note = note, + retry_at = at, + }, + } + end +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/handler_builder.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/handler_builder.lua new file mode 100644 index 00000000..f29b06bd --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/handler_builder.lua @@ -0,0 +1,196 @@ +local time = require("time") +local protocol = require("protocol") +local result = require("result") + +type LabelDecorator = (string, protocol.PayloadEnvelope, protocol.StoreState) -> string + +type HandlerBuilder = { + name: string?, + prefix: string?, + required_tag: string?, + remembered_flag: string?, + failure_tag: string?, + failure_code: result.ErrorCode?, + decorator: LabelDecorator?, + named: (self: HandlerBuilder, name: string) -> HandlerBuilder, + prefix_with: (self: HandlerBuilder, prefix: string) -> HandlerBuilder, + require_tag: (self: HandlerBuilder, key: string) -> HandlerBuilder, + remember_flag: (self: HandlerBuilder, flag: string) -> HandlerBuilder, + fail_on_tag: (self: HandlerBuilder, key: string, code: result.ErrorCode) -> HandlerBuilder, + decorate: (self: HandlerBuilder, fn: LabelDecorator) -> HandlerBuilder, + build: (self: HandlerBuilder) -> protocol.PluginHandler, +} + +type Builder = HandlerBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.HandlerBuilder = HandlerBuilder + +function M.new(): HandlerBuilder + local self: Builder = { + name = nil, + prefix = nil, + required_tag = nil, + remembered_flag = nil, + failure_tag = nil, + failure_code = nil, + decorator = nil, + named = Builder.named, + prefix_with = Builder.prefix_with, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + fail_on_tag = Builder.fail_on_tag, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self +end + +function Builder:require_tag(key: string): Builder + self.required_tag = key + return self +end + +function Builder:remember_flag(flag: string): Builder + self.remembered_flag = flag + return self +end + +function Builder:fail_on_tag(key: string, code: result.ErrorCode): Builder + self.failure_tag = key + self.failure_code = code + return self +end + +function Builder:decorate(fn: LabelDecorator): Builder + self.decorator = fn + return self +end + +function Builder:build(): protocol.PluginHandler + local name = self.name or "plugin" + local prefix = self.prefix or name + local required_tag = self.required_tag + local remembered_flag = self.remembered_flag + local failure_tag = self.failure_tag + local failure_code = self.failure_code + local decorator = self.decorator + + return function( + state: protocol.StoreState, + envelope: protocol.PayloadEnvelope, + at: time.Time + ): protocol.DispatchResult + local tags = envelope.meta.tags + + if required_tag then + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. ": missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. ": missing tag " .. required_tag, + retryable = false, + }, + } + end + end + + if failure_tag and tags then + local value = tags[failure_tag] + if value then + return { + ok = false, + error = { + code = failure_code or "busy", + message = name .. ": tag requested retry", + retryable = true, + }, + } + end + end + + if remembered_flag then + state.flags[remembered_flag] = true + end + + local payload = envelope.payload + local receipt: protocol.OutputReceipt + + if payload.kind == "render" then + local subject = payload.values["subject"] or payload.template + local body = prefix .. ":" .. payload.template .. ":" .. subject + if decorator then + body = decorator(body, envelope, state) + end + receipt = { + plugin = name, + envelope_id = envelope.id, + output = { + kind = "rendered", + body = body, + label = prefix, + }, + emitted_at = at, + cached = false, + } + elseif payload.kind == "index" then + receipt = { + plugin = name, + envelope_id = envelope.id, + output = { + kind = "indexed", + count = #payload.terms, + }, + emitted_at = at, + cached = false, + } + else + local note = prefix .. ":" .. payload.action .. ":" .. payload.actor_id + if decorator then + note = decorator(note, envelope, state) + end + receipt = { + plugin = name, + envelope_id = envelope.id, + output = { + kind = "audited", + note = note, + retry_after = nil, + }, + emitted_at = at, + cached = false, + } + end + + return {ok = true, value = receipt} + end +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/helpers.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/helpers.lua new file mode 100644 index 00000000..5d73df61 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/helpers.lua @@ -0,0 +1,35 @@ +local protocol = require("protocol") + +local M = {} + +function M.cache_key(request: protocol.DispatchRequest): string + return request.plugin .. ":" .. request.envelope.id +end + +function M.source_tag(envelope: protocol.PayloadEnvelope): string + local tags = envelope.meta.tags + if not tags then + return "unknown" + end + + local source = tags["source"] + if not source then + return "unknown" + end + + return source +end + +function M.output_label(output: protocol.Output): string + if output.kind == "rendered" then + return output.body + end + + if output.kind == "indexed" then + return tostring(output.count) + end + + return output.note +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/main.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/main.lua new file mode 100644 index 00000000..65f09c05 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/main.lua @@ -0,0 +1,83 @@ +local time = require("time") +local protocol = require("protocol") +local validator_builder = require("validator_builder") +local handler_builder = require("handler_builder") +local fallback_builder = require("fallback_builder") +local runtime = require("runtime") + +local now = time.now() + +local source_validator = validator_builder.new() + :named("source") + :require_tag("source") + :build() + +local render_handler = handler_builder.new() + :named("render") + :prefix_with("render") + :remember_flag("did_render") + :build() + +local audit_handler = handler_builder.new() + :named("audit") + :prefix_with("audit") + :fail_on_tag("retry", "busy") + :build() + +local retry_fallback = fallback_builder.new() + :for_plugin("audit") + :retry_on("busy") + :queue_named("audit-retry") + :build() + +local app = runtime.new() + :use_validator(source_validator) + :register_handler("render", render_handler) + :register_handler("audit", audit_handler) + :use_fallback(retry_fallback) + +local render_request: protocol.DispatchRequest = { + kind = "dispatch", + plugin = "render", + envelope = { + id = "req-1", + tenant_id = "tenant-a", + payload = { + kind = "render", + template = "welcome", + values = {subject = "hello"}, + }, + meta = protocol.meta("trace-1", {source = "api"}), + }, +} + +local audit_request: protocol.DispatchRequest = { + kind = "dispatch", + plugin = "audit", + envelope = { + id = "req-2", + tenant_id = "tenant-a", + payload = { + kind = "audit", + action = "login", + actor_id = "actor-7", + }, + meta = protocol.meta("trace-2", {source = "worker", retry = "true"}), + }, +} + +local store = app:new_store("supervisor-1", now) +local _rendered = app:dispatch(store, render_request, now) +local _ = app:dispatch(store, audit_request, now) + +local missing_handler = app.handlers["missing"] +local missing_result = missing_handler(store.state, render_request.envelope, now) -- expect-error + +local cached_render = store:lookup_cached("render:req-1") +if cached_render then + local seen_seconds: integer = cached_render.seen_at -- expect-error + local cached_flag: boolean = store.state.flags["did_index"] -- expect-error +end +local fallback_total: integer = store.state.plugin_counts["fallback"] -- expect-error +local last_audit: string = store.state.audit_tags["last"] -- expect-error +local elapsed = now:sub(store.state.last_tick) -- expect-error diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/manifest.json b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/manifest.json new file mode 100644 index 00000000..ed0d6db1 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/manifest.json @@ -0,0 +1,18 @@ +{ + "description": "Soundness coverage for plugin supervisor runtime: optional cached receipts, optional maps, optional times, missing handler lookup, and union output fields must not be treated as definite.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "plugin_store.lua", + "validator_builder.lua", + "handler_builder.lua", + "fallback_builder.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 6 + } +} diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/plugin_store.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/plugin_store.lua new file mode 100644 index 00000000..67515053 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/plugin_store.lua @@ -0,0 +1,98 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") + +type PluginStore = { + state: protocol.StoreState, + cache_receipt: (self: PluginStore, request: protocol.DispatchRequest, receipt: protocol.OutputReceipt, at: time.Time) -> (), + lookup_cached: (self: PluginStore, key: string) -> protocol.CachedReceipt?, + push_step: (self: PluginStore, step: protocol.RuntimeStep, at: time.Time) -> (), + set_flag: (self: PluginStore, flag: string) -> (), + summarize: (self: PluginStore, now: time.Time, last_output_kind: string?) -> protocol.RunSummary, +} + +type Store = PluginStore + +local Store = {} +Store.__index = Store + +local M = {} +M.PluginStore = PluginStore + +function M.new(id: string, now: time.Time): PluginStore + local self: Store = { + state = { + id = id, + started_at = now, + last_tick = nil, + last_dispatch_at = nil, + cached_receipts = {}, + plugin_counts = {}, + flags = {}, + audit_tags = {}, + steps = {}, + }, + cache_receipt = Store.cache_receipt, + lookup_cached = Store.lookup_cached, + push_step = Store.push_step, + set_flag = Store.set_flag, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:cache_receipt( + request: protocol.DispatchRequest, + receipt: protocol.OutputReceipt, + at: time.Time +) + self.state.cached_receipts[helpers.cache_key(request)] = { + value = receipt, + seen_at = at, + } + self.state.last_dispatch_at = at +end + +function Store:lookup_cached(key: string): protocol.CachedReceipt? + return self.state.cached_receipts[key] +end + +function Store:push_step(step: protocol.RuntimeStep, at: time.Time) + table.insert(self.state.steps, step) + self.state.last_dispatch_at = at + + if step.kind == "dispatch" then + local current = self.state.plugin_counts[step.plugin] or 0 + self.state.plugin_counts[step.plugin] = current + 1 + elseif step.kind == "cached" then + local current = self.state.plugin_counts["cached"] or 0 + self.state.plugin_counts["cached"] = current + 1 + elseif step.kind == "fallback" then + local current = self.state.plugin_counts["fallback"] or 0 + self.state.plugin_counts["fallback"] = current + 1 + self.state.audit_tags["last"] = step.note + else + self.state.audit_tags["last"] = step.note + end +end + +function Store:set_flag(flag: string) + self.state.flags[flag] = true +end + +function Store:summarize(now: time.Time, last_output_kind: string?): protocol.RunSummary + local cached_hits = self.state.plugin_counts["cached"] or 0 + local fallback_count = self.state.plugin_counts["fallback"] or 0 + + return { + id = self.state.id, + processed = #self.state.steps, + cached_hits = cached_hits, + fallback_count = fallback_count, + elapsed_seconds = now:sub(self.state.started_at), + last_output_kind = last_output_kind, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/protocol.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/protocol.lua new file mode 100644 index 00000000..1e1cacf8 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/protocol.lua @@ -0,0 +1,200 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError +type ErrorCode = result.ErrorCode + +type RequestMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type Envelope = { + id: string, + tenant_id: string, + payload: T, + meta: RequestMeta, +} + +type RenderPayload = { + kind: "render", + template: string, + values: {[string]: string}, +} + +type IndexPayload = { + kind: "index", + document_id: string, + terms: {string}, +} + +type AuditPayload = { + kind: "audit", + action: string, + actor_id: string, +} + +type Payload = RenderPayload | IndexPayload | AuditPayload +type PayloadEnvelope = Envelope + +type DispatchRequest = { + kind: "dispatch", + plugin: string, + envelope: PayloadEnvelope, +} + +type TickRequest = { + kind: "tick", + at: time.Time, +} + +type Request = DispatchRequest | TickRequest + +type RenderOutput = { + kind: "rendered", + body: string, + label: string?, +} + +type IndexOutput = { + kind: "indexed", + count: integer, +} + +type AuditOutput = { + kind: "audited", + note: string, + retry_after: time.Time?, +} + +type Output = RenderOutput | IndexOutput | AuditOutput + +type Receipt = { + plugin: string, + envelope_id: string, + output: T, + emitted_at: time.Time, + cached: boolean, +} + +type Cached = { + value: T, + seen_at: time.Time, +} + +type OutputReceipt = Receipt +type CachedReceipt = Cached + +type DispatchStep = { + kind: "dispatch", + plugin: string, + output: Output, + cached: boolean, +} + +type CachedStep = { + kind: "cached", + plugin: string, + envelope_id: string, + at: time.Time, +} + +type FallbackStep = { + kind: "fallback", + plugin: string, + queue: string, + note: string, + retry_at: time.Time, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type RuntimeStep = DispatchStep | CachedStep | FallbackStep | AuditStep + +type FallbackPlan = { + queue: string, + note: string, + retry_at: time.Time, +} + +type StoreState = { + id: string, + started_at: time.Time, + last_tick: time.Time?, + last_dispatch_at: time.Time?, + cached_receipts: {[string]: CachedReceipt}, + plugin_counts: {[string]: integer}, + flags: {[string]: boolean}, + audit_tags: {[string]: string}, + steps: {RuntimeStep}, +} + +type RunSummary = { + id: string, + processed: number, + cached_hits: number, + fallback_count: number, + elapsed_seconds: time.Duration, + last_output_kind: string?, +} + +type ValidationResult = {ok: true, value: DispatchRequest} | {ok: false, error: AppError} +type DispatchResult = {ok: true, value: OutputReceipt?} | {ok: false, error: AppError} +type FallbackResult = {ok: true, value: FallbackPlan?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type RequestValidator = (StoreState, DispatchRequest) -> ValidationResult +type PluginHandler = (StoreState, PayloadEnvelope, time.Time) -> DispatchResult +type FallbackHandler = (StoreState, DispatchRequest, AppError, time.Time) -> FallbackResult +type StepHook = (RuntimeStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.ErrorCode = ErrorCode +M.RequestMeta = RequestMeta +M.Envelope = Envelope +M.RenderPayload = RenderPayload +M.IndexPayload = IndexPayload +M.AuditPayload = AuditPayload +M.Payload = Payload +M.PayloadEnvelope = PayloadEnvelope +M.DispatchRequest = DispatchRequest +M.TickRequest = TickRequest +M.Request = Request +M.RenderOutput = RenderOutput +M.IndexOutput = IndexOutput +M.AuditOutput = AuditOutput +M.Output = Output +M.Receipt = Receipt +M.Cached = Cached +M.OutputReceipt = OutputReceipt +M.CachedReceipt = CachedReceipt +M.DispatchStep = DispatchStep +M.CachedStep = CachedStep +M.FallbackStep = FallbackStep +M.AuditStep = AuditStep +M.RuntimeStep = RuntimeStep +M.FallbackPlan = FallbackPlan +M.StoreState = StoreState +M.RunSummary = RunSummary +M.ValidationResult = ValidationResult +M.DispatchResult = DispatchResult +M.FallbackResult = FallbackResult +M.RunResult = RunResult +M.RequestValidator = RequestValidator +M.PluginHandler = PluginHandler +M.FallbackHandler = FallbackHandler +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): RequestMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/result.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/runtime.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/runtime.lua new file mode 100644 index 00000000..b250c299 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/runtime.lua @@ -0,0 +1,200 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local plugin_store = require("plugin_store") + +type SupervisorRuntime = { + validators: {protocol.RequestValidator}, + handlers: {[string]: protocol.PluginHandler}, + fallbacks: {protocol.FallbackHandler}, + hooks: {protocol.StepHook}, + use_validator: (self: SupervisorRuntime, validator: protocol.RequestValidator) -> SupervisorRuntime, + register_handler: (self: SupervisorRuntime, plugin_name: string, handler: protocol.PluginHandler) -> SupervisorRuntime, + use_fallback: (self: SupervisorRuntime, fallback: protocol.FallbackHandler) -> SupervisorRuntime, + on_step: (self: SupervisorRuntime, hook: protocol.StepHook) -> SupervisorRuntime, + new_store: (self: SupervisorRuntime, id: string, now: time.Time) -> plugin_store.PluginStore, + emit: (self: SupervisorRuntime, store: plugin_store.PluginStore, step: protocol.RuntimeStep, at: time.Time) -> (), + dispatch: (self: SupervisorRuntime, store: plugin_store.PluginStore, request: protocol.Request, at: time.Time) -> protocol.DispatchResult, + run: (self: SupervisorRuntime, store: plugin_store.PluginStore, requests: {protocol.Request}, now: time.Time) -> protocol.RunResult, +} + +type Runtime = SupervisorRuntime + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.SupervisorRuntime = SupervisorRuntime + +function M.new(): SupervisorRuntime + local self: Runtime = { + validators = {}, + handlers = {}, + fallbacks = {}, + hooks = {}, + use_validator = Runtime.use_validator, + register_handler = Runtime.register_handler, + use_fallback = Runtime.use_fallback, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + dispatch = Runtime.dispatch, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:use_validator(validator: protocol.RequestValidator): Runtime + table.insert(self.validators, validator) + return self +end + +function Runtime:register_handler(plugin_name: string, handler: protocol.PluginHandler): Runtime + self.handlers[plugin_name] = handler + return self +end + +function Runtime:use_fallback(fallback: protocol.FallbackHandler): Runtime + table.insert(self.fallbacks, fallback) + return self +end + +function Runtime:on_step(hook: protocol.StepHook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): plugin_store.PluginStore + return plugin_store.new(id, now) +end + +function Runtime:emit(store: plugin_store.PluginStore, step: protocol.RuntimeStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:dispatch( + store: plugin_store.PluginStore, + request: protocol.Request, + at: time.Time +): protocol.DispatchResult + if request.kind == "tick" then + store.state.last_tick = request.at + local audit_step: protocol.AuditStep = { + kind = "audit", + note = "tick", + at = request.at, + } + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + for _, validator in ipairs(self.validators) do + local validation = validator(store.state, request) + if not validation.ok then + return {ok = false, error = validation.error} + end + end + + local cached = store:lookup_cached(helpers.cache_key(request)) + if cached then + local receipt: protocol.OutputReceipt = { + plugin = cached.value.plugin, + envelope_id = cached.value.envelope_id, + output = cached.value.output, + emitted_at = cached.value.emitted_at, + cached = true, + } + local cached_step: protocol.CachedStep = { + kind = "cached", + plugin = request.plugin, + envelope_id = request.envelope.id, + at = at, + } + self:emit(store, cached_step, at) + return {ok = true, value = receipt} + end + + local handler = self.handlers[request.plugin] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing handler: " .. request.plugin, + retryable = false, + }, + } + end + + local handled = handler(store.state, request.envelope, at) + if not handled.ok then + for _, fallback in ipairs(self.fallbacks) do + local fallback_result = fallback(store.state, request, handled.error, at) + if not fallback_result.ok then + return {ok = false, error = fallback_result.error} + end + + local plan = fallback_result.value + if plan then + local fallback_step: protocol.FallbackStep = { + kind = "fallback", + plugin = request.plugin, + queue = plan.queue, + note = plan.note, + retry_at = plan.retry_at, + } + self:emit(store, fallback_step, at) + return {ok = true, value = nil} + end + end + + return {ok = false, error = handled.error} + end + + local receipt = handled.value + if not receipt then + return {ok = true, value = nil} + end + + store:cache_receipt(request, receipt, at) + + local dispatch_step: protocol.DispatchStep = { + kind = "dispatch", + plugin = request.plugin, + output = receipt.output, + cached = receipt.cached, + } + self:emit(store, dispatch_step, at) + return {ok = true, value = receipt} +end + +function Runtime:run( + store: plugin_store.PluginStore, + requests: {protocol.Request}, + now: time.Time +): protocol.RunResult + local last_output_kind: string? = nil + + for _, request in ipairs(requests) do + local dispatch_result = self:dispatch(store, request, now) + if not dispatch_result.ok then + return {ok = false, error = dispatch_result.error} + end + + local receipt = dispatch_result.value + if receipt then + last_output_kind = receipt.output.kind + end + end + + return { + ok = true, + value = store:summarize(now, last_output_kind), + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/validator_builder.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/validator_builder.lua new file mode 100644 index 00000000..4c603c84 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime-soundness/validator_builder.lua @@ -0,0 +1,90 @@ +local protocol = require("protocol") + +type ValidatorBuilder = { + name: string?, + required_tag: string?, + remembered_flag: string?, + named: (self: ValidatorBuilder, name: string) -> ValidatorBuilder, + require_tag: (self: ValidatorBuilder, key: string) -> ValidatorBuilder, + remember_flag: (self: ValidatorBuilder, flag: string) -> ValidatorBuilder, + build: (self: ValidatorBuilder) -> protocol.RequestValidator, +} + +type Builder = ValidatorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ValidatorBuilder = ValidatorBuilder + +function M.new(): ValidatorBuilder + local self: Builder = { + name = nil, + required_tag = nil, + remembered_flag = nil, + named = Builder.named, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_tag(key: string): Builder + self.required_tag = key + return self +end + +function Builder:remember_flag(flag: string): Builder + self.remembered_flag = flag + return self +end + +function Builder:build(): protocol.RequestValidator + local name = self.name or "validator" + local required_tag = self.required_tag + local remembered_flag = self.remembered_flag + + return function(state: protocol.StoreState, request: protocol.DispatchRequest): protocol.ValidationResult + if required_tag then + local tags = request.envelope.meta.tags + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. ": missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. ": missing tag " .. required_tag, + retryable = false, + }, + } + end + end + + if remembered_flag then + state.flags[remembered_flag] = true + end + + return {ok = true, value = request} + end +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/fallback_builder.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/fallback_builder.lua new file mode 100644 index 00000000..a8992e84 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/fallback_builder.lua @@ -0,0 +1,100 @@ +local protocol = require("protocol") +local result = require("result") + +type NoteDecorator = (string, protocol.DispatchRequest, result.AppError) -> string + +type FallbackBuilder = { + plugin_name: string?, + retry_code: result.ErrorCode?, + queue_name: string?, + decorator: NoteDecorator?, + for_plugin: (self: FallbackBuilder, plugin_name: string) -> FallbackBuilder, + retry_on: (self: FallbackBuilder, code: result.ErrorCode) -> FallbackBuilder, + queue_named: (self: FallbackBuilder, queue_name: string) -> FallbackBuilder, + decorate_note: (self: FallbackBuilder, fn: NoteDecorator) -> FallbackBuilder, + build: (self: FallbackBuilder) -> protocol.FallbackHandler, +} + +type Builder = FallbackBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.FallbackBuilder = FallbackBuilder + +function M.new(): FallbackBuilder + local self: Builder = { + plugin_name = nil, + retry_code = nil, + queue_name = nil, + decorator = nil, + for_plugin = Builder.for_plugin, + retry_on = Builder.retry_on, + queue_named = Builder.queue_named, + decorate_note = Builder.decorate_note, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:for_plugin(plugin_name: string): Builder + self.plugin_name = plugin_name + return self +end + +function Builder:retry_on(code: result.ErrorCode): Builder + self.retry_code = code + return self +end + +function Builder:queue_named(queue_name: string): Builder + self.queue_name = queue_name + return self +end + +function Builder:decorate_note(fn: NoteDecorator): Builder + self.decorator = fn + return self +end + +function Builder:build(): protocol.FallbackHandler + local plugin_name = self.plugin_name + local retry_code = self.retry_code + local queue_name = self.queue_name or "retry" + local decorator = self.decorator + + return function( + state: protocol.StoreState, + request: protocol.DispatchRequest, + err: result.AppError, + at: time.Time + ): protocol.FallbackResult + if plugin_name and request.plugin ~= plugin_name then + return {ok = true, value = nil} + end + + if retry_code and err.code ~= retry_code then + return {ok = true, value = nil} + end + + state.flags["saw_fallback"] = true + + local note = queue_name .. ":" .. request.plugin .. ":" .. request.envelope.id + if decorator then + note = decorator(note, request, err) + end + + return { + ok = true, + value = { + queue = queue_name, + note = note, + retry_at = at, + }, + } + end +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/handler_builder.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/handler_builder.lua new file mode 100644 index 00000000..f29b06bd --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/handler_builder.lua @@ -0,0 +1,196 @@ +local time = require("time") +local protocol = require("protocol") +local result = require("result") + +type LabelDecorator = (string, protocol.PayloadEnvelope, protocol.StoreState) -> string + +type HandlerBuilder = { + name: string?, + prefix: string?, + required_tag: string?, + remembered_flag: string?, + failure_tag: string?, + failure_code: result.ErrorCode?, + decorator: LabelDecorator?, + named: (self: HandlerBuilder, name: string) -> HandlerBuilder, + prefix_with: (self: HandlerBuilder, prefix: string) -> HandlerBuilder, + require_tag: (self: HandlerBuilder, key: string) -> HandlerBuilder, + remember_flag: (self: HandlerBuilder, flag: string) -> HandlerBuilder, + fail_on_tag: (self: HandlerBuilder, key: string, code: result.ErrorCode) -> HandlerBuilder, + decorate: (self: HandlerBuilder, fn: LabelDecorator) -> HandlerBuilder, + build: (self: HandlerBuilder) -> protocol.PluginHandler, +} + +type Builder = HandlerBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.HandlerBuilder = HandlerBuilder + +function M.new(): HandlerBuilder + local self: Builder = { + name = nil, + prefix = nil, + required_tag = nil, + remembered_flag = nil, + failure_tag = nil, + failure_code = nil, + decorator = nil, + named = Builder.named, + prefix_with = Builder.prefix_with, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + fail_on_tag = Builder.fail_on_tag, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self +end + +function Builder:require_tag(key: string): Builder + self.required_tag = key + return self +end + +function Builder:remember_flag(flag: string): Builder + self.remembered_flag = flag + return self +end + +function Builder:fail_on_tag(key: string, code: result.ErrorCode): Builder + self.failure_tag = key + self.failure_code = code + return self +end + +function Builder:decorate(fn: LabelDecorator): Builder + self.decorator = fn + return self +end + +function Builder:build(): protocol.PluginHandler + local name = self.name or "plugin" + local prefix = self.prefix or name + local required_tag = self.required_tag + local remembered_flag = self.remembered_flag + local failure_tag = self.failure_tag + local failure_code = self.failure_code + local decorator = self.decorator + + return function( + state: protocol.StoreState, + envelope: protocol.PayloadEnvelope, + at: time.Time + ): protocol.DispatchResult + local tags = envelope.meta.tags + + if required_tag then + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. ": missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. ": missing tag " .. required_tag, + retryable = false, + }, + } + end + end + + if failure_tag and tags then + local value = tags[failure_tag] + if value then + return { + ok = false, + error = { + code = failure_code or "busy", + message = name .. ": tag requested retry", + retryable = true, + }, + } + end + end + + if remembered_flag then + state.flags[remembered_flag] = true + end + + local payload = envelope.payload + local receipt: protocol.OutputReceipt + + if payload.kind == "render" then + local subject = payload.values["subject"] or payload.template + local body = prefix .. ":" .. payload.template .. ":" .. subject + if decorator then + body = decorator(body, envelope, state) + end + receipt = { + plugin = name, + envelope_id = envelope.id, + output = { + kind = "rendered", + body = body, + label = prefix, + }, + emitted_at = at, + cached = false, + } + elseif payload.kind == "index" then + receipt = { + plugin = name, + envelope_id = envelope.id, + output = { + kind = "indexed", + count = #payload.terms, + }, + emitted_at = at, + cached = false, + } + else + local note = prefix .. ":" .. payload.action .. ":" .. payload.actor_id + if decorator then + note = decorator(note, envelope, state) + end + receipt = { + plugin = name, + envelope_id = envelope.id, + output = { + kind = "audited", + note = note, + retry_after = nil, + }, + emitted_at = at, + cached = false, + } + end + + return {ok = true, value = receipt} + end +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/helpers.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/helpers.lua new file mode 100644 index 00000000..5d73df61 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/helpers.lua @@ -0,0 +1,35 @@ +local protocol = require("protocol") + +local M = {} + +function M.cache_key(request: protocol.DispatchRequest): string + return request.plugin .. ":" .. request.envelope.id +end + +function M.source_tag(envelope: protocol.PayloadEnvelope): string + local tags = envelope.meta.tags + if not tags then + return "unknown" + end + + local source = tags["source"] + if not source then + return "unknown" + end + + return source +end + +function M.output_label(output: protocol.Output): string + if output.kind == "rendered" then + return output.body + end + + if output.kind == "indexed" then + return tostring(output.count) + end + + return output.note +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/main.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/main.lua new file mode 100644 index 00000000..8eedfc6a --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/main.lua @@ -0,0 +1,288 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local helpers = require("helpers") +local validator_builder = require("validator_builder") +local handler_builder = require("handler_builder") +local fallback_builder = require("fallback_builder") +local runtime = require("runtime") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() + +local observed_outputs: {[string]: string} = {} +local observed_fallbacks: {string} = {} +local observed_audits: {string} = {} +local last_runtime_id: string? = nil + +local source_validator = validator_builder.new() + :named("source") + :require_tag("source") + :remember_flag("saw_source") + :build() + +local scope_validator = validator_builder.new() + :named("scope") + :require_tag("scope") + :remember_flag("saw_scope") + :build() + +local render_handler = handler_builder.new() + :named("render") + :prefix_with("render") + :require_tag("source") + :remember_flag("did_render") + :decorate(function(label: string, envelope: protocol.PayloadEnvelope, _state: protocol.StoreState): string + if envelope.payload.kind == "render" then + return label .. ":" .. envelope.tenant_id + end + return label + end) + :build() + +local index_handler = handler_builder.new() + :named("index") + :prefix_with("index") + :require_tag("scope") + :remember_flag("did_index") + :decorate(function(label: string, envelope: protocol.PayloadEnvelope, _state: protocol.StoreState): string + if envelope.payload.kind == "index" then + return label .. ":" .. envelope.payload.document_id + end + return label + end) + :build() + +local audit_handler = handler_builder.new() + :named("audit") + :prefix_with("audit") + :remember_flag("did_audit") + :fail_on_tag("retry", "busy") + :decorate(function(label: string, envelope: protocol.PayloadEnvelope, _state: protocol.StoreState): string + if envelope.payload.kind == "audit" then + return label .. ":" .. envelope.payload.actor_id + end + return label + end) + :build() + +local retry_fallback = fallback_builder.new() + :for_plugin("audit") + :retry_on("busy") + :queue_named("audit-retry") + :decorate_note(function(note: string, request: protocol.DispatchRequest, err: result.AppError): string + return note .. ":" .. request.envelope.id .. ":" .. err.code + end) + :build() + +local app = runtime.new() + :use_validator(source_validator) + :use_validator(scope_validator) + :register_handler("render", render_handler) + :register_handler("index", index_handler) + :register_handler("audit", audit_handler) + :use_fallback(retry_fallback) + +app:on_step(function(step: protocol.RuntimeStep, state: protocol.StoreState) + last_runtime_id = state.id + + if step.kind == "dispatch" then + observed_outputs[step.plugin] = helpers.output_label(step.output) + elseif step.kind == "fallback" then + table.insert(observed_fallbacks, step.note) + local retry_seconds: integer = step.retry_at:unix() + else + table.insert(observed_audits, step.note) + local at_seconds: integer = step.at:unix() + end +end) + +local render_request: protocol.DispatchRequest = { + kind = "dispatch", + plugin = "render", + envelope = { + id = "req-1", + tenant_id = "tenant-a", + payload = { + kind = "render", + template = "welcome", + values = {subject = "hello"}, + }, + meta = protocol.meta("trace-1", {source = "api", scope = "render"}), + }, +} + +local repeat_render_request: protocol.DispatchRequest = { + kind = "dispatch", + plugin = "render", + envelope = { + id = "req-1", + tenant_id = "tenant-a", + payload = { + kind = "render", + template = "welcome", + values = {subject = "hello"}, + }, + meta = protocol.meta("trace-2", {source = "api", scope = "render"}), + }, +} + +local index_request: protocol.DispatchRequest = { + kind = "dispatch", + plugin = "index", + envelope = { + id = "req-2", + tenant_id = "tenant-a", + payload = { + kind = "index", + document_id = "doc-1", + terms = {"lua", "types", "cache"}, + }, + meta = protocol.meta("trace-3", {source = "worker", scope = "index"}), + }, +} + +local audit_request: protocol.DispatchRequest = { + kind = "dispatch", + plugin = "audit", + envelope = { + id = "req-3", + tenant_id = "tenant-b", + payload = { + kind = "audit", + action = "login", + actor_id = "actor-7", + }, + meta = protocol.meta("trace-4", {source = "worker", scope = "audit", retry = "true"}), + }, +} + +local tick_request: protocol.TickRequest = { + kind = "tick", + at = now, +} + +local requests: {protocol.Request} = { + render_request, + repeat_render_request, + index_request, + audit_request, + tick_request, +} + +local store = app:new_store("supervisor-1", now) +local summary_result = app:run(store, requests, now) +if not summary_result.ok then + local message: string = summary_result.error.message + local retryable: boolean = summary_result.error.retryable +else + local summary = summary_result.value + local runtime_id: string = summary.id + local processed: number = summary.processed + local cached_hits: number = summary.cached_hits + local fallback_count: number = summary.fallback_count + local elapsed_seconds: time.Duration = summary.elapsed_seconds + local last_output_kind: string? = summary.last_output_kind +end + +local label_result = result.map(summary_result, function(summary: protocol.RunSummary): string + return summary.id .. ":" .. tostring(summary.cached_hits + summary.fallback_count) +end) + +if label_result.ok then + local label: string = label_result.value +end + +local checked_result = result.and_then(summary_result, function(summary: protocol.RunSummary): StringResult + if summary.processed < 4 then + return { + ok = false, + error = { + code = "invalid", + message = "expected processed steps", + retryable = false, + }, + } + end + + return {ok = true, value = summary.id} +end) + +if checked_result.ok then + local checked_id: string = checked_result.value +end + +local cache_key: string = helpers.cache_key(render_request) +local cached_render = store:lookup_cached(cache_key) +if cached_render then + local seen_at: integer = cached_render.seen_at:unix() + local receipt = cached_render.value + local plugin_name: string = receipt.plugin + if receipt.output.kind == "rendered" then + local rendered: protocol.RenderOutput = receipt.output + local body: string = rendered.body + local label = rendered.label + if label then + local stable_label: string = label + end + end +end + +local render_count = store.state.plugin_counts["render"] +if render_count then + local count: integer = render_count +end + +local cached_count = store.state.plugin_counts["cached"] +if cached_count then + local count: integer = cached_count +end + +local fallback_count = store.state.plugin_counts["fallback"] +if fallback_count then + local count: integer = fallback_count +end + +local saw_source = store.state.flags["saw_source"] +if saw_source then + local flag: boolean = saw_source +end + +local saw_fallback = store.state.flags["saw_fallback"] +if saw_fallback then + local flag: boolean = saw_fallback +end + +local last_tick = store.state.last_tick +if last_tick then + local tick_seconds: integer = last_tick:unix() +end + +local last_dispatch = store.state.last_dispatch_at +if last_dispatch then + local elapsed = now:sub(last_dispatch) + local seconds: number = elapsed:seconds() +end + +local last_audit = store.state.audit_tags["last"] +if last_audit then + local note: string = last_audit +end + +for plugin_name, output in pairs(observed_outputs) do + local stable_plugin: string = plugin_name + local stable_output: string = output +end + +for _, note in ipairs(observed_fallbacks) do + local stable_note: string = note +end + +for _, note in ipairs(observed_audits) do + local stable_note: string = note +end + +if last_runtime_id then + local runtime_id: string = last_runtime_id +end diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/manifest.json b/testdata/fixtures/realworld/plugin-supervisor-runtime/manifest.json new file mode 100644 index 00000000..38fb74fa --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/manifest.json @@ -0,0 +1,18 @@ +{ + "description": "Application-shaped plugin supervisor runtime mixing generic envelopes and cached receipts, staged validators, closure-built handlers and fallback policies, metatable-backed store/runtime objects, dynamic registries, callbacks, optional maps, cache hits, and time-based summaries.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "plugin_store.lua", + "validator_builder.lua", + "handler_builder.lua", + "fallback_builder.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/plugin_store.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/plugin_store.lua new file mode 100644 index 00000000..67515053 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/plugin_store.lua @@ -0,0 +1,98 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") + +type PluginStore = { + state: protocol.StoreState, + cache_receipt: (self: PluginStore, request: protocol.DispatchRequest, receipt: protocol.OutputReceipt, at: time.Time) -> (), + lookup_cached: (self: PluginStore, key: string) -> protocol.CachedReceipt?, + push_step: (self: PluginStore, step: protocol.RuntimeStep, at: time.Time) -> (), + set_flag: (self: PluginStore, flag: string) -> (), + summarize: (self: PluginStore, now: time.Time, last_output_kind: string?) -> protocol.RunSummary, +} + +type Store = PluginStore + +local Store = {} +Store.__index = Store + +local M = {} +M.PluginStore = PluginStore + +function M.new(id: string, now: time.Time): PluginStore + local self: Store = { + state = { + id = id, + started_at = now, + last_tick = nil, + last_dispatch_at = nil, + cached_receipts = {}, + plugin_counts = {}, + flags = {}, + audit_tags = {}, + steps = {}, + }, + cache_receipt = Store.cache_receipt, + lookup_cached = Store.lookup_cached, + push_step = Store.push_step, + set_flag = Store.set_flag, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:cache_receipt( + request: protocol.DispatchRequest, + receipt: protocol.OutputReceipt, + at: time.Time +) + self.state.cached_receipts[helpers.cache_key(request)] = { + value = receipt, + seen_at = at, + } + self.state.last_dispatch_at = at +end + +function Store:lookup_cached(key: string): protocol.CachedReceipt? + return self.state.cached_receipts[key] +end + +function Store:push_step(step: protocol.RuntimeStep, at: time.Time) + table.insert(self.state.steps, step) + self.state.last_dispatch_at = at + + if step.kind == "dispatch" then + local current = self.state.plugin_counts[step.plugin] or 0 + self.state.plugin_counts[step.plugin] = current + 1 + elseif step.kind == "cached" then + local current = self.state.plugin_counts["cached"] or 0 + self.state.plugin_counts["cached"] = current + 1 + elseif step.kind == "fallback" then + local current = self.state.plugin_counts["fallback"] or 0 + self.state.plugin_counts["fallback"] = current + 1 + self.state.audit_tags["last"] = step.note + else + self.state.audit_tags["last"] = step.note + end +end + +function Store:set_flag(flag: string) + self.state.flags[flag] = true +end + +function Store:summarize(now: time.Time, last_output_kind: string?): protocol.RunSummary + local cached_hits = self.state.plugin_counts["cached"] or 0 + local fallback_count = self.state.plugin_counts["fallback"] or 0 + + return { + id = self.state.id, + processed = #self.state.steps, + cached_hits = cached_hits, + fallback_count = fallback_count, + elapsed_seconds = now:sub(self.state.started_at), + last_output_kind = last_output_kind, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/protocol.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/protocol.lua new file mode 100644 index 00000000..1e1cacf8 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/protocol.lua @@ -0,0 +1,200 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError +type ErrorCode = result.ErrorCode + +type RequestMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type Envelope = { + id: string, + tenant_id: string, + payload: T, + meta: RequestMeta, +} + +type RenderPayload = { + kind: "render", + template: string, + values: {[string]: string}, +} + +type IndexPayload = { + kind: "index", + document_id: string, + terms: {string}, +} + +type AuditPayload = { + kind: "audit", + action: string, + actor_id: string, +} + +type Payload = RenderPayload | IndexPayload | AuditPayload +type PayloadEnvelope = Envelope + +type DispatchRequest = { + kind: "dispatch", + plugin: string, + envelope: PayloadEnvelope, +} + +type TickRequest = { + kind: "tick", + at: time.Time, +} + +type Request = DispatchRequest | TickRequest + +type RenderOutput = { + kind: "rendered", + body: string, + label: string?, +} + +type IndexOutput = { + kind: "indexed", + count: integer, +} + +type AuditOutput = { + kind: "audited", + note: string, + retry_after: time.Time?, +} + +type Output = RenderOutput | IndexOutput | AuditOutput + +type Receipt = { + plugin: string, + envelope_id: string, + output: T, + emitted_at: time.Time, + cached: boolean, +} + +type Cached = { + value: T, + seen_at: time.Time, +} + +type OutputReceipt = Receipt +type CachedReceipt = Cached + +type DispatchStep = { + kind: "dispatch", + plugin: string, + output: Output, + cached: boolean, +} + +type CachedStep = { + kind: "cached", + plugin: string, + envelope_id: string, + at: time.Time, +} + +type FallbackStep = { + kind: "fallback", + plugin: string, + queue: string, + note: string, + retry_at: time.Time, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type RuntimeStep = DispatchStep | CachedStep | FallbackStep | AuditStep + +type FallbackPlan = { + queue: string, + note: string, + retry_at: time.Time, +} + +type StoreState = { + id: string, + started_at: time.Time, + last_tick: time.Time?, + last_dispatch_at: time.Time?, + cached_receipts: {[string]: CachedReceipt}, + plugin_counts: {[string]: integer}, + flags: {[string]: boolean}, + audit_tags: {[string]: string}, + steps: {RuntimeStep}, +} + +type RunSummary = { + id: string, + processed: number, + cached_hits: number, + fallback_count: number, + elapsed_seconds: time.Duration, + last_output_kind: string?, +} + +type ValidationResult = {ok: true, value: DispatchRequest} | {ok: false, error: AppError} +type DispatchResult = {ok: true, value: OutputReceipt?} | {ok: false, error: AppError} +type FallbackResult = {ok: true, value: FallbackPlan?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type RequestValidator = (StoreState, DispatchRequest) -> ValidationResult +type PluginHandler = (StoreState, PayloadEnvelope, time.Time) -> DispatchResult +type FallbackHandler = (StoreState, DispatchRequest, AppError, time.Time) -> FallbackResult +type StepHook = (RuntimeStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.ErrorCode = ErrorCode +M.RequestMeta = RequestMeta +M.Envelope = Envelope +M.RenderPayload = RenderPayload +M.IndexPayload = IndexPayload +M.AuditPayload = AuditPayload +M.Payload = Payload +M.PayloadEnvelope = PayloadEnvelope +M.DispatchRequest = DispatchRequest +M.TickRequest = TickRequest +M.Request = Request +M.RenderOutput = RenderOutput +M.IndexOutput = IndexOutput +M.AuditOutput = AuditOutput +M.Output = Output +M.Receipt = Receipt +M.Cached = Cached +M.OutputReceipt = OutputReceipt +M.CachedReceipt = CachedReceipt +M.DispatchStep = DispatchStep +M.CachedStep = CachedStep +M.FallbackStep = FallbackStep +M.AuditStep = AuditStep +M.RuntimeStep = RuntimeStep +M.FallbackPlan = FallbackPlan +M.StoreState = StoreState +M.RunSummary = RunSummary +M.ValidationResult = ValidationResult +M.DispatchResult = DispatchResult +M.FallbackResult = FallbackResult +M.RunResult = RunResult +M.RequestValidator = RequestValidator +M.PluginHandler = PluginHandler +M.FallbackHandler = FallbackHandler +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): RequestMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/result.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/runtime.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/runtime.lua new file mode 100644 index 00000000..b250c299 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/runtime.lua @@ -0,0 +1,200 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local plugin_store = require("plugin_store") + +type SupervisorRuntime = { + validators: {protocol.RequestValidator}, + handlers: {[string]: protocol.PluginHandler}, + fallbacks: {protocol.FallbackHandler}, + hooks: {protocol.StepHook}, + use_validator: (self: SupervisorRuntime, validator: protocol.RequestValidator) -> SupervisorRuntime, + register_handler: (self: SupervisorRuntime, plugin_name: string, handler: protocol.PluginHandler) -> SupervisorRuntime, + use_fallback: (self: SupervisorRuntime, fallback: protocol.FallbackHandler) -> SupervisorRuntime, + on_step: (self: SupervisorRuntime, hook: protocol.StepHook) -> SupervisorRuntime, + new_store: (self: SupervisorRuntime, id: string, now: time.Time) -> plugin_store.PluginStore, + emit: (self: SupervisorRuntime, store: plugin_store.PluginStore, step: protocol.RuntimeStep, at: time.Time) -> (), + dispatch: (self: SupervisorRuntime, store: plugin_store.PluginStore, request: protocol.Request, at: time.Time) -> protocol.DispatchResult, + run: (self: SupervisorRuntime, store: plugin_store.PluginStore, requests: {protocol.Request}, now: time.Time) -> protocol.RunResult, +} + +type Runtime = SupervisorRuntime + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.SupervisorRuntime = SupervisorRuntime + +function M.new(): SupervisorRuntime + local self: Runtime = { + validators = {}, + handlers = {}, + fallbacks = {}, + hooks = {}, + use_validator = Runtime.use_validator, + register_handler = Runtime.register_handler, + use_fallback = Runtime.use_fallback, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + dispatch = Runtime.dispatch, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:use_validator(validator: protocol.RequestValidator): Runtime + table.insert(self.validators, validator) + return self +end + +function Runtime:register_handler(plugin_name: string, handler: protocol.PluginHandler): Runtime + self.handlers[plugin_name] = handler + return self +end + +function Runtime:use_fallback(fallback: protocol.FallbackHandler): Runtime + table.insert(self.fallbacks, fallback) + return self +end + +function Runtime:on_step(hook: protocol.StepHook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): plugin_store.PluginStore + return plugin_store.new(id, now) +end + +function Runtime:emit(store: plugin_store.PluginStore, step: protocol.RuntimeStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:dispatch( + store: plugin_store.PluginStore, + request: protocol.Request, + at: time.Time +): protocol.DispatchResult + if request.kind == "tick" then + store.state.last_tick = request.at + local audit_step: protocol.AuditStep = { + kind = "audit", + note = "tick", + at = request.at, + } + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + for _, validator in ipairs(self.validators) do + local validation = validator(store.state, request) + if not validation.ok then + return {ok = false, error = validation.error} + end + end + + local cached = store:lookup_cached(helpers.cache_key(request)) + if cached then + local receipt: protocol.OutputReceipt = { + plugin = cached.value.plugin, + envelope_id = cached.value.envelope_id, + output = cached.value.output, + emitted_at = cached.value.emitted_at, + cached = true, + } + local cached_step: protocol.CachedStep = { + kind = "cached", + plugin = request.plugin, + envelope_id = request.envelope.id, + at = at, + } + self:emit(store, cached_step, at) + return {ok = true, value = receipt} + end + + local handler = self.handlers[request.plugin] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing handler: " .. request.plugin, + retryable = false, + }, + } + end + + local handled = handler(store.state, request.envelope, at) + if not handled.ok then + for _, fallback in ipairs(self.fallbacks) do + local fallback_result = fallback(store.state, request, handled.error, at) + if not fallback_result.ok then + return {ok = false, error = fallback_result.error} + end + + local plan = fallback_result.value + if plan then + local fallback_step: protocol.FallbackStep = { + kind = "fallback", + plugin = request.plugin, + queue = plan.queue, + note = plan.note, + retry_at = plan.retry_at, + } + self:emit(store, fallback_step, at) + return {ok = true, value = nil} + end + end + + return {ok = false, error = handled.error} + end + + local receipt = handled.value + if not receipt then + return {ok = true, value = nil} + end + + store:cache_receipt(request, receipt, at) + + local dispatch_step: protocol.DispatchStep = { + kind = "dispatch", + plugin = request.plugin, + output = receipt.output, + cached = receipt.cached, + } + self:emit(store, dispatch_step, at) + return {ok = true, value = receipt} +end + +function Runtime:run( + store: plugin_store.PluginStore, + requests: {protocol.Request}, + now: time.Time +): protocol.RunResult + local last_output_kind: string? = nil + + for _, request in ipairs(requests) do + local dispatch_result = self:dispatch(store, request, now) + if not dispatch_result.ok then + return {ok = false, error = dispatch_result.error} + end + + local receipt = dispatch_result.value + if receipt then + last_output_kind = receipt.output.kind + end + end + + return { + ok = true, + value = store:summarize(now, last_output_kind), + } +end + +return M diff --git a/testdata/fixtures/realworld/plugin-supervisor-runtime/validator_builder.lua b/testdata/fixtures/realworld/plugin-supervisor-runtime/validator_builder.lua new file mode 100644 index 00000000..4c603c84 --- /dev/null +++ b/testdata/fixtures/realworld/plugin-supervisor-runtime/validator_builder.lua @@ -0,0 +1,90 @@ +local protocol = require("protocol") + +type ValidatorBuilder = { + name: string?, + required_tag: string?, + remembered_flag: string?, + named: (self: ValidatorBuilder, name: string) -> ValidatorBuilder, + require_tag: (self: ValidatorBuilder, key: string) -> ValidatorBuilder, + remember_flag: (self: ValidatorBuilder, flag: string) -> ValidatorBuilder, + build: (self: ValidatorBuilder) -> protocol.RequestValidator, +} + +type Builder = ValidatorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ValidatorBuilder = ValidatorBuilder + +function M.new(): ValidatorBuilder + local self: Builder = { + name = nil, + required_tag = nil, + remembered_flag = nil, + named = Builder.named, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_tag(key: string): Builder + self.required_tag = key + return self +end + +function Builder:remember_flag(flag: string): Builder + self.remembered_flag = flag + return self +end + +function Builder:build(): protocol.RequestValidator + local name = self.name or "validator" + local required_tag = self.required_tag + local remembered_flag = self.remembered_flag + + return function(state: protocol.StoreState, request: protocol.DispatchRequest): protocol.ValidationResult + if required_tag then + local tags = request.envelope.meta.tags + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. ": missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. ": missing tag " .. required_tag, + retryable = false, + }, + } + end + end + + if remembered_flag then + state.flags[remembered_flag] = true + end + + return {ok = true, value = request} + end +end + +return M diff --git a/testdata/fixtures/realworld/recursive-alias-array-index/main.lua b/testdata/fixtures/realworld/recursive-alias-array-index/main.lua new file mode 100644 index 00000000..67a65109 --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-array-index/main.lua @@ -0,0 +1,15 @@ +type Message = { + _topic: string, + topic: (self: Message) -> string, +} + +local function make(): {Message} + return {{ + _topic = "indexed", + topic = function(self: Message): string + return self._topic + end, + }} +end + +local topic: string = make()[1]:topic() diff --git a/testdata/fixtures/realworld/recursive-alias-array-index/manifest.json b/testdata/fixtures/realworld/recursive-alias-array-index/manifest.json new file mode 100644 index 00000000..142ae27a --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-array-index/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/realworld/recursive-alias-map-index/main.lua b/testdata/fixtures/realworld/recursive-alias-map-index/main.lua new file mode 100644 index 00000000..5607c527 --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-map-index/main.lua @@ -0,0 +1,20 @@ +type Message = { + _topic: string, + topic: (self: Message) -> string, +} + +local function make(): {[string]: Message} + return { + ["root"] = { + _topic = "mapped", + topic = function(self: Message): string + return self._topic + end, + }, + } +end + +local root = make()["root"] +if root then + local topic: string = root:topic() +end diff --git a/testdata/fixtures/realworld/recursive-alias-map-index/manifest.json b/testdata/fixtures/realworld/recursive-alias-map-index/manifest.json new file mode 100644 index 00000000..142ae27a --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-map-index/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/realworld/recursive-alias-method-chain/main.lua b/testdata/fixtures/realworld/recursive-alias-method-chain/main.lua new file mode 100644 index 00000000..0a4311ae --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-method-chain/main.lua @@ -0,0 +1,15 @@ +type Builder = { + f: (self: Builder) -> Builder, + g: (self: Builder) -> number, +} + +local b: Builder = { + f = function(self: Builder): Builder + return self + end, + g = function(self: Builder): number + return 1 + end, +} + +local n: number = b:f():g() diff --git a/testdata/fixtures/realworld/recursive-alias-method-chain/manifest.json b/testdata/fixtures/realworld/recursive-alias-method-chain/manifest.json new file mode 100644 index 00000000..142ae27a --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-method-chain/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/realworld/recursive-alias-module-return/builder.lua b/testdata/fixtures/realworld/recursive-alias-module-return/builder.lua new file mode 100644 index 00000000..228ca77f --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-module-return/builder.lua @@ -0,0 +1,23 @@ +type Message = { + _topic: string, + topic: (self: Message) -> string, +} + +type Result = { + message: Message, +} + +local M = {} + +function M.make(): Result + return { + message = { + _topic = "exported", + topic = function(self: Message): string + return self._topic + end, + }, + } +end + +return M diff --git a/testdata/fixtures/realworld/recursive-alias-module-return/main.lua b/testdata/fixtures/realworld/recursive-alias-module-return/main.lua new file mode 100644 index 00000000..2bd43714 --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-module-return/main.lua @@ -0,0 +1,3 @@ +local builder = require("builder") + +local topic: string = builder.make().message:topic() diff --git a/testdata/fixtures/realworld/recursive-alias-module-return/manifest.json b/testdata/fixtures/realworld/recursive-alias-module-return/manifest.json new file mode 100644 index 00000000..a6474a2b --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-module-return/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["builder.lua", "main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/realworld/recursive-alias-nested-field/main.lua b/testdata/fixtures/realworld/recursive-alias-nested-field/main.lua new file mode 100644 index 00000000..b2c21509 --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-nested-field/main.lua @@ -0,0 +1,21 @@ +type Message = { + _topic: string, + topic: (self: Message) -> string, +} + +type Result = { + message: Message, +} + +local function make(): Result + return { + message = { + _topic = "test", + topic = function(self: Message): string + return self._topic + end, + }, + } +end + +local topic: string = make().message:topic() diff --git a/testdata/fixtures/realworld/recursive-alias-nested-field/manifest.json b/testdata/fixtures/realworld/recursive-alias-nested-field/manifest.json new file mode 100644 index 00000000..142ae27a --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-nested-field/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/realworld/recursive-alias-optional-narrowing/main.lua b/testdata/fixtures/realworld/recursive-alias-optional-narrowing/main.lua new file mode 100644 index 00000000..94e22de6 --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-optional-narrowing/main.lua @@ -0,0 +1,27 @@ +type Message = { + _topic: string, + topic: (self: Message) -> string, +} + +type Result = { + message: Message?, +} + +local function make(ok: boolean): Result + if ok then + return { + message = { + _topic = "optional", + topic = function(self: Message): string + return self._topic + end, + }, + } + end + return {message = nil} +end + +local result = make(true) +if result.message then + local topic: string = result.message:topic() +end diff --git a/testdata/fixtures/realworld/recursive-alias-optional-narrowing/manifest.json b/testdata/fixtures/realworld/recursive-alias-optional-narrowing/manifest.json new file mode 100644 index 00000000..142ae27a --- /dev/null +++ b/testdata/fixtures/realworld/recursive-alias-optional-narrowing/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/realworld/result-type-narrowing/manifest.json b/testdata/fixtures/realworld/result-type-narrowing/manifest.json index 6d396875..90eea4be 100644 --- a/testdata/fixtures/realworld/result-type-narrowing/manifest.json +++ b/testdata/fixtures/realworld/result-type-narrowing/manifest.json @@ -1,4 +1,4 @@ { "files": ["result.lua", "repo.lua", "service.lua", "main.lua"], - "check": {"errors": 4} + "check": {"errors": 0} } diff --git a/testdata/fixtures/realworld/service-locator/locator.lua b/testdata/fixtures/realworld/service-locator/locator.lua index 82f11484..ea125726 100644 --- a/testdata/fixtures/realworld/service-locator/locator.lua +++ b/testdata/fixtures/realworld/service-locator/locator.lua @@ -2,15 +2,15 @@ local logger = require("logger") local cache = require("cache") type Services = { - logger: Logger, - cache: Cache, + logger: logger.Logger, + cache: cache.Cache, } local M = {} local _services: Services? = nil -function M.init(log_level: LogLevel?): Services +function M.init(log_level: logger.LogLevel?): Services local s: Services = { logger = logger.new(log_level), cache = cache.new(), @@ -26,11 +26,11 @@ function M.get(): Services return _services end -function M.logger(): Logger +function M.logger(): logger.Logger return M.get().logger end -function M.cache(): Cache +function M.cache(): cache.Cache return M.get().cache end diff --git a/testdata/fixtures/realworld/sql-repository/manifest.json b/testdata/fixtures/realworld/sql-repository/manifest.json index 8179b4af..1a877772 100644 --- a/testdata/fixtures/realworld/sql-repository/manifest.json +++ b/testdata/fixtures/realworld/sql-repository/manifest.json @@ -1 +1 @@ -{"files": ["db.lua", "repository.lua", "main.lua"], "check": {"errors": 4}} +{"files": ["db.lua", "repository.lua", "main.lua"], "check": {"errors": 0}} diff --git a/testdata/fixtures/realworld/sql-repository/repository.lua b/testdata/fixtures/realworld/sql-repository/repository.lua index 0cc71dc8..3fedc990 100644 --- a/testdata/fixtures/realworld/sql-repository/repository.lua +++ b/testdata/fixtures/realworld/sql-repository/repository.lua @@ -32,7 +32,7 @@ M.table_exists_queries = { mysql = [[SELECT COUNT(*) AS count FROM information_schema.tables WHERE table_name = '_migrations']], } -function M.table_exists(database: Database): (boolean?, string?) +function M.table_exists(database: db.Database): (boolean?, string?) local db_type, err = database:type() if err then return nil, "Failed to determine database type: " .. tostring(err) @@ -51,7 +51,7 @@ function M.table_exists(database: Database): (boolean?, string?) return false, nil end -function M.init(database: Database): (boolean, string?) +function M.init(database: db.Database): (boolean, string?) local exists, err = M.table_exists(database) if err then return false, err end if exists then return true, nil end @@ -66,14 +66,14 @@ function M.init(database: Database): (boolean, string?) return database:execute(schema) end -function M.record(database: Database, id: string, description: string?): (boolean, string?) +function M.record(database: db.Database, id: string, description: string?): (boolean, string?) return database:execute( "INSERT INTO _migrations (id, description) VALUES (?, ?)", {id, description or ""} ) end -function M.is_applied(database: Database, id: string): (boolean, string?) +function M.is_applied(database: db.Database, id: string): (boolean, string?) local result, err = database:query( "SELECT id FROM _migrations WHERE id = ?", {id} diff --git a/testdata/fixtures/realworld/table-builder-pattern/manifest.json b/testdata/fixtures/realworld/table-builder-pattern/manifest.json index 325eda3c..7d1adba5 100644 --- a/testdata/fixtures/realworld/table-builder-pattern/manifest.json +++ b/testdata/fixtures/realworld/table-builder-pattern/manifest.json @@ -1,4 +1,4 @@ { "files": ["config.lua", "main.lua"], - "check": {"errors": 3} + "check": {"errors": 0} } diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/helpers.lua b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/helpers.lua new file mode 100644 index 00000000..526d6b52 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/helpers.lua @@ -0,0 +1,49 @@ +local protocol = require("protocol") + +local M = {} + +function M.request_label(request: protocol.Request): string + if request.kind == "auth" then + return "auth:" .. request.scope + end + if request.kind == "query" then + return "query:" .. request.resource + end + if request.kind == "update" then + return "update:" .. request.resource + end + return "tick" +end + +function M.tag_value(request: protocol.Request, key: string): string? + if request.kind == "tick" then + return nil + end + + local tags = request.meta.tags + if not tags then + return nil + end + return tags[key] +end + +function M.decision_note(decision: protocol.Decision): string + if decision.kind == "allow" then + return "allow:" .. decision.reason + end + if decision.kind == "deny" then + return "deny:" .. decision.reason + end + return "defer:" .. decision.queue +end + +function M.bump_counter(counters: {[string]: integer}, key: string) + local current = counters[key] + if current then + counters[key] = current + 1 + return + end + counters[key] = 1 +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/main.lua b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/main.lua new file mode 100644 index 00000000..9634f9dd --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/main.lua @@ -0,0 +1,66 @@ +local time = require("time") +local protocol = require("protocol") +local validator_builder = require("validator_builder") +local rule_builder = require("rule_builder") +local runtime = require("runtime") + +local now = time.now() + +local source_validator = validator_builder.new() + :named("source") + :require_tag("source") + :build() + +local auth_rule = rule_builder.new() + :named("auth") + :for_kind("auth") + :cache_with("auth") + :build() + +local app = runtime.new() + :use_validator(source_validator) + :register_evaluator("auth", auth_rule) + +local store = app:new_store("policy-soundness", now) + +store:save_policy({ + tenant_id = "tenant-a", + scopes = {["policy.read"] = true}, + allowed_resources = {}, + fallback_queue = nil, + tags = nil, + last_checked = nil, +}) + +local auth_one: protocol.AuthRequest = { + kind = "auth", + tenant_id = "tenant-a", + actor_id = "actor-1", + scope = "policy.read", + meta = protocol.meta("trace-1", {source = "api"}), +} + +local tick: protocol.TickRequest = { + kind = "tick", + at = now, +} + +local run_result = app:run(store, {auth_one, tick}, now) + +if run_result.ok then + local last_kind: string = run_result.value.last_kind -- expect-error +end + +local missing_policy: protocol.TenantPolicy = store:lookup_policy("missing") -- expect-error +local missing_decision: protocol.Decision = store.state.cached_decisions["missing"] -- expect-error +local evaluator: protocol.PolicyEvaluator = app.evaluators["auth"] -- expect-error + +local elapsed = now:sub(store.state.last_eval_at) -- expect-error +local source: string = protocol.meta("trace-2", nil).tags["source"] -- expect-error + +local cached = store:lookup_decision("auth:tenant-a:actor-1") +if cached then + local queue: string = cached.queue -- expect-error +end + +local missing_status: string = store:lookup_policy("tenant-a").tags["source"] -- expect-error diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/manifest.json b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/manifest.json new file mode 100644 index 00000000..0e3009c0 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/manifest.json @@ -0,0 +1,17 @@ +{ + "description": "Soundness counterpart for tenant policy runtime, pinning unsafe optional cache, evaluator registry, tag, union-field, and time uses in the same multi-module application shape.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "policy_store.lua", + "validator_builder.lua", + "rule_builder.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 8 + } +} diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/policy_store.lua b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/policy_store.lua new file mode 100644 index 00000000..c19ffd2e --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/policy_store.lua @@ -0,0 +1,109 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") + +type PolicyStore = { + state: protocol.StoreState, + save_policy: (self: PolicyStore, policy: protocol.TenantPolicy) -> (), + lookup_policy: (self: PolicyStore, tenant_id: string) -> protocol.TenantPolicy?, + lookup_decision: (self: PolicyStore, cache_key: string) -> protocol.Decision?, + push_step: (self: PolicyStore, step: protocol.PolicyStep, at: time.Time) -> (), + cache_decision: (self: PolicyStore, request: protocol.Request, decision: protocol.Decision, at: time.Time) -> (), + summarize: (self: PolicyStore, now: time.Time, last_kind: string?) -> protocol.RunSummary, +} + +type Store = PolicyStore + +local Store = {} +Store.__index = Store + +local M = {} +M.PolicyStore = PolicyStore + +function M.new(id: string, now: time.Time): PolicyStore + local self: Store = { + state = { + id = id, + started_at = now, + last_eval_at = nil, + policies = {}, + cached_decisions = {}, + steps = {}, + counters = {}, + flags = {}, + }, + save_policy = Store.save_policy, + lookup_policy = Store.lookup_policy, + lookup_decision = Store.lookup_decision, + push_step = Store.push_step, + cache_decision = Store.cache_decision, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:save_policy(policy: protocol.TenantPolicy) + self.state.policies[policy.tenant_id] = policy +end + +function Store:lookup_policy(tenant_id: string): protocol.TenantPolicy? + return self.state.policies[tenant_id] +end + +function Store:lookup_decision(cache_key: string): protocol.Decision? + return self.state.cached_decisions[cache_key] +end + +function Store:push_step(step: protocol.PolicyStep, at: time.Time) + table.insert(self.state.steps, step) + self.state.last_eval_at = at +end + +function Store:cache_decision(request: protocol.Request, decision: protocol.Decision, at: time.Time) + if request.kind == "tick" then + return + end + + local current = self.state.policies[request.tenant_id] + if current then + local updated_policy: protocol.TenantPolicy = { + tenant_id = current.tenant_id, + scopes = current.scopes, + allowed_resources = current.allowed_resources, + fallback_queue = current.fallback_queue, + tags = current.tags, + last_checked = at, + } + self.state.policies[current.tenant_id] = updated_policy + end + + if decision.kind == "allow" then + local key = decision.cache_key + if key then + self.state.cached_decisions[key] = decision + end + elseif decision.kind == "defer" then + self.state.cached_decisions[request.tenant_id .. ":defer:" .. request.kind] = decision + end + + helpers.bump_counter(self.state.counters, decision.kind) +end + +function Store:summarize(now: time.Time, last_kind: string?): protocol.RunSummary + local allowed_count = self.state.counters["allow"] or 0 + local denied_count = self.state.counters["deny"] or 0 + local deferred_count = self.state.counters["defer"] or 0 + + return { + id = self.state.id, + total_processed = #self.state.steps, + allowed_count = allowed_count, + denied_count = denied_count, + deferred_count = deferred_count, + elapsed_seconds = now:sub(self.state.started_at), + last_kind = last_kind, + } +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/protocol.lua b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/protocol.lua new file mode 100644 index 00000000..ded174f1 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/protocol.lua @@ -0,0 +1,158 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type PolicyMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type AuthRequest = { + kind: "auth", + tenant_id: string, + actor_id: string, + scope: string, + meta: PolicyMeta, +} + +type QueryRequest = { + kind: "query", + tenant_id: string, + actor_id: string, + resource: string, + meta: PolicyMeta, +} + +type UpdateRequest = { + kind: "update", + tenant_id: string, + actor_id: string, + resource: string, + change_set: {[string]: string}, + meta: PolicyMeta, +} + +type TickRequest = { + kind: "tick", + at: time.Time, +} + +type Request = AuthRequest | QueryRequest | UpdateRequest | TickRequest + +type AllowDecision = { + kind: "allow", + reason: string, + cache_key: string?, + expires_at: time.Time?, +} + +type DenyDecision = { + kind: "deny", + reason: string, + retryable: boolean, +} + +type DeferDecision = { + kind: "defer", + queue: string, + retry_at: time.Time, +} + +type Decision = AllowDecision | DenyDecision | DeferDecision + +type TenantPolicy = { + tenant_id: string, + scopes: {[string]: boolean}, + allowed_resources: {[string]: boolean}, + fallback_queue: string?, + tags: {[string]: string}?, + last_checked: time.Time?, +} + +type DecisionStep = { + kind: "decision", + request_kind: "auth" | "query" | "update", + tenant_id: string, + note: string, +} + +type DeferStep = { + kind: "defer", + tenant_id: string, + queue: string, + note: string, + retry_at: time.Time, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type PolicyStep = DecisionStep | DeferStep | AuditStep + +type StoreState = { + id: string, + started_at: time.Time, + last_eval_at: time.Time?, + policies: {[string]: TenantPolicy}, + cached_decisions: {[string]: Decision}, + steps: {PolicyStep}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type RunSummary = { + id: string, + total_processed: number, + allowed_count: number, + denied_count: number, + deferred_count: number, + elapsed_seconds: time.Duration, + last_kind: string?, +} + +type ValidationResult = {ok: true, value: Request} | {ok: false, error: AppError} +type EvaluateResult = {ok: true, value: Decision?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type RequestValidator = (StoreState, Request) -> ValidationResult +type PolicyEvaluator = (StoreState, Request, time.Time) -> EvaluateResult +type StepHook = (PolicyStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.PolicyMeta = PolicyMeta +M.AuthRequest = AuthRequest +M.QueryRequest = QueryRequest +M.UpdateRequest = UpdateRequest +M.TickRequest = TickRequest +M.Request = Request +M.AllowDecision = AllowDecision +M.DenyDecision = DenyDecision +M.DeferDecision = DeferDecision +M.Decision = Decision +M.TenantPolicy = TenantPolicy +M.DecisionStep = DecisionStep +M.DeferStep = DeferStep +M.AuditStep = AuditStep +M.PolicyStep = PolicyStep +M.StoreState = StoreState +M.RunSummary = RunSummary +M.ValidationResult = ValidationResult +M.EvaluateResult = EvaluateResult +M.RunResult = RunResult +M.RequestValidator = RequestValidator +M.PolicyEvaluator = PolicyEvaluator +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): PolicyMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/result.lua b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/rule_builder.lua b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/rule_builder.lua new file mode 100644 index 00000000..7884ae63 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/rule_builder.lua @@ -0,0 +1,191 @@ +local protocol = require("protocol") +local helpers = require("helpers") + +type RuleDecorator = (string, protocol.TenantPolicy, protocol.Request) -> string + +type RuleBuilder = { + name: string, + request_kind: "auth" | "query" | "update", + required_scope: string?, + fallback_queue: string?, + cache_prefix: string?, + decorator: RuleDecorator?, + named: (self: RuleBuilder, name: string) -> RuleBuilder, + for_kind: (self: RuleBuilder, request_kind: "auth" | "query" | "update") -> RuleBuilder, + require_scope: (self: RuleBuilder, scope: string) -> RuleBuilder, + fallback_to: (self: RuleBuilder, queue: string) -> RuleBuilder, + cache_with: (self: RuleBuilder, prefix: string) -> RuleBuilder, + decorate: (self: RuleBuilder, decorator: RuleDecorator) -> RuleBuilder, + build: (self: RuleBuilder) -> protocol.PolicyEvaluator, +} + +type Builder = RuleBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): RuleBuilder + local self: Builder = { + name = "rule", + request_kind = "auth", + required_scope = nil, + fallback_queue = nil, + cache_prefix = nil, + decorator = nil, + named = Builder.named, + for_kind = Builder.for_kind, + require_scope = Builder.require_scope, + fallback_to = Builder.fallback_to, + cache_with = Builder.cache_with, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:for_kind(request_kind: "auth" | "query" | "update"): Builder + self.request_kind = request_kind + return self +end + +function Builder:require_scope(scope: string): Builder + self.required_scope = scope + return self +end + +function Builder:fallback_to(queue: string): Builder + self.fallback_queue = queue + return self +end + +function Builder:cache_with(prefix: string): Builder + self.cache_prefix = prefix + return self +end + +function Builder:decorate(decorator: RuleDecorator): Builder + self.decorator = decorator + return self +end + +function Builder:build(): protocol.PolicyEvaluator + local name = self.name + local request_kind = self.request_kind + local required_scope = self.required_scope + local fallback_queue = self.fallback_queue + local cache_prefix = self.cache_prefix + local decorator = self.decorator + + return function(state: protocol.StoreState, request: protocol.Request, at: time.Time): protocol.EvaluateResult + if request.kind == "tick" then + return { + ok = false, + error = { + code = "invalid", + message = name .. " cannot evaluate ticks", + retryable = false, + }, + } + end + + if request.kind ~= request_kind then + return { + ok = false, + error = { + code = "invalid", + message = name .. " wrong request kind: " .. helpers.request_label(request), + retryable = false, + }, + } + end + + local policy = state.policies[request.tenant_id] + if not policy then + return { + ok = false, + error = { + code = "not_found", + message = name .. " missing tenant policy", + retryable = false, + }, + } + end + + local reason = name .. ":" .. request.tenant_id + if decorator then + reason = decorator(reason, policy, request) + end + + if request.kind == "auth" then + local scope_key = required_scope or request.scope + if policy.scopes[scope_key] then + local cache_key: string? = nil + if cache_prefix then + cache_key = cache_prefix .. ":" .. request.tenant_id .. ":" .. request.actor_id + end + local allow: protocol.AllowDecision = { + kind = "allow", + reason = reason, + cache_key = cache_key, + expires_at = at, + } + return {ok = true, value = allow} + end + + local deny: protocol.DenyDecision = { + kind = "deny", + reason = reason, + retryable = false, + } + return {ok = true, value = deny} + end + + local resource: string + if request.kind == "query" then + resource = request.resource + else + resource = request.resource + end + + if policy.allowed_resources[resource] then + local cache_key: string? = nil + if cache_prefix then + cache_key = cache_prefix .. ":" .. request.tenant_id .. ":" .. resource + end + local allow: protocol.AllowDecision = { + kind = "allow", + reason = reason, + cache_key = cache_key, + expires_at = at, + } + return {ok = true, value = allow} + end + + local queue = fallback_queue or policy.fallback_queue + if queue then + local defer: protocol.DeferDecision = { + kind = "defer", + queue = queue, + retry_at = at, + } + return {ok = true, value = defer} + end + + local deny: protocol.DenyDecision = { + kind = "deny", + reason = reason, + retryable = false, + } + return {ok = true, value = deny} + end +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/runtime.lua b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/runtime.lua new file mode 100644 index 00000000..99ea7701 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/runtime.lua @@ -0,0 +1,157 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local policy_store = require("policy_store") + +type PolicyRuntime = { + validators: {protocol.RequestValidator}, + evaluators: {[string]: protocol.PolicyEvaluator}, + hooks: {protocol.StepHook}, + use_validator: (self: PolicyRuntime, validator: protocol.RequestValidator) -> PolicyRuntime, + register_evaluator: (self: PolicyRuntime, request_kind: string, evaluator: protocol.PolicyEvaluator) -> PolicyRuntime, + on_step: (self: PolicyRuntime, hook: protocol.StepHook) -> PolicyRuntime, + new_store: (self: PolicyRuntime, id: string, now: time.Time) -> policy_store.PolicyStore, + emit: (self: PolicyRuntime, store: policy_store.PolicyStore, step: protocol.PolicyStep, at: time.Time) -> (), + evaluate: (self: PolicyRuntime, store: policy_store.PolicyStore, request: protocol.Request, at: time.Time) -> protocol.EvaluateResult, + run: (self: PolicyRuntime, store: policy_store.PolicyStore, requests: {protocol.Request}, now: time.Time) -> protocol.RunResult, +} + +type Runtime = PolicyRuntime + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.PolicyRuntime = PolicyRuntime + +function M.new(): PolicyRuntime + local self: Runtime = { + validators = {}, + evaluators = {}, + hooks = {}, + use_validator = Runtime.use_validator, + register_evaluator = Runtime.register_evaluator, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + evaluate = Runtime.evaluate, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:use_validator(validator: protocol.RequestValidator): Runtime + table.insert(self.validators, validator) + return self +end + +function Runtime:register_evaluator(request_kind: string, evaluator: protocol.PolicyEvaluator): Runtime + self.evaluators[request_kind] = evaluator + return self +end + +function Runtime:on_step(hook: protocol.StepHook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): policy_store.PolicyStore + return policy_store.new(id, now) +end + +function Runtime:emit(store: policy_store.PolicyStore, step: protocol.PolicyStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:evaluate( + store: policy_store.PolicyStore, + request: protocol.Request, + at: time.Time +): protocol.EvaluateResult + if request.kind == "tick" then + local audit_step: protocol.AuditStep = {kind = "audit", note = "tick", at = request.at} + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + for _, validator in ipairs(self.validators) do + local validation: protocol.ValidationResult = validator(store.state, request) + if not validation.ok then + return {ok = false, error = validation.error} + end + end + + local evaluator = self.evaluators[request.kind] + if not evaluator then + return { + ok = false, + error = { + code = "not_found", + message = "missing evaluator: " .. request.kind, + retryable = false, + }, + } + end + + local evaluation = evaluator(store.state, request, at) + if not evaluation.ok then + return {ok = false, error = evaluation.error} + end + + local decision = evaluation.value + if not decision then + return {ok = true, value = nil} + end + + store:cache_decision(request, decision, at) + + if decision.kind == "defer" then + local defer_decision: protocol.DeferDecision = decision + local defer_step: protocol.DeferStep = { + kind = "defer", + tenant_id = request.tenant_id, + queue = defer_decision.queue, + note = helpers.decision_note(defer_decision), + retry_at = defer_decision.retry_at, + } + self:emit(store, defer_step, at) + return {ok = true, value = defer_decision} + end + + local decision_step: protocol.DecisionStep = { + kind = "decision", + request_kind = request.kind, + tenant_id = request.tenant_id, + note = helpers.decision_note(decision), + } + self:emit(store, decision_step, at) + return {ok = true, value = decision} +end + +function Runtime:run( + store: policy_store.PolicyStore, + requests: {protocol.Request}, + now: time.Time +): protocol.RunResult + local last_kind: string? = nil + + for _, request in ipairs(requests) do + local evaluation: protocol.EvaluateResult = self:evaluate(store, request, now) + if not evaluation.ok then + return {ok = false, error = evaluation.error} + end + + local decision = evaluation.value + if decision then + last_kind = decision.kind + end + end + + return {ok = true, value = store:summarize(now, last_kind)} +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime-soundness/validator_builder.lua b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/validator_builder.lua new file mode 100644 index 00000000..1c71ae26 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime-soundness/validator_builder.lua @@ -0,0 +1,131 @@ +local protocol = require("protocol") + +type ValidatorBuilder = { + name: string, + required_tag: string?, + required_scope: string?, + required_resource_prefix: string?, + flag_name: string?, + named: (self: ValidatorBuilder, name: string) -> ValidatorBuilder, + require_tag: (self: ValidatorBuilder, tag: string) -> ValidatorBuilder, + require_scope: (self: ValidatorBuilder, scope: string) -> ValidatorBuilder, + require_resource_prefix: (self: ValidatorBuilder, prefix: string) -> ValidatorBuilder, + remember_flag: (self: ValidatorBuilder, flag_name: string) -> ValidatorBuilder, + build: (self: ValidatorBuilder) -> protocol.RequestValidator, +} + +type Builder = ValidatorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): ValidatorBuilder + local self: Builder = { + name = "validator", + required_tag = nil, + required_scope = nil, + required_resource_prefix = nil, + flag_name = nil, + named = Builder.named, + require_tag = Builder.require_tag, + require_scope = Builder.require_scope, + require_resource_prefix = Builder.require_resource_prefix, + remember_flag = Builder.remember_flag, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_tag(tag: string): Builder + self.required_tag = tag + return self +end + +function Builder:require_scope(scope: string): Builder + self.required_scope = scope + return self +end + +function Builder:require_resource_prefix(prefix: string): Builder + self.required_resource_prefix = prefix + return self +end + +function Builder:remember_flag(flag_name: string): Builder + self.flag_name = flag_name + return self +end + +function Builder:build(): protocol.RequestValidator + local name = self.name + local required_tag = self.required_tag + local required_scope = self.required_scope + local required_resource_prefix = self.required_resource_prefix + local flag_name = self.flag_name + + return function(state: protocol.StoreState, request: protocol.Request): protocol.ValidationResult + if request.kind == "tick" then + return {ok = true, value = request} + end + + if required_tag then + local tags = request.meta.tags + if not tags or not tags[required_tag] then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + if required_scope and request.kind == "auth" and request.scope ~= required_scope then + return { + ok = false, + error = { + code = "invalid", + message = name .. " wrong scope: " .. request.scope, + retryable = false, + }, + } + end + + if required_resource_prefix and request.kind ~= "auth" then + local resource: string + if request.kind == "query" then + resource = request.resource + else + resource = request.resource + end + if string.sub(resource, 1, #required_resource_prefix) ~= required_resource_prefix then + return { + ok = false, + error = { + code = "invalid", + message = name .. " wrong resource prefix", + retryable = false, + }, + } + end + end + + if flag_name then + state.flags[flag_name] = true + end + + return {ok = true, value = request} + end +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/helpers.lua b/testdata/fixtures/realworld/tenant-policy-runtime/helpers.lua new file mode 100644 index 00000000..526d6b52 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/helpers.lua @@ -0,0 +1,49 @@ +local protocol = require("protocol") + +local M = {} + +function M.request_label(request: protocol.Request): string + if request.kind == "auth" then + return "auth:" .. request.scope + end + if request.kind == "query" then + return "query:" .. request.resource + end + if request.kind == "update" then + return "update:" .. request.resource + end + return "tick" +end + +function M.tag_value(request: protocol.Request, key: string): string? + if request.kind == "tick" then + return nil + end + + local tags = request.meta.tags + if not tags then + return nil + end + return tags[key] +end + +function M.decision_note(decision: protocol.Decision): string + if decision.kind == "allow" then + return "allow:" .. decision.reason + end + if decision.kind == "deny" then + return "deny:" .. decision.reason + end + return "defer:" .. decision.queue +end + +function M.bump_counter(counters: {[string]: integer}, key: string) + local current = counters[key] + if current then + counters[key] = current + 1 + return + end + counters[key] = 1 +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/main.lua b/testdata/fixtures/realworld/tenant-policy-runtime/main.lua new file mode 100644 index 00000000..849206f3 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/main.lua @@ -0,0 +1,268 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local helpers = require("helpers") +local validator_builder = require("validator_builder") +local rule_builder = require("rule_builder") +local runtime = require("runtime") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() + +local observed_decisions: {[string]: string} = {} +local observed_defers: {string} = {} +local observed_audits: {string} = {} +local last_runtime_id: string? = nil + +local source_validator = validator_builder.new() + :named("source") + :require_tag("source") + :remember_flag("saw_source") + :build() + +local auth_validator = validator_builder.new() + :named("auth") + :require_scope("policy.read") + :build() + +local resource_validator = validator_builder.new() + :named("resource") + :require_resource_prefix("doc/") + :build() + +local auth_rule = rule_builder.new() + :named("auth") + :for_kind("auth") + :cache_with("auth") + :decorate(function(reason: string, _policy: protocol.TenantPolicy, request: protocol.Request): string + if request.kind == "auth" then + return reason .. ":" .. request.actor_id + end + return reason + end) + :build() + +local query_rule = rule_builder.new() + :named("query") + :for_kind("query") + :fallback_to("policy-review") + :cache_with("query") + :decorate(function(reason: string, _policy: protocol.TenantPolicy, request: protocol.Request): string + if request.kind == "query" then + return reason .. ":" .. request.resource + end + return reason + end) + :build() + +local update_rule = rule_builder.new() + :named("update") + :for_kind("update") + :fallback_to("manual-approval") + :cache_with("update") + :decorate(function(reason: string, _policy: protocol.TenantPolicy, request: protocol.Request): string + if request.kind == "update" then + return reason .. ":" .. request.resource + end + return reason + end) + :build() + +local app = runtime.new() + :use_validator(source_validator) + :use_validator(auth_validator) + :use_validator(resource_validator) + :register_evaluator("auth", auth_rule) + :register_evaluator("query", query_rule) + :register_evaluator("update", update_rule) + +app:on_step(function(step: protocol.PolicyStep, state: protocol.StoreState) + last_runtime_id = state.id + if step.kind == "decision" then + observed_decisions[step.tenant_id .. ":" .. step.request_kind] = step.note + elseif step.kind == "defer" then + table.insert(observed_defers, step.note) + local retry_seconds: integer = step.retry_at:unix() + local queue: string = step.queue + else + table.insert(observed_audits, step.note) + local at_seconds: integer = step.at:unix() + end +end) + +local store = app:new_store("policy-1", now) + +store:save_policy({ + tenant_id = "tenant-a", + scopes = {["policy.read"] = true}, + allowed_resources = {["doc/alpha"] = true, ["doc/beta"] = true}, + fallback_queue = "policy-review", + tags = {source = "api"}, + last_checked = nil, +}) + +store:save_policy({ + tenant_id = "tenant-b", + scopes = {["policy.read"] = true}, + allowed_resources = {}, + fallback_queue = "manual-approval", + tags = {source = "worker"}, + last_checked = nil, +}) + +local auth_one: protocol.AuthRequest = { + kind = "auth", + tenant_id = "tenant-a", + actor_id = "actor-1", + scope = "policy.read", + meta = protocol.meta("trace-1", {source = "api"}), +} + +local query_one: protocol.QueryRequest = { + kind = "query", + tenant_id = "tenant-a", + actor_id = "actor-1", + resource = "doc/alpha", + meta = protocol.meta("trace-2", {source = "api", priority = "high"}), +} + +local update_two: protocol.UpdateRequest = { + kind = "update", + tenant_id = "tenant-b", + actor_id = "actor-2", + resource = "doc/review", + change_set = {status = "submitted"}, + meta = protocol.meta("trace-3", {source = "worker"}), +} + +local tick: protocol.TickRequest = { + kind = "tick", + at = now, +} + +local requests: {protocol.Request} = { + auth_one, + query_one, + update_two, + tick, +} + +local summary_result = app:run(store, requests, now) +if not summary_result.ok then + local message: string = summary_result.error.message + local retryable: boolean = summary_result.error.retryable +else + local summary = summary_result.value + local runtime_id: string = summary.id + local total_processed: number = summary.total_processed + local allowed_count: number = summary.allowed_count + local denied_count: number = summary.denied_count + local deferred_count: number = summary.deferred_count + local elapsed_seconds: time.Duration = summary.elapsed_seconds + local last_kind: string? = summary.last_kind +end + +local summary_label = result.map(summary_result, function(summary: protocol.RunSummary): string + return summary.id .. ":" .. tostring(summary.allowed_count + summary.deferred_count) +end) + +if summary_label.ok then + local label: string = summary_label.value +end + +local checked_result = result.and_then(summary_result, function(summary: protocol.RunSummary): StringResult + if summary.deferred_count == 0 then + return { + ok = false, + error = { + code = "invalid", + message = "expected deferred decision", + retryable = false, + }, + } + end + return {ok = true, value = summary.id} +end) + +if checked_result.ok then + local checked_id: string = checked_result.value +end + +local policy_one = store:lookup_policy("tenant-a") +if policy_one then + local tenant_id: string = policy_one.tenant_id + local tags = policy_one.tags + if tags then + local origin = tags["source"] + if origin then + local source: string = origin + end + end + local last_checked = policy_one.last_checked + if last_checked then + local checked_at: integer = last_checked:unix() + end +end + +local cached_auth = store:lookup_decision("auth:tenant-a:actor-1") +if cached_auth then + if cached_auth.kind == "allow" then + local allow_decision: protocol.AllowDecision = cached_auth + local reason: string = allow_decision.reason + local cache_key = allow_decision.cache_key + if cache_key then + local key: string = cache_key + end + local expires_at = allow_decision.expires_at + if expires_at then + local expires_seconds: integer = expires_at:unix() + end + end +end + +local cached_query = store:lookup_decision("query:tenant-a:doc/alpha") +if cached_query then + if cached_query.kind == "allow" then + local allow_decision: protocol.AllowDecision = cached_query + local reason: string = allow_decision.reason + end +end + +local cached_defer = store:lookup_decision("tenant-b:defer:update") +if cached_defer then + if cached_defer.kind == "defer" then + local defer_decision: protocol.DeferDecision = cached_defer + local queue: string = defer_decision.queue + local retry_seconds: integer = defer_decision.retry_at:unix() + end +end + +local allowed_counter = store.state.counters["allow"] +if allowed_counter then + local allowed_value: integer = allowed_counter +end + +local deferred_counter = store.state.counters["defer"] +if deferred_counter then + local deferred_value: integer = deferred_counter +end + +local saw_source = store.state.flags["saw_source"] +if saw_source then + local flag: boolean = saw_source +end + +local last_eval_at = store.state.last_eval_at +if last_eval_at then + local last_eval_seconds: integer = last_eval_at:unix() +end + +local source_tag = helpers.tag_value(query_one, "source") +if source_tag then + local source: string = source_tag +end + +if last_runtime_id then + local runtime_id: string = last_runtime_id +end diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/manifest.json b/testdata/fixtures/realworld/tenant-policy-runtime/manifest.json new file mode 100644 index 00000000..d54145b5 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/manifest.json @@ -0,0 +1,17 @@ +{ + "description": "Application-shaped tenant policy runtime mixing staged validators, closure-built evaluators, metatable-backed store/runtime objects, discriminated requests and decisions, dynamic evaluator registries, optional cached decisions, callbacks, counters, tag maps, and time-based summaries.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "policy_store.lua", + "validator_builder.lua", + "rule_builder.lua", + "runtime.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/policy_store.lua b/testdata/fixtures/realworld/tenant-policy-runtime/policy_store.lua new file mode 100644 index 00000000..c19ffd2e --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/policy_store.lua @@ -0,0 +1,109 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") + +type PolicyStore = { + state: protocol.StoreState, + save_policy: (self: PolicyStore, policy: protocol.TenantPolicy) -> (), + lookup_policy: (self: PolicyStore, tenant_id: string) -> protocol.TenantPolicy?, + lookup_decision: (self: PolicyStore, cache_key: string) -> protocol.Decision?, + push_step: (self: PolicyStore, step: protocol.PolicyStep, at: time.Time) -> (), + cache_decision: (self: PolicyStore, request: protocol.Request, decision: protocol.Decision, at: time.Time) -> (), + summarize: (self: PolicyStore, now: time.Time, last_kind: string?) -> protocol.RunSummary, +} + +type Store = PolicyStore + +local Store = {} +Store.__index = Store + +local M = {} +M.PolicyStore = PolicyStore + +function M.new(id: string, now: time.Time): PolicyStore + local self: Store = { + state = { + id = id, + started_at = now, + last_eval_at = nil, + policies = {}, + cached_decisions = {}, + steps = {}, + counters = {}, + flags = {}, + }, + save_policy = Store.save_policy, + lookup_policy = Store.lookup_policy, + lookup_decision = Store.lookup_decision, + push_step = Store.push_step, + cache_decision = Store.cache_decision, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:save_policy(policy: protocol.TenantPolicy) + self.state.policies[policy.tenant_id] = policy +end + +function Store:lookup_policy(tenant_id: string): protocol.TenantPolicy? + return self.state.policies[tenant_id] +end + +function Store:lookup_decision(cache_key: string): protocol.Decision? + return self.state.cached_decisions[cache_key] +end + +function Store:push_step(step: protocol.PolicyStep, at: time.Time) + table.insert(self.state.steps, step) + self.state.last_eval_at = at +end + +function Store:cache_decision(request: protocol.Request, decision: protocol.Decision, at: time.Time) + if request.kind == "tick" then + return + end + + local current = self.state.policies[request.tenant_id] + if current then + local updated_policy: protocol.TenantPolicy = { + tenant_id = current.tenant_id, + scopes = current.scopes, + allowed_resources = current.allowed_resources, + fallback_queue = current.fallback_queue, + tags = current.tags, + last_checked = at, + } + self.state.policies[current.tenant_id] = updated_policy + end + + if decision.kind == "allow" then + local key = decision.cache_key + if key then + self.state.cached_decisions[key] = decision + end + elseif decision.kind == "defer" then + self.state.cached_decisions[request.tenant_id .. ":defer:" .. request.kind] = decision + end + + helpers.bump_counter(self.state.counters, decision.kind) +end + +function Store:summarize(now: time.Time, last_kind: string?): protocol.RunSummary + local allowed_count = self.state.counters["allow"] or 0 + local denied_count = self.state.counters["deny"] or 0 + local deferred_count = self.state.counters["defer"] or 0 + + return { + id = self.state.id, + total_processed = #self.state.steps, + allowed_count = allowed_count, + denied_count = denied_count, + deferred_count = deferred_count, + elapsed_seconds = now:sub(self.state.started_at), + last_kind = last_kind, + } +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/protocol.lua b/testdata/fixtures/realworld/tenant-policy-runtime/protocol.lua new file mode 100644 index 00000000..ded174f1 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/protocol.lua @@ -0,0 +1,158 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type PolicyMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type AuthRequest = { + kind: "auth", + tenant_id: string, + actor_id: string, + scope: string, + meta: PolicyMeta, +} + +type QueryRequest = { + kind: "query", + tenant_id: string, + actor_id: string, + resource: string, + meta: PolicyMeta, +} + +type UpdateRequest = { + kind: "update", + tenant_id: string, + actor_id: string, + resource: string, + change_set: {[string]: string}, + meta: PolicyMeta, +} + +type TickRequest = { + kind: "tick", + at: time.Time, +} + +type Request = AuthRequest | QueryRequest | UpdateRequest | TickRequest + +type AllowDecision = { + kind: "allow", + reason: string, + cache_key: string?, + expires_at: time.Time?, +} + +type DenyDecision = { + kind: "deny", + reason: string, + retryable: boolean, +} + +type DeferDecision = { + kind: "defer", + queue: string, + retry_at: time.Time, +} + +type Decision = AllowDecision | DenyDecision | DeferDecision + +type TenantPolicy = { + tenant_id: string, + scopes: {[string]: boolean}, + allowed_resources: {[string]: boolean}, + fallback_queue: string?, + tags: {[string]: string}?, + last_checked: time.Time?, +} + +type DecisionStep = { + kind: "decision", + request_kind: "auth" | "query" | "update", + tenant_id: string, + note: string, +} + +type DeferStep = { + kind: "defer", + tenant_id: string, + queue: string, + note: string, + retry_at: time.Time, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type PolicyStep = DecisionStep | DeferStep | AuditStep + +type StoreState = { + id: string, + started_at: time.Time, + last_eval_at: time.Time?, + policies: {[string]: TenantPolicy}, + cached_decisions: {[string]: Decision}, + steps: {PolicyStep}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type RunSummary = { + id: string, + total_processed: number, + allowed_count: number, + denied_count: number, + deferred_count: number, + elapsed_seconds: time.Duration, + last_kind: string?, +} + +type ValidationResult = {ok: true, value: Request} | {ok: false, error: AppError} +type EvaluateResult = {ok: true, value: Decision?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type RequestValidator = (StoreState, Request) -> ValidationResult +type PolicyEvaluator = (StoreState, Request, time.Time) -> EvaluateResult +type StepHook = (PolicyStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.PolicyMeta = PolicyMeta +M.AuthRequest = AuthRequest +M.QueryRequest = QueryRequest +M.UpdateRequest = UpdateRequest +M.TickRequest = TickRequest +M.Request = Request +M.AllowDecision = AllowDecision +M.DenyDecision = DenyDecision +M.DeferDecision = DeferDecision +M.Decision = Decision +M.TenantPolicy = TenantPolicy +M.DecisionStep = DecisionStep +M.DeferStep = DeferStep +M.AuditStep = AuditStep +M.PolicyStep = PolicyStep +M.StoreState = StoreState +M.RunSummary = RunSummary +M.ValidationResult = ValidationResult +M.EvaluateResult = EvaluateResult +M.RunResult = RunResult +M.RequestValidator = RequestValidator +M.PolicyEvaluator = PolicyEvaluator +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): PolicyMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/result.lua b/testdata/fixtures/realworld/tenant-policy-runtime/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/rule_builder.lua b/testdata/fixtures/realworld/tenant-policy-runtime/rule_builder.lua new file mode 100644 index 00000000..7884ae63 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/rule_builder.lua @@ -0,0 +1,191 @@ +local protocol = require("protocol") +local helpers = require("helpers") + +type RuleDecorator = (string, protocol.TenantPolicy, protocol.Request) -> string + +type RuleBuilder = { + name: string, + request_kind: "auth" | "query" | "update", + required_scope: string?, + fallback_queue: string?, + cache_prefix: string?, + decorator: RuleDecorator?, + named: (self: RuleBuilder, name: string) -> RuleBuilder, + for_kind: (self: RuleBuilder, request_kind: "auth" | "query" | "update") -> RuleBuilder, + require_scope: (self: RuleBuilder, scope: string) -> RuleBuilder, + fallback_to: (self: RuleBuilder, queue: string) -> RuleBuilder, + cache_with: (self: RuleBuilder, prefix: string) -> RuleBuilder, + decorate: (self: RuleBuilder, decorator: RuleDecorator) -> RuleBuilder, + build: (self: RuleBuilder) -> protocol.PolicyEvaluator, +} + +type Builder = RuleBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): RuleBuilder + local self: Builder = { + name = "rule", + request_kind = "auth", + required_scope = nil, + fallback_queue = nil, + cache_prefix = nil, + decorator = nil, + named = Builder.named, + for_kind = Builder.for_kind, + require_scope = Builder.require_scope, + fallback_to = Builder.fallback_to, + cache_with = Builder.cache_with, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:for_kind(request_kind: "auth" | "query" | "update"): Builder + self.request_kind = request_kind + return self +end + +function Builder:require_scope(scope: string): Builder + self.required_scope = scope + return self +end + +function Builder:fallback_to(queue: string): Builder + self.fallback_queue = queue + return self +end + +function Builder:cache_with(prefix: string): Builder + self.cache_prefix = prefix + return self +end + +function Builder:decorate(decorator: RuleDecorator): Builder + self.decorator = decorator + return self +end + +function Builder:build(): protocol.PolicyEvaluator + local name = self.name + local request_kind = self.request_kind + local required_scope = self.required_scope + local fallback_queue = self.fallback_queue + local cache_prefix = self.cache_prefix + local decorator = self.decorator + + return function(state: protocol.StoreState, request: protocol.Request, at: time.Time): protocol.EvaluateResult + if request.kind == "tick" then + return { + ok = false, + error = { + code = "invalid", + message = name .. " cannot evaluate ticks", + retryable = false, + }, + } + end + + if request.kind ~= request_kind then + return { + ok = false, + error = { + code = "invalid", + message = name .. " wrong request kind: " .. helpers.request_label(request), + retryable = false, + }, + } + end + + local policy = state.policies[request.tenant_id] + if not policy then + return { + ok = false, + error = { + code = "not_found", + message = name .. " missing tenant policy", + retryable = false, + }, + } + end + + local reason = name .. ":" .. request.tenant_id + if decorator then + reason = decorator(reason, policy, request) + end + + if request.kind == "auth" then + local scope_key = required_scope or request.scope + if policy.scopes[scope_key] then + local cache_key: string? = nil + if cache_prefix then + cache_key = cache_prefix .. ":" .. request.tenant_id .. ":" .. request.actor_id + end + local allow: protocol.AllowDecision = { + kind = "allow", + reason = reason, + cache_key = cache_key, + expires_at = at, + } + return {ok = true, value = allow} + end + + local deny: protocol.DenyDecision = { + kind = "deny", + reason = reason, + retryable = false, + } + return {ok = true, value = deny} + end + + local resource: string + if request.kind == "query" then + resource = request.resource + else + resource = request.resource + end + + if policy.allowed_resources[resource] then + local cache_key: string? = nil + if cache_prefix then + cache_key = cache_prefix .. ":" .. request.tenant_id .. ":" .. resource + end + local allow: protocol.AllowDecision = { + kind = "allow", + reason = reason, + cache_key = cache_key, + expires_at = at, + } + return {ok = true, value = allow} + end + + local queue = fallback_queue or policy.fallback_queue + if queue then + local defer: protocol.DeferDecision = { + kind = "defer", + queue = queue, + retry_at = at, + } + return {ok = true, value = defer} + end + + local deny: protocol.DenyDecision = { + kind = "deny", + reason = reason, + retryable = false, + } + return {ok = true, value = deny} + end +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/runtime.lua b/testdata/fixtures/realworld/tenant-policy-runtime/runtime.lua new file mode 100644 index 00000000..99ea7701 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/runtime.lua @@ -0,0 +1,157 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local policy_store = require("policy_store") + +type PolicyRuntime = { + validators: {protocol.RequestValidator}, + evaluators: {[string]: protocol.PolicyEvaluator}, + hooks: {protocol.StepHook}, + use_validator: (self: PolicyRuntime, validator: protocol.RequestValidator) -> PolicyRuntime, + register_evaluator: (self: PolicyRuntime, request_kind: string, evaluator: protocol.PolicyEvaluator) -> PolicyRuntime, + on_step: (self: PolicyRuntime, hook: protocol.StepHook) -> PolicyRuntime, + new_store: (self: PolicyRuntime, id: string, now: time.Time) -> policy_store.PolicyStore, + emit: (self: PolicyRuntime, store: policy_store.PolicyStore, step: protocol.PolicyStep, at: time.Time) -> (), + evaluate: (self: PolicyRuntime, store: policy_store.PolicyStore, request: protocol.Request, at: time.Time) -> protocol.EvaluateResult, + run: (self: PolicyRuntime, store: policy_store.PolicyStore, requests: {protocol.Request}, now: time.Time) -> protocol.RunResult, +} + +type Runtime = PolicyRuntime + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.PolicyRuntime = PolicyRuntime + +function M.new(): PolicyRuntime + local self: Runtime = { + validators = {}, + evaluators = {}, + hooks = {}, + use_validator = Runtime.use_validator, + register_evaluator = Runtime.register_evaluator, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + evaluate = Runtime.evaluate, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:use_validator(validator: protocol.RequestValidator): Runtime + table.insert(self.validators, validator) + return self +end + +function Runtime:register_evaluator(request_kind: string, evaluator: protocol.PolicyEvaluator): Runtime + self.evaluators[request_kind] = evaluator + return self +end + +function Runtime:on_step(hook: protocol.StepHook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): policy_store.PolicyStore + return policy_store.new(id, now) +end + +function Runtime:emit(store: policy_store.PolicyStore, step: protocol.PolicyStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:evaluate( + store: policy_store.PolicyStore, + request: protocol.Request, + at: time.Time +): protocol.EvaluateResult + if request.kind == "tick" then + local audit_step: protocol.AuditStep = {kind = "audit", note = "tick", at = request.at} + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + for _, validator in ipairs(self.validators) do + local validation: protocol.ValidationResult = validator(store.state, request) + if not validation.ok then + return {ok = false, error = validation.error} + end + end + + local evaluator = self.evaluators[request.kind] + if not evaluator then + return { + ok = false, + error = { + code = "not_found", + message = "missing evaluator: " .. request.kind, + retryable = false, + }, + } + end + + local evaluation = evaluator(store.state, request, at) + if not evaluation.ok then + return {ok = false, error = evaluation.error} + end + + local decision = evaluation.value + if not decision then + return {ok = true, value = nil} + end + + store:cache_decision(request, decision, at) + + if decision.kind == "defer" then + local defer_decision: protocol.DeferDecision = decision + local defer_step: protocol.DeferStep = { + kind = "defer", + tenant_id = request.tenant_id, + queue = defer_decision.queue, + note = helpers.decision_note(defer_decision), + retry_at = defer_decision.retry_at, + } + self:emit(store, defer_step, at) + return {ok = true, value = defer_decision} + end + + local decision_step: protocol.DecisionStep = { + kind = "decision", + request_kind = request.kind, + tenant_id = request.tenant_id, + note = helpers.decision_note(decision), + } + self:emit(store, decision_step, at) + return {ok = true, value = decision} +end + +function Runtime:run( + store: policy_store.PolicyStore, + requests: {protocol.Request}, + now: time.Time +): protocol.RunResult + local last_kind: string? = nil + + for _, request in ipairs(requests) do + local evaluation: protocol.EvaluateResult = self:evaluate(store, request, now) + if not evaluation.ok then + return {ok = false, error = evaluation.error} + end + + local decision = evaluation.value + if decision then + last_kind = decision.kind + end + end + + return {ok = true, value = store:summarize(now, last_kind)} +end + +return M diff --git a/testdata/fixtures/realworld/tenant-policy-runtime/validator_builder.lua b/testdata/fixtures/realworld/tenant-policy-runtime/validator_builder.lua new file mode 100644 index 00000000..1c71ae26 --- /dev/null +++ b/testdata/fixtures/realworld/tenant-policy-runtime/validator_builder.lua @@ -0,0 +1,131 @@ +local protocol = require("protocol") + +type ValidatorBuilder = { + name: string, + required_tag: string?, + required_scope: string?, + required_resource_prefix: string?, + flag_name: string?, + named: (self: ValidatorBuilder, name: string) -> ValidatorBuilder, + require_tag: (self: ValidatorBuilder, tag: string) -> ValidatorBuilder, + require_scope: (self: ValidatorBuilder, scope: string) -> ValidatorBuilder, + require_resource_prefix: (self: ValidatorBuilder, prefix: string) -> ValidatorBuilder, + remember_flag: (self: ValidatorBuilder, flag_name: string) -> ValidatorBuilder, + build: (self: ValidatorBuilder) -> protocol.RequestValidator, +} + +type Builder = ValidatorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} + +function M.new(): ValidatorBuilder + local self: Builder = { + name = "validator", + required_tag = nil, + required_scope = nil, + required_resource_prefix = nil, + flag_name = nil, + named = Builder.named, + require_tag = Builder.require_tag, + require_scope = Builder.require_scope, + require_resource_prefix = Builder.require_resource_prefix, + remember_flag = Builder.remember_flag, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_tag(tag: string): Builder + self.required_tag = tag + return self +end + +function Builder:require_scope(scope: string): Builder + self.required_scope = scope + return self +end + +function Builder:require_resource_prefix(prefix: string): Builder + self.required_resource_prefix = prefix + return self +end + +function Builder:remember_flag(flag_name: string): Builder + self.flag_name = flag_name + return self +end + +function Builder:build(): protocol.RequestValidator + local name = self.name + local required_tag = self.required_tag + local required_scope = self.required_scope + local required_resource_prefix = self.required_resource_prefix + local flag_name = self.flag_name + + return function(state: protocol.StoreState, request: protocol.Request): protocol.ValidationResult + if request.kind == "tick" then + return {ok = true, value = request} + end + + if required_tag then + local tags = request.meta.tags + if not tags or not tags[required_tag] then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + if required_scope and request.kind == "auth" and request.scope ~= required_scope then + return { + ok = false, + error = { + code = "invalid", + message = name .. " wrong scope: " .. request.scope, + retryable = false, + }, + } + end + + if required_resource_prefix and request.kind ~= "auth" then + local resource: string + if request.kind == "query" then + resource = request.resource + else + resource = request.resource + end + if string.sub(resource, 1, #required_resource_prefix) ~= required_resource_prefix then + return { + ok = false, + error = { + code = "invalid", + message = name .. " wrong resource prefix", + retryable = false, + }, + } + end + end + + if flag_name then + state.flags[flag_name] = true + end + + return {ok = true, value = request} + end +end + +return M diff --git a/testdata/fixtures/realworld/trait-registry/main.lua b/testdata/fixtures/realworld/trait-registry/main.lua index 4fe3058d..5dc72e5b 100644 --- a/testdata/fixtures/realworld/trait-registry/main.lua +++ b/testdata/fixtures/realworld/trait-registry/main.lua @@ -1,7 +1,7 @@ local types = require("types") local processor = require("processor") -local entry: TraitRegistryEntry = { +local entry: types.TraitRegistryEntry = { id = "search-trait", meta = {type = types.TRAIT_TYPE, name = "Search", comment = "Web search capability"}, data = { diff --git a/testdata/fixtures/realworld/trait-registry/processor.lua b/testdata/fixtures/realworld/trait-registry/processor.lua index 69759c91..5db8d577 100644 --- a/testdata/fixtures/realworld/trait-registry/processor.lua +++ b/testdata/fixtures/realworld/trait-registry/processor.lua @@ -2,11 +2,11 @@ local types = require("types") local M = {} -function M.normalize_tool(tool_def: TraitToolDef): TraitToolEntry +function M.normalize_tool(tool_def: types.TraitToolDef): types.TraitToolEntry if type(tool_def) == "string" then return {id = tool_def} end - local entry: TraitToolEntry = { + local entry: types.TraitToolEntry = { id = tool_def.id, context = tool_def.context, description = tool_def.description, @@ -15,23 +15,23 @@ function M.normalize_tool(tool_def: TraitToolDef): TraitToolEntry return entry end -function M.normalize_tools(tools_data: {TraitToolDef}?): {TraitToolEntry} +function M.normalize_tools(tools_data: {types.TraitToolDef}?): {types.TraitToolEntry} if not tools_data or #tools_data == 0 then return {} end - local result: {TraitToolEntry} = {} + local result: {types.TraitToolEntry} = {} for _, tool_def in ipairs(tools_data) do table.insert(result, M.normalize_tool(tool_def)) end return result end -function M.build_trait(entry: TraitRegistryEntry): (TraitSpec?, string?) +function M.build_trait(entry: types.TraitRegistryEntry): (types.TraitSpec?, string?) if not entry.data then return nil, "trait has no data: " .. entry.id end local data = entry.data - local spec: TraitSpec = { + local spec: types.TraitSpec = { id = entry.id, name = entry.meta and entry.meta.name or entry.id, description = entry.meta and entry.meta.comment or "", diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/action_builder.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/action_builder.lua new file mode 100644 index 00000000..94f7f7cd --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/action_builder.lua @@ -0,0 +1,308 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") + +type Decorator = (string, protocol.SagaAggregate, protocol.Action) -> string + +type ActionBuilder = { + action_kind: "begin" | "reserve" | "charge" | "commit" | "cancel", + note_prefix: string, + counter_name: string?, + source_key: string?, + decorator: Decorator?, + for_kind: (self: ActionBuilder, kind: "begin" | "reserve" | "charge" | "commit" | "cancel") -> ActionBuilder, + prefix_with: (self: ActionBuilder, prefix: string) -> ActionBuilder, + count_as: (self: ActionBuilder, counter_name: string) -> ActionBuilder, + capture_source: (self: ActionBuilder, source_key: string) -> ActionBuilder, + decorate: (self: ActionBuilder, decorator: Decorator) -> ActionBuilder, + build: (self: ActionBuilder) -> protocol.ActionHandler, +} + +type Builder = ActionBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ActionBuilder = ActionBuilder + +function M.new(): ActionBuilder + local self: Builder = { + action_kind = "begin", + note_prefix = "saga", + counter_name = nil, + source_key = nil, + decorator = nil, + for_kind = Builder.for_kind, + prefix_with = Builder.prefix_with, + count_as = Builder.count_as, + capture_source = Builder.capture_source, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:for_kind(kind: "begin" | "reserve" | "charge" | "commit" | "cancel"): Builder + self.action_kind = kind + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.note_prefix = prefix + return self +end + +function Builder:count_as(counter_name: string): Builder + self.counter_name = counter_name + return self +end + +function Builder:capture_source(source_key: string): Builder + self.source_key = source_key + return self +end + +function Builder:decorate(decorator: Decorator): Builder + self.decorator = decorator + return self +end + +local function ensure_view(state: protocol.StoreState, saga: protocol.SagaAggregate): protocol.SagaView + local current = state.views[saga.order_id] + if current then + return current + end + + local created: protocol.SagaView = { + order_id = saga.order_id, + status = saga.status, + version = saga.version, + reservation_token = saga.reservation_token, + payment_id = saga.payment_id, + source = saga.source, + committed_at = nil, + rolled_back_at = nil, + last_error = saga.last_error, + } + state.views[saga.order_id] = created + return created +end + +function Builder:build(): protocol.ActionHandler + local action_kind = self.action_kind + local note_prefix = self.note_prefix + local counter_name = self.counter_name + local source_key = self.source_key + local decorator = self.decorator + + return function(state: protocol.StoreState, action: protocol.Action, at: time.Time): protocol.HandlerResult + if action.kind == "tick" then + return {ok = true, value = nil} + end + + local saga: protocol.SagaAggregate + + if action_kind == "begin" then + if action.kind ~= "begin" then + return { + ok = false, + error = { + code = "invalid", + message = "expected begin", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if existing then + return { + ok = false, + error = { + code = "conflict", + message = "saga exists: " .. action.order_id, + retryable = false, + }, + } + end + + saga = { + order_id = action.order_id, + customer_id = action.customer_id, + version = 1, + status = "open", + reservation_token = nil, + payment_id = nil, + last_error = nil, + source = nil, + updated_at = at, + compensations = {}, + } + state.sagas[action.order_id] = saga + elseif action_kind == "reserve" then + if action.kind ~= "reserve" then + return { + ok = false, + error = { + code = "invalid", + message = "expected reserve", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing saga: " .. action.order_id, + retryable = false, + }, + } + end + + saga = existing + saga.version = saga.version + 1 + saga.status = "reserved" + saga.reservation_token = "res:" .. action.sku .. ":" .. tostring(action.qty) + table.insert(saga.compensations, { + kind = "release", + reservation_token = saga.reservation_token, + }) + saga.updated_at = at + elseif action_kind == "charge" then + if action.kind ~= "charge" then + return { + ok = false, + error = { + code = "invalid", + message = "expected charge", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing saga: " .. action.order_id, + retryable = false, + }, + } + end + + saga = existing + saga.version = saga.version + 1 + saga.status = "charged" + saga.payment_id = "pay:" .. tostring(action.cents) + table.insert(saga.compensations, { + kind = "refund", + payment_id = saga.payment_id, + }) + saga.updated_at = at + elseif action_kind == "commit" then + if action.kind ~= "commit" then + return { + ok = false, + error = { + code = "invalid", + message = "expected commit", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing saga: " .. action.order_id, + retryable = false, + }, + } + end + + saga = existing + saga.version = saga.version + 1 + saga.status = "committed" + saga.updated_at = at + else + if action.kind ~= "cancel" then + return { + ok = false, + error = { + code = "invalid", + message = "expected cancel", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing saga: " .. action.order_id, + retryable = false, + }, + } + end + + saga = existing + saga.version = saga.version + 1 + saga.status = "rolled_back" + saga.last_error = action.reason + saga.updated_at = at + end + + if source_key then + local tags = action.meta.tags + if tags then + local source = tags[source_key] + if source then + saga.source = source + end + end + end + + local view = ensure_view(state, saga) + view.status = saga.status + view.version = saga.version + view.reservation_token = saga.reservation_token + view.payment_id = saga.payment_id + view.source = saga.source + view.last_error = saga.last_error + if saga.status == "committed" then + view.committed_at = at + elseif saga.status == "rolled_back" then + view.rolled_back_at = at + end + + if counter_name then + local current = state.counters[counter_name] or 0 + state.counters[counter_name] = current + 1 + end + + local note = note_prefix .. ":" .. saga.order_id .. ":" .. saga.status .. ":" .. tostring(saga.version) + if decorator then + note = decorator(note, saga, action) + else + note = note .. ":" .. helpers.action_label(action) + end + + return {ok = true, value = note} + end +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/helpers.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/helpers.lua new file mode 100644 index 00000000..5d2064ec --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/helpers.lua @@ -0,0 +1,46 @@ +local protocol = require("protocol") + +local M = {} + +function M.action_id(action: protocol.Action): string? + if action.kind == "tick" then + return nil + end + return action.order_id +end + +function M.action_label(action: protocol.Action): string + if action.kind == "begin" then + return "begin:" .. action.customer_id + end + if action.kind == "reserve" then + return "reserve:" .. action.sku + end + if action.kind == "charge" then + return "charge:" .. tostring(action.cents) + end + if action.kind == "commit" then + return "commit" + end + if action.kind == "cancel" then + return "cancel:" .. action.reason + end + return "tick" +end + +function M.source_tag(action: protocol.Action): string? + if action.kind == "tick" then + return nil + end + local tags = action.meta.tags + if not tags then + return nil + end + return tags["source"] +end + +function M.status_name(status: string?): string + return status or "unknown" +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/main.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/main.lua new file mode 100644 index 00000000..8e4d61cd --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/main.lua @@ -0,0 +1,91 @@ +local time = require("time") +local protocol = require("protocol") +local validator_builder = require("validator_builder") +local action_builder = require("action_builder") +local orchestrator = require("orchestrator") + +local now = time.now() + +local source_validator = validator_builder.new() + :named("source") + :require_tag("source") + :remember_flag("validated_source") + :build() + +local begin_handler = action_builder.new() + :for_kind("begin") + :prefix_with("begin") + :count_as("begun") + :capture_source("source") + :build() + +local reserve_handler = action_builder.new() + :for_kind("reserve") + :prefix_with("reserve") + :count_as("reserved") + :capture_source("source") + :build() + +local cancel_handler = action_builder.new() + :for_kind("cancel") + :prefix_with("cancel") + :count_as("cancelled") + :capture_source("source") + :build() + +local app = orchestrator.new() + :use_validator(source_validator) + :register_handler("begin", begin_handler) + :register_handler("reserve", reserve_handler) + :register_handler("cancel", cancel_handler) + +local begin_one: protocol.BeginAction = { + kind = "begin", + order_id = "ord-1", + customer_id = "alice", + meta = protocol.meta("trace-1", {source = "api"}), +} + +local reserve_one: protocol.ReserveAction = { + kind = "reserve", + order_id = "ord-1", + sku = "sku-1", + qty = 1, + meta = protocol.meta("trace-2", {source = "worker"}), +} + +local cancel_one: protocol.CancelAction = { + kind = "cancel", + order_id = "ord-1", + reason = "manual_stop", + meta = protocol.meta("trace-3", {source = "worker"}), +} + +local tick: protocol.TickAction = { + kind = "tick", + at = now, +} + +local store = app:new_store("saga-1", now) +local run_result = app:run(store, {begin_one, reserve_one, cancel_one, tick}, now) + +if run_result.ok then + local last_status: string = run_result.value.last_status -- expect-error +end + +local saga = store:lookup_saga("ord-1") +local reservation_token: string = saga.reservation_token -- expect-error +local payment_id: string = saga.payment_id -- expect-error + +for _, comp in ipairs(saga.compensations) do + if comp.kind == "release" then + local token: string = comp.reservation_token + else + local payment: string = comp.payment_id + end +end + +local view = store:lookup_view("ord-1") +local committed_seconds = now:sub(view.committed_at) -- expect-error +local trace_source: string = begin_one.meta.tags["source"] -- expect-error +local last_seen = now:sub(store.state.last_action_at) -- expect-error diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/manifest.json b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/manifest.json new file mode 100644 index 00000000..b6cb9e6c --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/manifest.json @@ -0,0 +1,14 @@ +{ + "description": "Soundness checks for the transactional saga domain: unsafe optional saga, view, compensation, tag, and time uses must be rejected.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "saga_store.lua", + "validator_builder.lua", + "action_builder.lua", + "orchestrator.lua", + "main.lua" + ], + "packages": ["time"] +} diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/orchestrator.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/orchestrator.lua new file mode 100644 index 00000000..0d78c36b --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/orchestrator.lua @@ -0,0 +1,169 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local saga_store = require("saga_store") + +type Orchestrator = { + validators: {protocol.ActionValidator}, + handlers: {[string]: protocol.ActionHandler}, + hooks: {protocol.StepHook}, + use_validator: (self: Orchestrator, validator: protocol.ActionValidator) -> Orchestrator, + register_handler: (self: Orchestrator, kind: string, handler: protocol.ActionHandler) -> Orchestrator, + on_step: (self: Orchestrator, hook: protocol.StepHook) -> Orchestrator, + new_store: (self: Orchestrator, id: string, now: time.Time) -> saga_store.SagaStore, + emit: (self: Orchestrator, store: saga_store.SagaStore, step: protocol.SagaStep, at: time.Time) -> (), + execute: (self: Orchestrator, store: saga_store.SagaStore, action: protocol.Action, at: time.Time) -> protocol.ExecuteResult, + run: (self: Orchestrator, store: saga_store.SagaStore, actions: {protocol.Action}, now: time.Time) -> protocol.RunResult, +} + +type Runtime = Orchestrator + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.Orchestrator = Orchestrator + +function M.new(): Orchestrator + local self: Runtime = { + validators = {}, + handlers = {}, + hooks = {}, + use_validator = Runtime.use_validator, + register_handler = Runtime.register_handler, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + execute = Runtime.execute, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:use_validator(validator: protocol.ActionValidator): Runtime + table.insert(self.validators, validator) + return self +end + +function Runtime:register_handler(kind: string, handler: protocol.ActionHandler): Runtime + self.handlers[kind] = handler + return self +end + +function Runtime:on_step(hook: protocol.StepHook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): saga_store.SagaStore + return saga_store.new(id, now) +end + +function Runtime:emit(store: saga_store.SagaStore, step: protocol.SagaStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:execute( + store: saga_store.SagaStore, + action: protocol.Action, + at: time.Time +): protocol.ExecuteResult + if action.kind == "tick" then + local audit_step: protocol.SagaStep = { + kind = "audit", + note = "tick", + at = action.at, + } + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + local current: protocol.Action = action + for _, validator in ipairs(self.validators) do + local validation_result: protocol.ValidationResult = validator(store.state, current) + if not validation_result.ok then + return {ok = false, error = validation_result.error} + end + current = validation_result.value + end + + local handler = self.handlers[current.kind] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing handler: " .. current.kind, + retryable = false, + }, + } + end + + local handler_value: protocol.ActionHandler = handler + local handler_result: protocol.HandlerResult = handler_value(store.state, current, at) + if not handler_result.ok then + return {ok = false, error = handler_result.error} + end + + local note = handler_result.value or helpers.action_label(current) + local action_step: protocol.SagaStep = { + kind = "action", + name = current.kind, + note = note, + order_id = current.order_id, + } + self:emit(store, action_step, at) + + if action.kind == "cancel" then + local saga = store:lookup_saga(action.order_id) + if saga then + for i = #saga.compensations, 1, -1 do + local comp = saga.compensations[i] + local comp_note: string + local comp_name: string + if comp.kind == "release" then + comp_name = "release" + comp_note = "release:" .. comp.reservation_token + else + comp_name = "refund" + comp_note = "refund:" .. comp.payment_id + end + local comp_step: protocol.SagaStep = { + kind = "compensation", + name = comp_name, + note = comp_note, + order_id = saga.order_id, + } + self:emit(store, comp_step, at) + end + end + end + + return {ok = true, value = current.kind} +end + +function Runtime:run( + store: saga_store.SagaStore, + actions: {protocol.Action}, + now: time.Time +): protocol.RunResult + local last_status: string? = nil + + for _, action in ipairs(actions) do + local execute_result: protocol.ExecuteResult = self:execute(store, action, now) + if not execute_result.ok then + return {ok = false, error = execute_result.error} + end + if execute_result.value then + last_status = execute_result.value + end + end + + return {ok = true, value = store:summarize(now, last_status)} +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/protocol.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/protocol.lua new file mode 100644 index 00000000..3b1a54a4 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/protocol.lua @@ -0,0 +1,173 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type ActionMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type BeginAction = { + kind: "begin", + order_id: string, + customer_id: string, + meta: ActionMeta, +} + +type ReserveAction = { + kind: "reserve", + order_id: string, + sku: string, + qty: integer, + meta: ActionMeta, +} + +type ChargeAction = { + kind: "charge", + order_id: string, + cents: integer, + meta: ActionMeta, +} + +type CommitAction = { + kind: "commit", + order_id: string, + meta: ActionMeta, +} + +type CancelAction = { + kind: "cancel", + order_id: string, + reason: string, + meta: ActionMeta, +} + +type TickAction = { + kind: "tick", + at: time.Time, +} + +type Action = BeginAction | ReserveAction | ChargeAction | CommitAction | CancelAction | TickAction + +type ReleaseInventory = { + kind: "release", + reservation_token: string, +} + +type RefundPayment = { + kind: "refund", + payment_id: string, +} + +type Compensation = ReleaseInventory | RefundPayment + +type SagaAggregate = { + order_id: string, + customer_id: string, + version: integer, + status: "open" | "reserved" | "charged" | "committed" | "rolled_back", + reservation_token: string?, + payment_id: string?, + last_error: string?, + source: string?, + updated_at: time.Time?, + compensations: {Compensation}, +} + +type SagaView = { + order_id: string, + status: "open" | "reserved" | "charged" | "committed" | "rolled_back", + version: integer, + reservation_token: string?, + payment_id: string?, + source: string?, + committed_at: time.Time?, + rolled_back_at: time.Time?, + last_error: string?, +} + +type ActionStep = { + kind: "action", + name: string, + note: string, + order_id: string?, +} + +type CompensationStep = { + kind: "compensation", + name: string, + note: string, + order_id: string?, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type SagaStep = ActionStep | CompensationStep | AuditStep + +type StoreState = { + id: string, + started_at: time.Time, + last_action_at: time.Time?, + steps: {SagaStep}, + sagas: {[string]: SagaAggregate}, + views: {[string]: SagaView}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type RunSummary = { + id: string, + total_steps: number, + saga_count: number, + committed_count: number, + rolled_back_count: number, + last_status: string?, + elapsed_seconds: number, +} + +type ValidationResult = {ok: true, value: Action} | {ok: false, error: AppError} +type HandlerResult = {ok: true, value: string?} | {ok: false, error: AppError} +type ExecuteResult = {ok: true, value: string?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type ActionValidator = (StoreState, Action) -> ValidationResult +type ActionHandler = (StoreState, Action, time.Time) -> HandlerResult +type StepHook = (SagaStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.ActionMeta = ActionMeta +M.BeginAction = BeginAction +M.ReserveAction = ReserveAction +M.ChargeAction = ChargeAction +M.CommitAction = CommitAction +M.CancelAction = CancelAction +M.TickAction = TickAction +M.Action = Action +M.Compensation = Compensation +M.SagaAggregate = SagaAggregate +M.SagaView = SagaView +M.SagaStep = SagaStep +M.StoreState = StoreState +M.RunSummary = RunSummary +M.ValidationResult = ValidationResult +M.HandlerResult = HandlerResult +M.ExecuteResult = ExecuteResult +M.RunResult = RunResult +M.ActionValidator = ActionValidator +M.ActionHandler = ActionHandler +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): ActionMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/result.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/saga_store.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/saga_store.lua new file mode 100644 index 00000000..7b603d35 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/saga_store.lua @@ -0,0 +1,97 @@ +local time = require("time") +local protocol = require("protocol") + +type SagaStore = { + state: protocol.StoreState, + touch: (self: SagaStore, at: time.Time) -> SagaStore, + push_step: (self: SagaStore, step: protocol.SagaStep, at: time.Time) -> SagaStore, + lookup_saga: (self: SagaStore, order_id: string) -> protocol.SagaAggregate?, + lookup_view: (self: SagaStore, order_id: string) -> protocol.SagaView?, + increment: (self: SagaStore, name: string) -> integer, + summarize: (self: SagaStore, now: time.Time, last_status: string?) -> protocol.RunSummary, +} + +type Store = SagaStore + +local Store = {} +Store.__index = Store + +local M = {} +M.SagaStore = SagaStore + +function M.new(id: string, now: time.Time): SagaStore + local self: Store = { + state = { + id = id, + started_at = now, + last_action_at = nil, + steps = {}, + sagas = {}, + views = {}, + counters = {}, + flags = {}, + }, + touch = Store.touch, + push_step = Store.push_step, + lookup_saga = Store.lookup_saga, + lookup_view = Store.lookup_view, + increment = Store.increment, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:touch(at: time.Time): Store + self.state.last_action_at = at + return self +end + +function Store:push_step(step: protocol.SagaStep, at: time.Time): Store + table.insert(self.state.steps, step) + return self:touch(at) +end + +function Store:lookup_saga(order_id: string): protocol.SagaAggregate? + return self.state.sagas[order_id] +end + +function Store:lookup_view(order_id: string): protocol.SagaView? + return self.state.views[order_id] +end + +function Store:increment(name: string): integer + local current = self.state.counters[name] or 0 + local next_value = current + 1 + self.state.counters[name] = next_value + return next_value +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.RunSummary + local saga_count = 0 + local committed_count = 0 + local rolled_back_count = 0 + for _, view in pairs(self.state.views) do + saga_count = saga_count + 1 + if view.status == "committed" then + committed_count = committed_count + 1 + elseif view.status == "rolled_back" then + rolled_back_count = rolled_back_count + 1 + end + end + + local seen_at = self.state.last_action_at or self.state.started_at + local elapsed = now:sub(seen_at) + + return { + id = self.state.id, + total_steps = #self.state.steps, + saga_count = saga_count, + committed_count = committed_count, + rolled_back_count = rolled_back_count, + last_status = last_status, + elapsed_seconds = elapsed:seconds(), + } +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/validator_builder.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/validator_builder.lua new file mode 100644 index 00000000..594f4ad2 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator-soundness/validator_builder.lua @@ -0,0 +1,94 @@ +local protocol = require("protocol") + +type ValidatorBuilder = { + name: string, + required_tag: string?, + flag_name: string?, + named: (self: ValidatorBuilder, name: string) -> ValidatorBuilder, + require_tag: (self: ValidatorBuilder, tag_name: string) -> ValidatorBuilder, + remember_flag: (self: ValidatorBuilder, flag_name: string) -> ValidatorBuilder, + build: (self: ValidatorBuilder) -> protocol.ActionValidator, +} + +type Builder = ValidatorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ValidatorBuilder = ValidatorBuilder + +function M.new(): ValidatorBuilder + local self: Builder = { + name = "validator", + required_tag = nil, + flag_name = nil, + named = Builder.named, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_tag(tag_name: string): Builder + self.required_tag = tag_name + return self +end + +function Builder:remember_flag(flag_name: string): Builder + self.flag_name = flag_name + return self +end + +function Builder:build(): protocol.ActionValidator + local name = self.name + local required_tag = self.required_tag + local flag_name = self.flag_name + + return function(state: protocol.StoreState, action: protocol.Action): protocol.ValidationResult + if action.kind == "tick" then + return {ok = true, value = action} + end + + if required_tag then + local tags = action.meta.tags + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + if flag_name then + state.flags[flag_name] = true + end + + return {ok = true, value = action} + end +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/action_builder.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator/action_builder.lua new file mode 100644 index 00000000..94f7f7cd --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/action_builder.lua @@ -0,0 +1,308 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") + +type Decorator = (string, protocol.SagaAggregate, protocol.Action) -> string + +type ActionBuilder = { + action_kind: "begin" | "reserve" | "charge" | "commit" | "cancel", + note_prefix: string, + counter_name: string?, + source_key: string?, + decorator: Decorator?, + for_kind: (self: ActionBuilder, kind: "begin" | "reserve" | "charge" | "commit" | "cancel") -> ActionBuilder, + prefix_with: (self: ActionBuilder, prefix: string) -> ActionBuilder, + count_as: (self: ActionBuilder, counter_name: string) -> ActionBuilder, + capture_source: (self: ActionBuilder, source_key: string) -> ActionBuilder, + decorate: (self: ActionBuilder, decorator: Decorator) -> ActionBuilder, + build: (self: ActionBuilder) -> protocol.ActionHandler, +} + +type Builder = ActionBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ActionBuilder = ActionBuilder + +function M.new(): ActionBuilder + local self: Builder = { + action_kind = "begin", + note_prefix = "saga", + counter_name = nil, + source_key = nil, + decorator = nil, + for_kind = Builder.for_kind, + prefix_with = Builder.prefix_with, + count_as = Builder.count_as, + capture_source = Builder.capture_source, + decorate = Builder.decorate, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:for_kind(kind: "begin" | "reserve" | "charge" | "commit" | "cancel"): Builder + self.action_kind = kind + return self +end + +function Builder:prefix_with(prefix: string): Builder + self.note_prefix = prefix + return self +end + +function Builder:count_as(counter_name: string): Builder + self.counter_name = counter_name + return self +end + +function Builder:capture_source(source_key: string): Builder + self.source_key = source_key + return self +end + +function Builder:decorate(decorator: Decorator): Builder + self.decorator = decorator + return self +end + +local function ensure_view(state: protocol.StoreState, saga: protocol.SagaAggregate): protocol.SagaView + local current = state.views[saga.order_id] + if current then + return current + end + + local created: protocol.SagaView = { + order_id = saga.order_id, + status = saga.status, + version = saga.version, + reservation_token = saga.reservation_token, + payment_id = saga.payment_id, + source = saga.source, + committed_at = nil, + rolled_back_at = nil, + last_error = saga.last_error, + } + state.views[saga.order_id] = created + return created +end + +function Builder:build(): protocol.ActionHandler + local action_kind = self.action_kind + local note_prefix = self.note_prefix + local counter_name = self.counter_name + local source_key = self.source_key + local decorator = self.decorator + + return function(state: protocol.StoreState, action: protocol.Action, at: time.Time): protocol.HandlerResult + if action.kind == "tick" then + return {ok = true, value = nil} + end + + local saga: protocol.SagaAggregate + + if action_kind == "begin" then + if action.kind ~= "begin" then + return { + ok = false, + error = { + code = "invalid", + message = "expected begin", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if existing then + return { + ok = false, + error = { + code = "conflict", + message = "saga exists: " .. action.order_id, + retryable = false, + }, + } + end + + saga = { + order_id = action.order_id, + customer_id = action.customer_id, + version = 1, + status = "open", + reservation_token = nil, + payment_id = nil, + last_error = nil, + source = nil, + updated_at = at, + compensations = {}, + } + state.sagas[action.order_id] = saga + elseif action_kind == "reserve" then + if action.kind ~= "reserve" then + return { + ok = false, + error = { + code = "invalid", + message = "expected reserve", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing saga: " .. action.order_id, + retryable = false, + }, + } + end + + saga = existing + saga.version = saga.version + 1 + saga.status = "reserved" + saga.reservation_token = "res:" .. action.sku .. ":" .. tostring(action.qty) + table.insert(saga.compensations, { + kind = "release", + reservation_token = saga.reservation_token, + }) + saga.updated_at = at + elseif action_kind == "charge" then + if action.kind ~= "charge" then + return { + ok = false, + error = { + code = "invalid", + message = "expected charge", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing saga: " .. action.order_id, + retryable = false, + }, + } + end + + saga = existing + saga.version = saga.version + 1 + saga.status = "charged" + saga.payment_id = "pay:" .. tostring(action.cents) + table.insert(saga.compensations, { + kind = "refund", + payment_id = saga.payment_id, + }) + saga.updated_at = at + elseif action_kind == "commit" then + if action.kind ~= "commit" then + return { + ok = false, + error = { + code = "invalid", + message = "expected commit", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing saga: " .. action.order_id, + retryable = false, + }, + } + end + + saga = existing + saga.version = saga.version + 1 + saga.status = "committed" + saga.updated_at = at + else + if action.kind ~= "cancel" then + return { + ok = false, + error = { + code = "invalid", + message = "expected cancel", + retryable = false, + }, + } + end + + local existing = state.sagas[action.order_id] + if not existing then + return { + ok = false, + error = { + code = "not_found", + message = "missing saga: " .. action.order_id, + retryable = false, + }, + } + end + + saga = existing + saga.version = saga.version + 1 + saga.status = "rolled_back" + saga.last_error = action.reason + saga.updated_at = at + end + + if source_key then + local tags = action.meta.tags + if tags then + local source = tags[source_key] + if source then + saga.source = source + end + end + end + + local view = ensure_view(state, saga) + view.status = saga.status + view.version = saga.version + view.reservation_token = saga.reservation_token + view.payment_id = saga.payment_id + view.source = saga.source + view.last_error = saga.last_error + if saga.status == "committed" then + view.committed_at = at + elseif saga.status == "rolled_back" then + view.rolled_back_at = at + end + + if counter_name then + local current = state.counters[counter_name] or 0 + state.counters[counter_name] = current + 1 + end + + local note = note_prefix .. ":" .. saga.order_id .. ":" .. saga.status .. ":" .. tostring(saga.version) + if decorator then + note = decorator(note, saga, action) + else + note = note .. ":" .. helpers.action_label(action) + end + + return {ok = true, value = note} + end +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/helpers.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator/helpers.lua new file mode 100644 index 00000000..5d2064ec --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/helpers.lua @@ -0,0 +1,46 @@ +local protocol = require("protocol") + +local M = {} + +function M.action_id(action: protocol.Action): string? + if action.kind == "tick" then + return nil + end + return action.order_id +end + +function M.action_label(action: protocol.Action): string + if action.kind == "begin" then + return "begin:" .. action.customer_id + end + if action.kind == "reserve" then + return "reserve:" .. action.sku + end + if action.kind == "charge" then + return "charge:" .. tostring(action.cents) + end + if action.kind == "commit" then + return "commit" + end + if action.kind == "cancel" then + return "cancel:" .. action.reason + end + return "tick" +end + +function M.source_tag(action: protocol.Action): string? + if action.kind == "tick" then + return nil + end + local tags = action.meta.tags + if not tags then + return nil + end + return tags["source"] +end + +function M.status_name(status: string?): string + return status or "unknown" +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/main.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator/main.lua new file mode 100644 index 00000000..ef2e19ac --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/main.lua @@ -0,0 +1,304 @@ +local time = require("time") +local result = require("result") +local protocol = require("protocol") +local helpers = require("helpers") +local validator_builder = require("validator_builder") +local action_builder = require("action_builder") +local orchestrator = require("orchestrator") + +type StringResult = {ok: true, value: string} | {ok: false, error: result.AppError} + +local now = time.now() + +local observed_actions: {[string]: string} = {} +local observed_compensations: {string} = {} +local observed_audits: {string} = {} +local last_runtime_id: string? = nil + +local source_validator = validator_builder.new() + :named("source") + :require_tag("source") + :remember_flag("validated_source") + :build() + +local begin_handler = action_builder.new() + :for_kind("begin") + :prefix_with("begin") + :count_as("begun") + :capture_source("source") + :decorate(function(note: string, saga: protocol.SagaAggregate, action: protocol.Action): string + return note .. ":" .. saga.customer_id .. ":" .. helpers.action_label(action) + end) + :build() + +local reserve_handler = action_builder.new() + :for_kind("reserve") + :prefix_with("reserve") + :count_as("reserved") + :capture_source("source") + :decorate(function(note: string, saga: protocol.SagaAggregate, _action: protocol.Action): string + local token = saga.reservation_token or "missing" + return note .. ":" .. token + end) + :build() + +local charge_handler = action_builder.new() + :for_kind("charge") + :prefix_with("charge") + :count_as("charged") + :capture_source("source") + :decorate(function(note: string, saga: protocol.SagaAggregate, _action: protocol.Action): string + local payment = saga.payment_id or "missing" + return note .. ":" .. payment + end) + :build() + +local commit_handler = action_builder.new() + :for_kind("commit") + :prefix_with("commit") + :count_as("committed") + :capture_source("source") + :decorate(function(note: string, saga: protocol.SagaAggregate, _action: protocol.Action): string + local source = saga.source or "unknown" + return note .. ":" .. source + end) + :build() + +local cancel_handler = action_builder.new() + :for_kind("cancel") + :prefix_with("cancel") + :count_as("cancelled") + :capture_source("source") + :decorate(function(note: string, saga: protocol.SagaAggregate, action: protocol.Action): string + local err = saga.last_error or "none" + return note .. ":" .. err .. ":" .. helpers.action_label(action) + end) + :build() + +local app = orchestrator.new() + :use_validator(source_validator) + :register_handler("begin", begin_handler) + :register_handler("reserve", reserve_handler) + :register_handler("charge", charge_handler) + :register_handler("commit", commit_handler) + :register_handler("cancel", cancel_handler) + +app:on_step(function(step: protocol.SagaStep, state: protocol.StoreState) + last_runtime_id = state.id + + if step.kind == "action" then + observed_actions[step.name .. ":" .. tostring(#observed_audits + 1)] = step.note + if step.order_id then + local order_id: string = step.order_id + end + elseif step.kind == "compensation" then + table.insert(observed_compensations, step.note) + if step.order_id then + local order_id: string = step.order_id + end + else + table.insert(observed_audits, step.note) + local at_seconds: integer = step.at:unix() + end +end) + +local begin_one: protocol.BeginAction = { + kind = "begin", + order_id = "ord-1", + customer_id = "alice", + meta = protocol.meta("trace-1", {source = "api"}), +} + +local reserve_one: protocol.ReserveAction = { + kind = "reserve", + order_id = "ord-1", + sku = "sku-1", + qty = 2, + meta = protocol.meta("trace-2", {source = "worker"}), +} + +local charge_one: protocol.ChargeAction = { + kind = "charge", + order_id = "ord-1", + cents = 4200, + meta = protocol.meta("trace-3", {source = "worker"}), +} + +local commit_one: protocol.CommitAction = { + kind = "commit", + order_id = "ord-1", + meta = protocol.meta("trace-4", {source = "worker"}), +} + +local begin_two: protocol.BeginAction = { + kind = "begin", + order_id = "ord-2", + customer_id = "bob", + meta = protocol.meta("trace-5", {source = "api"}), +} + +local reserve_two: protocol.ReserveAction = { + kind = "reserve", + order_id = "ord-2", + sku = "sku-2", + qty = 1, + meta = protocol.meta("trace-6", {source = "worker"}), +} + +local charge_two: protocol.ChargeAction = { + kind = "charge", + order_id = "ord-2", + cents = 9900, + meta = protocol.meta("trace-7", {source = "worker"}), +} + +local cancel_two: protocol.CancelAction = { + kind = "cancel", + order_id = "ord-2", + reason = "payment_declined", + meta = protocol.meta("trace-8", {source = "worker"}), +} + +local tick: protocol.TickAction = { + kind = "tick", + at = now, +} + +local actions: {protocol.Action} = { + begin_one, + reserve_one, + charge_one, + commit_one, + begin_two, + reserve_two, + charge_two, + cancel_two, + tick, +} + +local store = app:new_store("saga-1", now) +local summary_result = app:run(store, actions, now) +if not summary_result.ok then + local message: string = summary_result.error.message + local retryable: boolean = summary_result.error.retryable +else + local summary = summary_result.value + local runtime_id: string = summary.id + local total_steps: number = summary.total_steps + local saga_count: number = summary.saga_count + local committed_count: number = summary.committed_count + local rolled_back_count: number = summary.rolled_back_count + local elapsed_seconds: number = summary.elapsed_seconds + local last_status: string? = summary.last_status +end + +local summary_label = result.map(summary_result, function(summary: protocol.RunSummary): string + return summary.id .. ":" .. tostring(summary.saga_count) +end) + +if summary_label.ok then + local label: string = summary_label.value +end + +local summary_id = result.and_then(summary_result, function(summary: protocol.RunSummary): StringResult + if summary.rolled_back_count == 0 then + return { + ok = false, + error = { + code = "invalid", + message = "expected rollback", + retryable = false, + }, + } + end + return { + ok = true, + value = summary.id, + } +end) + +if summary_id.ok then + local stable_id: string = summary_id.value +end + +local saga_one = store:lookup_saga("ord-1") +if saga_one then + local status: string = saga_one.status + local version: integer = saga_one.version + local reservation = saga_one.reservation_token + if reservation then + local stable_reservation: string = reservation + end + local payment = saga_one.payment_id + if payment then + local stable_payment: string = payment + end + local source = saga_one.source + if source then + local stable_source: string = source + end + local updated = saga_one.updated_at or now + local seconds: number = now:sub(updated):seconds() +end + +local saga_two = store:lookup_saga("ord-2") +if saga_two then + local failed_status: string = saga_two.status + local error_msg = saga_two.last_error + if error_msg then + local stable_error: string = error_msg + end + for _, comp in ipairs(saga_two.compensations) do + if comp.kind == "release" then + local token: string = comp.reservation_token + else + local payment_id: string = comp.payment_id + end + end +end + +local view_one = store:lookup_view("ord-1") +if view_one then + local committed = view_one.committed_at + if committed then + local unix_seconds: integer = committed:unix() + end +end + +local view_two = store:lookup_view("ord-2") +if view_two then + local rolled_back = view_two.rolled_back_at + if rolled_back then + local unix_seconds: integer = rolled_back:unix() + end + local source = view_two.source + if source then + local stable_source: string = source + end +end + +for key, note in pairs(observed_actions) do + local stable_key: string = key + local stable_note: string = note +end + +for _, note in ipairs(observed_compensations) do + local stable_note: string = note +end + +for _, note in ipairs(observed_audits) do + local audit_note: string = note +end + +if last_runtime_id ~= nil then + local stable_runtime_id: string = last_runtime_id +end + +local source = helpers.source_tag(begin_one) +if source then + local stable_source: string = source +end + +local seen_at = store.state.last_action_at or store.state.started_at +local elapsed = now:sub(seen_at) +local seconds: number = elapsed:seconds() diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/manifest.json b/testdata/fixtures/realworld/transactional-saga-orchestrator/manifest.json new file mode 100644 index 00000000..8f24b676 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/manifest.json @@ -0,0 +1,17 @@ +{ + "description": "Application-shaped saga orchestrator mixing staged validation, fluent action handlers, compensation records, rollback emission, metatable-backed stores, discriminated actions, dynamic registries, optional saga/view fields, callbacks, and time-based summaries.", + "files": [ + "result.lua", + "protocol.lua", + "helpers.lua", + "saga_store.lua", + "validator_builder.lua", + "action_builder.lua", + "orchestrator.lua", + "main.lua" + ], + "packages": ["time"], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/orchestrator.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator/orchestrator.lua new file mode 100644 index 00000000..0d78c36b --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/orchestrator.lua @@ -0,0 +1,169 @@ +local time = require("time") +local protocol = require("protocol") +local helpers = require("helpers") +local saga_store = require("saga_store") + +type Orchestrator = { + validators: {protocol.ActionValidator}, + handlers: {[string]: protocol.ActionHandler}, + hooks: {protocol.StepHook}, + use_validator: (self: Orchestrator, validator: protocol.ActionValidator) -> Orchestrator, + register_handler: (self: Orchestrator, kind: string, handler: protocol.ActionHandler) -> Orchestrator, + on_step: (self: Orchestrator, hook: protocol.StepHook) -> Orchestrator, + new_store: (self: Orchestrator, id: string, now: time.Time) -> saga_store.SagaStore, + emit: (self: Orchestrator, store: saga_store.SagaStore, step: protocol.SagaStep, at: time.Time) -> (), + execute: (self: Orchestrator, store: saga_store.SagaStore, action: protocol.Action, at: time.Time) -> protocol.ExecuteResult, + run: (self: Orchestrator, store: saga_store.SagaStore, actions: {protocol.Action}, now: time.Time) -> protocol.RunResult, +} + +type Runtime = Orchestrator + +local Runtime = {} +Runtime.__index = Runtime + +local M = {} +M.Orchestrator = Orchestrator + +function M.new(): Orchestrator + local self: Runtime = { + validators = {}, + handlers = {}, + hooks = {}, + use_validator = Runtime.use_validator, + register_handler = Runtime.register_handler, + on_step = Runtime.on_step, + new_store = Runtime.new_store, + emit = Runtime.emit, + execute = Runtime.execute, + run = Runtime.run, + } + setmetatable(self, Runtime) + return self +end + +function Runtime:use_validator(validator: protocol.ActionValidator): Runtime + table.insert(self.validators, validator) + return self +end + +function Runtime:register_handler(kind: string, handler: protocol.ActionHandler): Runtime + self.handlers[kind] = handler + return self +end + +function Runtime:on_step(hook: protocol.StepHook): Runtime + table.insert(self.hooks, hook) + return self +end + +function Runtime:new_store(id: string, now: time.Time): saga_store.SagaStore + return saga_store.new(id, now) +end + +function Runtime:emit(store: saga_store.SagaStore, step: protocol.SagaStep, at: time.Time) + store:push_step(step, at) + for _, hook in ipairs(self.hooks) do + hook(step, store.state) + end +end + +function Runtime:execute( + store: saga_store.SagaStore, + action: protocol.Action, + at: time.Time +): protocol.ExecuteResult + if action.kind == "tick" then + local audit_step: protocol.SagaStep = { + kind = "audit", + note = "tick", + at = action.at, + } + self:emit(store, audit_step, at) + return {ok = true, value = nil} + end + + local current: protocol.Action = action + for _, validator in ipairs(self.validators) do + local validation_result: protocol.ValidationResult = validator(store.state, current) + if not validation_result.ok then + return {ok = false, error = validation_result.error} + end + current = validation_result.value + end + + local handler = self.handlers[current.kind] + if not handler then + return { + ok = false, + error = { + code = "not_found", + message = "missing handler: " .. current.kind, + retryable = false, + }, + } + end + + local handler_value: protocol.ActionHandler = handler + local handler_result: protocol.HandlerResult = handler_value(store.state, current, at) + if not handler_result.ok then + return {ok = false, error = handler_result.error} + end + + local note = handler_result.value or helpers.action_label(current) + local action_step: protocol.SagaStep = { + kind = "action", + name = current.kind, + note = note, + order_id = current.order_id, + } + self:emit(store, action_step, at) + + if action.kind == "cancel" then + local saga = store:lookup_saga(action.order_id) + if saga then + for i = #saga.compensations, 1, -1 do + local comp = saga.compensations[i] + local comp_note: string + local comp_name: string + if comp.kind == "release" then + comp_name = "release" + comp_note = "release:" .. comp.reservation_token + else + comp_name = "refund" + comp_note = "refund:" .. comp.payment_id + end + local comp_step: protocol.SagaStep = { + kind = "compensation", + name = comp_name, + note = comp_note, + order_id = saga.order_id, + } + self:emit(store, comp_step, at) + end + end + end + + return {ok = true, value = current.kind} +end + +function Runtime:run( + store: saga_store.SagaStore, + actions: {protocol.Action}, + now: time.Time +): protocol.RunResult + local last_status: string? = nil + + for _, action in ipairs(actions) do + local execute_result: protocol.ExecuteResult = self:execute(store, action, now) + if not execute_result.ok then + return {ok = false, error = execute_result.error} + end + if execute_result.value then + last_status = execute_result.value + end + end + + return {ok = true, value = store:summarize(now, last_status)} +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/protocol.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator/protocol.lua new file mode 100644 index 00000000..3b1a54a4 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/protocol.lua @@ -0,0 +1,173 @@ +local time = require("time") +local result = require("result") + +type AppError = result.AppError + +type ActionMeta = { + trace_id: string, + tags: {[string]: string}?, +} + +type BeginAction = { + kind: "begin", + order_id: string, + customer_id: string, + meta: ActionMeta, +} + +type ReserveAction = { + kind: "reserve", + order_id: string, + sku: string, + qty: integer, + meta: ActionMeta, +} + +type ChargeAction = { + kind: "charge", + order_id: string, + cents: integer, + meta: ActionMeta, +} + +type CommitAction = { + kind: "commit", + order_id: string, + meta: ActionMeta, +} + +type CancelAction = { + kind: "cancel", + order_id: string, + reason: string, + meta: ActionMeta, +} + +type TickAction = { + kind: "tick", + at: time.Time, +} + +type Action = BeginAction | ReserveAction | ChargeAction | CommitAction | CancelAction | TickAction + +type ReleaseInventory = { + kind: "release", + reservation_token: string, +} + +type RefundPayment = { + kind: "refund", + payment_id: string, +} + +type Compensation = ReleaseInventory | RefundPayment + +type SagaAggregate = { + order_id: string, + customer_id: string, + version: integer, + status: "open" | "reserved" | "charged" | "committed" | "rolled_back", + reservation_token: string?, + payment_id: string?, + last_error: string?, + source: string?, + updated_at: time.Time?, + compensations: {Compensation}, +} + +type SagaView = { + order_id: string, + status: "open" | "reserved" | "charged" | "committed" | "rolled_back", + version: integer, + reservation_token: string?, + payment_id: string?, + source: string?, + committed_at: time.Time?, + rolled_back_at: time.Time?, + last_error: string?, +} + +type ActionStep = { + kind: "action", + name: string, + note: string, + order_id: string?, +} + +type CompensationStep = { + kind: "compensation", + name: string, + note: string, + order_id: string?, +} + +type AuditStep = { + kind: "audit", + note: string, + at: time.Time, +} + +type SagaStep = ActionStep | CompensationStep | AuditStep + +type StoreState = { + id: string, + started_at: time.Time, + last_action_at: time.Time?, + steps: {SagaStep}, + sagas: {[string]: SagaAggregate}, + views: {[string]: SagaView}, + counters: {[string]: integer}, + flags: {[string]: boolean}, +} + +type RunSummary = { + id: string, + total_steps: number, + saga_count: number, + committed_count: number, + rolled_back_count: number, + last_status: string?, + elapsed_seconds: number, +} + +type ValidationResult = {ok: true, value: Action} | {ok: false, error: AppError} +type HandlerResult = {ok: true, value: string?} | {ok: false, error: AppError} +type ExecuteResult = {ok: true, value: string?} | {ok: false, error: AppError} +type RunResult = {ok: true, value: RunSummary} | {ok: false, error: AppError} + +type ActionValidator = (StoreState, Action) -> ValidationResult +type ActionHandler = (StoreState, Action, time.Time) -> HandlerResult +type StepHook = (SagaStep, StoreState) -> () + +local M = {} +M.AppError = AppError +M.ActionMeta = ActionMeta +M.BeginAction = BeginAction +M.ReserveAction = ReserveAction +M.ChargeAction = ChargeAction +M.CommitAction = CommitAction +M.CancelAction = CancelAction +M.TickAction = TickAction +M.Action = Action +M.Compensation = Compensation +M.SagaAggregate = SagaAggregate +M.SagaView = SagaView +M.SagaStep = SagaStep +M.StoreState = StoreState +M.RunSummary = RunSummary +M.ValidationResult = ValidationResult +M.HandlerResult = HandlerResult +M.ExecuteResult = ExecuteResult +M.RunResult = RunResult +M.ActionValidator = ActionValidator +M.ActionHandler = ActionHandler +M.StepHook = StepHook + +function M.meta(trace_id: string, tags: {[string]: string}?): ActionMeta + return { + trace_id = trace_id, + tags = tags, + } +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/result.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator/result.lua new file mode 100644 index 00000000..9419acf4 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/result.lua @@ -0,0 +1,45 @@ +type ErrorCode = "not_found" | "invalid" | "busy" | "conflict" + +type AppError = { + code: ErrorCode, + message: string, + retryable: boolean, +} + +type Result = {ok: true, value: T} | {ok: false, error: AppError} + +local M = {} +M.ErrorCode = ErrorCode +M.AppError = AppError +M.Result = Result + +function M.ok(value: T): Result + return {ok = true, value = value} +end + +function M.err(code: ErrorCode, message: string, retryable: boolean?): Result + return { + ok = false, + error = { + code = code, + message = message, + retryable = retryable or false, + }, + } +end + +function M.map(r: Result, fn: (T) -> U): Result + if r.ok then + return M.ok(fn(r.value)) + end + return {ok = false, error = r.error} +end + +function M.and_then(r: Result, fn: (T) -> Result): Result + if r.ok then + return fn(r.value) + end + return {ok = false, error = r.error} +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/saga_store.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator/saga_store.lua new file mode 100644 index 00000000..7b603d35 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/saga_store.lua @@ -0,0 +1,97 @@ +local time = require("time") +local protocol = require("protocol") + +type SagaStore = { + state: protocol.StoreState, + touch: (self: SagaStore, at: time.Time) -> SagaStore, + push_step: (self: SagaStore, step: protocol.SagaStep, at: time.Time) -> SagaStore, + lookup_saga: (self: SagaStore, order_id: string) -> protocol.SagaAggregate?, + lookup_view: (self: SagaStore, order_id: string) -> protocol.SagaView?, + increment: (self: SagaStore, name: string) -> integer, + summarize: (self: SagaStore, now: time.Time, last_status: string?) -> protocol.RunSummary, +} + +type Store = SagaStore + +local Store = {} +Store.__index = Store + +local M = {} +M.SagaStore = SagaStore + +function M.new(id: string, now: time.Time): SagaStore + local self: Store = { + state = { + id = id, + started_at = now, + last_action_at = nil, + steps = {}, + sagas = {}, + views = {}, + counters = {}, + flags = {}, + }, + touch = Store.touch, + push_step = Store.push_step, + lookup_saga = Store.lookup_saga, + lookup_view = Store.lookup_view, + increment = Store.increment, + summarize = Store.summarize, + } + setmetatable(self, Store) + return self +end + +function Store:touch(at: time.Time): Store + self.state.last_action_at = at + return self +end + +function Store:push_step(step: protocol.SagaStep, at: time.Time): Store + table.insert(self.state.steps, step) + return self:touch(at) +end + +function Store:lookup_saga(order_id: string): protocol.SagaAggregate? + return self.state.sagas[order_id] +end + +function Store:lookup_view(order_id: string): protocol.SagaView? + return self.state.views[order_id] +end + +function Store:increment(name: string): integer + local current = self.state.counters[name] or 0 + local next_value = current + 1 + self.state.counters[name] = next_value + return next_value +end + +function Store:summarize(now: time.Time, last_status: string?): protocol.RunSummary + local saga_count = 0 + local committed_count = 0 + local rolled_back_count = 0 + for _, view in pairs(self.state.views) do + saga_count = saga_count + 1 + if view.status == "committed" then + committed_count = committed_count + 1 + elseif view.status == "rolled_back" then + rolled_back_count = rolled_back_count + 1 + end + end + + local seen_at = self.state.last_action_at or self.state.started_at + local elapsed = now:sub(seen_at) + + return { + id = self.state.id, + total_steps = #self.state.steps, + saga_count = saga_count, + committed_count = committed_count, + rolled_back_count = rolled_back_count, + last_status = last_status, + elapsed_seconds = elapsed:seconds(), + } +end + +return M diff --git a/testdata/fixtures/realworld/transactional-saga-orchestrator/validator_builder.lua b/testdata/fixtures/realworld/transactional-saga-orchestrator/validator_builder.lua new file mode 100644 index 00000000..594f4ad2 --- /dev/null +++ b/testdata/fixtures/realworld/transactional-saga-orchestrator/validator_builder.lua @@ -0,0 +1,94 @@ +local protocol = require("protocol") + +type ValidatorBuilder = { + name: string, + required_tag: string?, + flag_name: string?, + named: (self: ValidatorBuilder, name: string) -> ValidatorBuilder, + require_tag: (self: ValidatorBuilder, tag_name: string) -> ValidatorBuilder, + remember_flag: (self: ValidatorBuilder, flag_name: string) -> ValidatorBuilder, + build: (self: ValidatorBuilder) -> protocol.ActionValidator, +} + +type Builder = ValidatorBuilder + +local Builder = {} +Builder.__index = Builder + +local M = {} +M.ValidatorBuilder = ValidatorBuilder + +function M.new(): ValidatorBuilder + local self: Builder = { + name = "validator", + required_tag = nil, + flag_name = nil, + named = Builder.named, + require_tag = Builder.require_tag, + remember_flag = Builder.remember_flag, + build = Builder.build, + } + setmetatable(self, Builder) + return self +end + +function Builder:named(name: string): Builder + self.name = name + return self +end + +function Builder:require_tag(tag_name: string): Builder + self.required_tag = tag_name + return self +end + +function Builder:remember_flag(flag_name: string): Builder + self.flag_name = flag_name + return self +end + +function Builder:build(): protocol.ActionValidator + local name = self.name + local required_tag = self.required_tag + local flag_name = self.flag_name + + return function(state: protocol.StoreState, action: protocol.Action): protocol.ValidationResult + if action.kind == "tick" then + return {ok = true, value = action} + end + + if required_tag then + local tags = action.meta.tags + if not tags then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tags", + retryable = false, + }, + } + end + + local value = tags[required_tag] + if not value then + return { + ok = false, + error = { + code = "invalid", + message = name .. " missing tag: " .. required_tag, + retryable = false, + }, + } + end + end + + if flag_name then + state.flags[flag_name] = true + end + + return {ok = true, value = action} + end +end + +return M diff --git a/testdata/fixtures/realworld/typed-callback-chain/main.lua b/testdata/fixtures/realworld/typed-callback-chain/main.lua index 3e704f24..8843630b 100644 --- a/testdata/fixtures/realworld/typed-callback-chain/main.lua +++ b/testdata/fixtures/realworld/typed-callback-chain/main.lua @@ -2,10 +2,15 @@ local types = require("types") local stream = require("stream") local collected_chunks: {string} = {} -local collected_tools: {ToolCall} = {} -local final_result: StreamResult? = nil +local collected_tools: {types.ToolCall} = {} +local final_result: types.StreamResult? = nil -local events = { +type ContentEvent = {type: "content", data: string} +type ToolCallEvent = {type: "tool_call", id: string, name: string, arguments: {[string]: any}} +type DoneEvent = {type: "done", reason: string?, usage: types.Usage?} +type StreamEvent = ContentEvent | ToolCallEvent | DoneEvent + +local events: {StreamEvent} = { {type = "content", data = "Hello "}, {type = "content", data = "world"}, {type = "tool_call", id = "t1", name = "search", arguments = {query = "test"}}, @@ -16,12 +21,12 @@ local result, err = stream.process(events, { on_content = function(chunk: string) table.insert(collected_chunks, chunk) end, - on_tool_call = function(call: ToolCall) + on_tool_call = function(call: types.ToolCall) table.insert(collected_tools, call) local name: string = call.name local id: string = call.id end, - on_done = function(result: StreamResult) + on_done = function(result: types.StreamResult) final_result = result local content: string = result.content local tokens: number = result.usage.input_tokens diff --git a/testdata/fixtures/realworld/typed-callback-chain/manifest.json b/testdata/fixtures/realworld/typed-callback-chain/manifest.json index be2c16be..197bb78a 100644 --- a/testdata/fixtures/realworld/typed-callback-chain/manifest.json +++ b/testdata/fixtures/realworld/typed-callback-chain/manifest.json @@ -1,4 +1 @@ -{ - "files": ["types.lua", "stream.lua", "main.lua"], - "check": {"errors": 29} -} +{"files": ["types.lua", "stream.lua", "main.lua"], "check": {"errors": 0}} diff --git a/testdata/fixtures/realworld/typed-callback-chain/stream.lua b/testdata/fixtures/realworld/typed-callback-chain/stream.lua index 67e90c53..ae2c12a4 100644 --- a/testdata/fixtures/realworld/typed-callback-chain/stream.lua +++ b/testdata/fixtures/realworld/typed-callback-chain/stream.lua @@ -1,8 +1,14 @@ local types = require("types") +type ContentEvent = {type: "content", data: string} +type ToolCallEvent = {type: "tool_call", id: string, name: string, arguments: {[string]: any}} +type ErrorEvent = {type: "error", message: string, code: string?} +type DoneEvent = {type: "done", reason: string?, usage: types.Usage?} +type StreamEvent = ContentEvent | ToolCallEvent | ErrorEvent | DoneEvent + local M = {} -function M.process(events: {any}, callbacks: StreamCallbacks?): (StreamResult?, string?) +function M.process(events: {StreamEvent}, callbacks: types.StreamCallbacks?): (types.StreamResult?, string?) callbacks = callbacks or {} local on_content = callbacks.on_content @@ -10,37 +16,47 @@ function M.process(events: {any}, callbacks: StreamCallbacks?): (StreamResult?, local on_error = callbacks.on_error local on_done = callbacks.on_done - local result = types.empty_result() + local content = "" + local tool_calls: {types.ToolCall} = {} + local finish_reason: string? = nil + local usage: types.Usage = {input_tokens = 0, output_tokens = 0} for _, event in ipairs(events) do if event.type == "content" then - local chunk: string = event.data - result.content = result.content .. chunk + content = content .. event.data if on_content then - on_content(chunk) + on_content(event.data) end elseif event.type == "tool_call" then - local call: ToolCall = { + local call: types.ToolCall = { id = event.id, name = event.name, - arguments = event.arguments or {}, + arguments = event.arguments, } - table.insert(result.tool_calls, call) + table.insert(tool_calls, call) if on_tool_call then on_tool_call(call) end elseif event.type == "error" then - local err: ErrorInfo = {message = event.message, code = event.code} if on_error then - on_error(err) + on_error({message = event.message, code = event.code}) end - return nil, err.message + return nil, event.message elseif event.type == "done" then - result.finish_reason = event.reason - result.usage = event.usage or result.usage + finish_reason = event.reason + if event.usage then + usage = event.usage + end end end + local result: types.StreamResult = { + content = content, + tool_calls = tool_calls, + finish_reason = finish_reason, + usage = usage, + } + if on_done then on_done(result) end diff --git a/testdata/fixtures/realworld/typed-callback-chain/types.lua b/testdata/fixtures/realworld/typed-callback-chain/types.lua index 071890af..46f46d23 100644 --- a/testdata/fixtures/realworld/typed-callback-chain/types.lua +++ b/testdata/fixtures/realworld/typed-callback-chain/types.lua @@ -35,12 +35,15 @@ M.ErrorInfo = ErrorInfo M.StreamCallbacks = StreamCallbacks function M.empty_result(): StreamResult - return { + local empty_tools: {ToolCall} = {} + local empty_usage: Usage = {input_tokens = 0, output_tokens = 0} + local result: StreamResult = { content = "", - tool_calls = {}, + tool_calls = empty_tools, finish_reason = nil, - usage = {input_tokens = 0, output_tokens = 0}, + usage = empty_usage, } + return result end return M diff --git a/testdata/fixtures/realworld/typed-enum-constants/handler.lua b/testdata/fixtures/realworld/typed-enum-constants/handler.lua index 38e06278..8fd0b523 100644 --- a/testdata/fixtures/realworld/typed-enum-constants/handler.lua +++ b/testdata/fixtures/realworld/typed-enum-constants/handler.lua @@ -1,15 +1,15 @@ local status = require("status") type Route = { - method: HttpMethod, + method: status.HttpMethod, path: string, - handler: (req: Request) -> Response, + handler: (req: status.Request) -> status.Response, } type Router = { _routes: {Route}, - add: (self: Router, method: HttpMethod, path: string, handler: (req: Request) -> Response) -> Router, - handle: (self: Router, req: Request) -> Response, + add: (self: Router, method: status.HttpMethod, path: string, handler: (req: status.Request) -> status.Response) -> Router, + handle: (self: Router, req: status.Request) -> status.Response, } local M = {} @@ -17,11 +17,11 @@ local M = {} function M.new(): Router local router: Router = { _routes = {}, - add = function(self: Router, method: HttpMethod, path: string, handler: (req: Request) -> Response): Router + add = function(self: Router, method: status.HttpMethod, path: string, handler: (req: status.Request) -> status.Response): Router table.insert(self._routes, {method = method, path = path, handler = handler}) return self end, - handle = function(self: Router, req: Request): Response + handle = function(self: Router, req: status.Request): status.Response for _, route in ipairs(self._routes) do if route.method == req.method and route.path == req.path then return route.handler(req) diff --git a/testdata/fixtures/realworld/typed-enum-constants/main.lua b/testdata/fixtures/realworld/typed-enum-constants/main.lua index 88526721..155cc199 100644 --- a/testdata/fixtures/realworld/typed-enum-constants/main.lua +++ b/testdata/fixtures/realworld/typed-enum-constants/main.lua @@ -2,16 +2,16 @@ local status = require("status") local handler = require("handler") local router = handler.new() - :add("GET", "/users", function(req: Request): Response + :add("GET", "/users", function(req: status.Request): status.Response return status.ok({users = {"Alice", "Bob"}}) end) - :add("POST", "/users", function(req: Request): Response + :add("POST", "/users", function(req: status.Request): status.Response if not req.body then return status.error(400, "Missing body") end return status.created({id = "new-user"}) end) - :add("DELETE", "/users", function(req: Request): Response + :add("DELETE", "/users", function(req: status.Request): status.Response return status.ok() end) diff --git a/testdata/fixtures/realworld/typed-enum-constants/manifest.json b/testdata/fixtures/realworld/typed-enum-constants/manifest.json index 56f0a98a..9bc495be 100644 --- a/testdata/fixtures/realworld/typed-enum-constants/manifest.json +++ b/testdata/fixtures/realworld/typed-enum-constants/manifest.json @@ -1 +1 @@ -{"files": ["status.lua", "handler.lua", "main.lua"], "check": {"errors": 6}} +{"files": ["status.lua", "handler.lua", "main.lua"], "check": {"errors": 0}} diff --git a/testdata/fixtures/regression/deadlock-compiler-lua/main.lua b/testdata/fixtures/regression/deadlock-compiler-lua/main.lua new file mode 100644 index 00000000..bccab3a4 --- /dev/null +++ b/testdata/fixtures/regression/deadlock-compiler-lua/main.lua @@ -0,0 +1,1327 @@ +local json = require("json") +local uuid = require("uuid") +local expr = require("expr") +local consts = require("df_consts") + +local compiler = {} + +compiler.OP_TYPES = { + WITH_INPUT = "with_input", + WITH_DATA = "with_data", + FUNC = "func", + AGENT = "agent", + CYCLE = "cycle", + PARALLEL = "parallel", + STATE = "state", + USE = "use", + AS = "as", + TO = "to", + ERROR_TO = "error_to", + WHEN = "when" +} + +local FlowGraph = {} +local flow_graph_mt = { __index = FlowGraph } + +function FlowGraph.new() + return setmetatable({ + operations = table.create(16, 0), + nodes = table.create(0, 16), + node_order = table.create(16, 0), + edges = table.create(0, 16), + references = table.create(0, 8), + input_data = nil, + input_name = nil, + input_routes = table.create(4, 0), + static_data_sources = table.create(4, 0), + last_node_id = nil, + last_static_id = nil, + last_node_name = nil, + last_route_from_static = false, + pending_routes = table.create(8, 0), + has_explicit_routing = false, + session_parent_id = nil, + forced_success_nodes = table.create(0, 4), + forced_failure_nodes = table.create(0, 4), + auto_chained = table.create(0, 16) + }, flow_graph_mt) +end + +function FlowGraph:add_operation(op_type, config) + table.insert(self.operations, { + type = op_type, + config = config or {} + }) + return self, nil +end + +function FlowGraph:create_node(node_type, config, metadata) + local node_id = uuid.v7() + + self.nodes[node_id] = { + node_id = node_id, + node_type = node_type, + config = config or {}, + metadata = metadata or {}, + status = consts.STATUS.PENDING + } + + table.insert(self.node_order, node_id) + + self.edges[node_id] = { + targets = table.create(4, 0), + error_targets = table.create(2, 0) + } + + self.last_node_id = node_id + self.last_static_id = nil + return node_id, nil +end + +function FlowGraph:create_static_data(data) + local static_id = uuid.v7() + + table.insert(self.static_data_sources, { + static_id = static_id, + data = data, + routes = table.create(4, 0) + }) + + self.last_static_id = static_id + self.last_node_id = nil + return static_id, nil +end + +function FlowGraph:create_template_nodes(template, parent_node_id) + if not template or not template.operations then + return table.create(0, 0) + end + + local template_node_ids = table.create(#template.operations, 0) + local last_template_node_id = nil + + for _, op in ipairs(template.operations) do + if op.type == compiler.OP_TYPES.FUNC then + local template_node_id = uuid.v7() + + local config = { + func_id = op.config.func_id, + args = op.config.args, + inputs = op.config.inputs, + context = op.config.context, + input_transform = op.config.input_transform + } + + if last_template_node_id then + local prev_node = (self.nodes[last_template_node_id] :: any) + if not prev_node.config.data_targets then + prev_node.config.data_targets = table.create(1, 0) + end + table.insert(prev_node.config.data_targets, { + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = template_node_id, + discriminator = "default", + metadata = { + source_node_id = last_template_node_id + } + }) + end + + local metadata = op.config.metadata or {} + + self.nodes[template_node_id] = { + node_id = template_node_id, + node_type = "userspace.dataflow.node.func:node", + config = config, + metadata = metadata, + status = consts.STATUS.TEMPLATE, + parent_node_id = parent_node_id + } + + self.edges[template_node_id] = { + targets = table.create(2, 0), + error_targets = table.create(1, 0) + } + + table.insert(self.node_order, template_node_id) + table.insert(template_node_ids, template_node_id) + last_template_node_id = template_node_id + elseif op.type == compiler.OP_TYPES.AGENT then + local template_node_id = uuid.v7() + + local config = { + agent = op.config.agent_id, + model = op.config.model, + arena = op.config.arena, + inputs = op.config.inputs, + show_tool_calls = op.config.show_tool_calls, + input_transform = op.config.input_transform + } + + if last_template_node_id then + local prev_node = (self.nodes[last_template_node_id] :: any) + if not prev_node.config.data_targets then + prev_node.config.data_targets = table.create(1, 0) + end + table.insert(prev_node.config.data_targets, { + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = template_node_id, + discriminator = "default", + metadata = { + source_node_id = last_template_node_id + } + }) + end + + local metadata = op.config.metadata or {} + + self.nodes[template_node_id] = { + node_id = template_node_id, + node_type = "userspace.dataflow.node.agent:node", + config = config, + metadata = metadata, + status = consts.STATUS.TEMPLATE, + parent_node_id = parent_node_id + } + + self.edges[template_node_id] = { + targets = table.create(2, 0), + error_targets = table.create(1, 0) + } + + table.insert(self.node_order, template_node_id) + table.insert(template_node_ids, template_node_id) + last_template_node_id = template_node_id + elseif op.type == compiler.OP_TYPES.CYCLE then + local template_node_id = uuid.v7() + + local config = { + func_id = op.config.func_id, + args = op.config.args, + continue_condition = op.config.continue_condition, + max_iterations = op.config.max_iterations, + initial_state = op.config.initial_state, + inputs = op.config.inputs, + context = op.config.context, + input_transform = op.config.input_transform + } + + if last_template_node_id then + local prev_node = (self.nodes[last_template_node_id] :: any) + if not prev_node.config.data_targets then + prev_node.config.data_targets = table.create(1, 0) + end + table.insert(prev_node.config.data_targets, { + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = template_node_id, + source_node_id = last_template_node_id, + discriminator = "default" + }) + end + + local metadata = op.config.metadata or {} + + self.nodes[template_node_id] = { + node_id = template_node_id, + node_type = "userspace.dataflow.node.cycle:cycle", + config = config, + metadata = metadata, + status = consts.STATUS.TEMPLATE, + parent_node_id = parent_node_id + } + + self.edges[template_node_id] = { + targets = table.create(2, 0), + error_targets = table.create(1, 0) + } + + table.insert(self.node_order, template_node_id) + table.insert(template_node_ids, template_node_id) + + if op.config.template then + local cycle_template_nodes = self:create_template_nodes(op.config.template, template_node_id) + for _, child_id in ipairs(cycle_template_nodes) do + table.insert(template_node_ids, child_id) + end + end + + last_template_node_id = template_node_id + end + end + + if last_template_node_id then + local last_node = (self.nodes[last_template_node_id] :: any) + if not last_node.config.data_targets then + last_node.config.data_targets = table.create(1, 0) + end + table.insert(last_node.config.data_targets, { + data_type = consts.DATA_TYPE.NODE_OUTPUT, + discriminator = "result", + metadata = { + source_node_id = last_template_node_id + } + }) + + if not last_node.config.error_targets then + last_node.config.error_targets = table.create(1, 0) + end + table.insert(last_node.config.error_targets, { + data_type = consts.DATA_TYPE.NODE_OUTPUT, + discriminator = "error", + metadata = { + source_node_id = last_template_node_id + } + }) + end + + return template_node_ids +end + +function FlowGraph:add_reference(name, node_id) + if self.references[name] then + return nil, "Duplicate node name: " .. name + end + self.references[name] = node_id + self.last_node_name = name + return true, nil +end + +function FlowGraph:resolve_reference(name) + local node_id = self.references[name] + if not node_id then + return nil, "Undefined node reference: " .. name + end + return node_id, nil +end + +function FlowGraph:compute_auto_chain() + for i = 1, #self.node_order - 1 do + local current_node_id = self.node_order[i] + local next_node_id = self.node_order[i + 1] + local current_node = (self.nodes[current_node_id] :: any) + local next_node = (self.nodes[next_node_id] :: any) + + if not current_node.parent_node_id and not next_node.parent_node_id then + local current_edges = (self.edges[current_node_id] :: any) + + local has_any_targets = false + for _, edge in ipairs(current_edges.targets) do + if edge.target_node_id or edge.is_workflow_terminal then + has_any_targets = true + break + end + end + for _, edge in ipairs(current_edges.error_targets) do + if edge.target_node_id or edge.is_workflow_terminal then + has_any_targets = true + break + end + end + + if not has_any_targets then + self.auto_chained[current_node_id] = next_node_id + table.insert(current_edges.targets, { + target_node_id = next_node_id, + discriminator = "default", + is_auto_chain = true, + metadata = { + source_node_id = current_node_id + } + }) + end + end + end +end + +function FlowGraph:detect_cycles() + local visited = table.create(0, 32) + local rec_stack = table.create(0, 32) + + local function dfs(node_id, path) + if rec_stack[node_id] then + local cycle_start = nil + for i, id in ipairs(path) do + if id == node_id then + cycle_start = i + break + end + end + + if cycle_start then + local cycle = table.create(#path - cycle_start + 2, 0) + for i = cycle_start, #path do + table.insert(cycle, path[i]) + end + table.insert(cycle, node_id) + return true, "Cycle detected: " .. table.concat(cycle, " -> ") + end + return true, "Cycle detected at node: " .. node_id + end + + if visited[node_id] then + return false, nil + end + + visited[node_id] = true + rec_stack[node_id] = true + table.insert(path, node_id) + + local edges = (self.edges[node_id] :: any) + if edges then + for _, edge in ipairs(edges.targets) do + if edge.target_node_id then + local has_cycle, cycle_desc = dfs(edge.target_node_id, path) + if has_cycle then + return true, cycle_desc + end + end + end + for _, edge in ipairs(edges.error_targets) do + if edge.target_node_id then + local has_cycle, cycle_desc = dfs(edge.target_node_id, path) + if has_cycle then + return true, cycle_desc + end + end + end + end + + rec_stack[node_id] = false + table.remove(path) + return false, nil + end + + for node_id, _ in pairs(self.nodes) do + if not visited[node_id] then + local has_cycle, cycle_desc = dfs(node_id, table.create(16, 0)) + if has_cycle then + return true, cycle_desc + end + end + end + + return false, nil +end + +function compiler.build_graph(operations, session_context) + if not operations or #operations == 0 then + return nil, "No operations provided" + end + + local graph = FlowGraph.new() :: any + + if session_context and session_context.node_id then + graph.session_parent_id = session_context.node_id + end + + for _, op in ipairs(operations) do + if op.type == compiler.OP_TYPES.WITH_INPUT then + graph.input_data = op.config.data + elseif op.type == compiler.OP_TYPES.WITH_DATA then + local static_id, err = graph:create_static_data(op.config.data) + if err then + return nil, err + end + elseif op.type == compiler.OP_TYPES.FUNC then + local config = { + func_id = op.config.func_id, + args = op.config.args, + inputs = op.config.inputs, + context = op.config.context, + input_transform = op.config.input_transform + } + + local node_id, err = graph:create_node("userspace.dataflow.node.func:node", config, op.config.metadata) + if err then + return nil, err + end + elseif op.type == compiler.OP_TYPES.AGENT then + local config = { + agent = op.config.agent_id, + model = op.config.model, + arena = op.config.arena, + inputs = op.config.inputs, + show_tool_calls = op.config.show_tool_calls, + input_transform = op.config.input_transform + } + + local node_id, err = graph:create_node("userspace.dataflow.node.agent:node", config, op.config.metadata) + if err then + return nil, err + end + elseif op.type == compiler.OP_TYPES.CYCLE then + local config = { + func_id = op.config.func_id, + args = op.config.args, + continue_condition = op.config.continue_condition, + max_iterations = op.config.max_iterations, + initial_state = op.config.initial_state, + inputs = op.config.inputs, + context = op.config.context, + input_transform = op.config.input_transform + } + + local node_id, err = graph:create_node("userspace.dataflow.node.cycle:cycle", config, op.config.metadata) + if err then + return nil, err + end + + if op.config.template then + graph:create_template_nodes(op.config.template, node_id) + end + elseif op.type == compiler.OP_TYPES.PARALLEL then + local config = { + source_array_key = op.config.source_array_key, + iteration_input_key = op.config.iteration_input_key, + batch_size = op.config.batch_size, + on_error = op.config.on_error, + filter = op.config.filter, + unwrap = op.config.unwrap, + passthrough_keys = op.config.passthrough_keys, + inputs = op.config.inputs, + input_transform = op.config.input_transform + } + + local node_id, err = graph:create_node("userspace.dataflow.node.parallel:parallel", config, + op.config.metadata) + if err then + return nil, err + end + + if op.config.template then + graph:create_template_nodes(op.config.template, node_id) + end + elseif op.type == compiler.OP_TYPES.STATE then + local config = { + inputs = op.config.inputs, + input_transform = op.config.input_transform, + output_mode = op.config.output_mode, + ignored_keys = op.config.ignored_keys + } + + local node_id, err = graph:create_node("userspace.dataflow.node.state:state", config, op.config.metadata) + if err then + return nil, err + end + elseif op.type == compiler.OP_TYPES.USE then + if op.config.operations then + for _, template_op in ipairs(op.config.operations) do + table.insert(graph.operations, template_op) + end + end + elseif op.type == compiler.OP_TYPES.AS then + if graph.last_static_id then + graph:add_reference(op.config.name, graph.last_static_id) + elseif graph.input_data and not graph.input_name and not graph.last_node_id then + graph.input_name = op.config.name + graph:add_reference(op.config.name, "INPUT") + elseif graph.last_node_id then + local success, err = graph:add_reference(op.config.name, graph.last_node_id) + if err then + return nil, err + end + else + return nil, "Cannot name node: no previous node, input, or static data to name: " .. op.config.name + end + elseif op.type == compiler.OP_TYPES.TO then + if op.config.target == "@success" or op.config.target == "@fail" or op.config.target == "@end" then + if not graph.last_node_id then + return nil, "Cannot route to terminal: no source node" + end + + local is_success = op.config.target == "@success" or (op.config.target == "@end") + + if is_success then + graph.forced_success_nodes[graph.last_node_id] = true + else + graph.forced_failure_nodes[graph.last_node_id] = true + end + + graph.has_explicit_routing = true + graph.last_route_from_static = false + table.insert(graph.pending_routes, { + from_node_id = graph.last_node_id, + is_workflow_terminal = true, + is_success = is_success, + transform = op.config.transform, + condition = nil, + is_error = false + }) + elseif graph.input_data and not graph.last_node_id and not graph.last_static_id then + table.insert(graph.input_routes, { + target_name = op.config.target, + input_key = op.config.input_key or graph.input_name or "default", + transform = op.config.transform + }) + graph.has_explicit_routing = true + graph.last_route_from_static = false + elseif graph.last_static_id then + for _, static_source in ipairs(graph.static_data_sources) do + if (static_source :: any).static_id == graph.last_static_id then + table.insert((static_source :: any).routes, { + target_name = op.config.target, + input_key = op.config.input_key or graph.last_node_name or "default", + transform = op.config.transform + }) + break + end + end + graph.has_explicit_routing = true + graph.last_route_from_static = true + elseif graph.last_node_id then + graph.has_explicit_routing = true + graph.last_route_from_static = false + local discriminator = op.config.input_key + if not discriminator and graph.last_node_name then + discriminator = graph.last_node_name + end + table.insert(graph.pending_routes, { + from_node_id = graph.last_node_id, + target_name = op.config.target, + input_key = discriminator, + transform = op.config.transform, + is_error = false, + condition = nil + }) + else + return nil, "Cannot add route: no source node, input, or static data" + end + elseif op.type == compiler.OP_TYPES.ERROR_TO then + if op.config.target == "@success" or op.config.target == "@fail" or op.config.target == "@end" then + if not graph.last_node_id then + return nil, "Cannot route to terminal: no source node" + end + + local is_success = op.config.target == "@success" + + if is_success then + graph.forced_success_nodes[graph.last_node_id] = true + else + graph.forced_failure_nodes[graph.last_node_id] = true + end + + graph.has_explicit_routing = true + graph.last_route_from_static = false + table.insert(graph.pending_routes, { + from_node_id = graph.last_node_id, + is_workflow_terminal = true, + is_success = is_success, + transform = op.config.transform, + is_error = true, + condition = nil + }) + else + if not graph.last_node_id then + return nil, "Cannot add error route: no source node" + end + graph.has_explicit_routing = true + graph.last_route_from_static = false + local discriminator = op.config.input_key + if not discriminator and graph.last_node_name then + discriminator = graph.last_node_name + end + table.insert(graph.pending_routes, { + from_node_id = graph.last_node_id, + target_name = op.config.target, + input_key = discriminator, + transform = op.config.transform, + is_error = true, + condition = nil + }) + end + elseif op.type == compiler.OP_TYPES.WHEN then + if graph.last_route_from_static then + return nil, + "Cannot use :when() with static data routes. Static data is constant and conditions would always evaluate the same way." + end + if #graph.pending_routes == 0 then + return nil, "Cannot add condition: no preceding route from a node" + end + (graph.pending_routes[#graph.pending_routes] :: any).condition = op.config.condition + end + + local success, err = graph:add_operation(op.type, op.config) + if err then + return nil, err + end + end + + for _, route in ipairs(graph.pending_routes) do + local route_entry = (route :: any) + if route_entry.is_workflow_terminal then + local edges = (graph.edges[route_entry.from_node_id] :: any) + local edge_list = route_entry.is_error and edges.error_targets or edges.targets + table.insert(edge_list, { + target_node_id = nil, + is_workflow_terminal = true, + is_success = route_entry.is_success, + transform = route_entry.transform, + condition = route_entry.condition + }) + else + local target_node_id, resolve_err = graph:resolve_reference(route_entry.target_name) + if resolve_err then + return nil, resolve_err + end + local edges = (graph.edges[route_entry.from_node_id] :: any) + local edge_list = route_entry.is_error and edges.error_targets or edges.targets + table.insert(edge_list, { + target_node_id = target_node_id, + transform = route_entry.transform, + condition = route_entry.condition, + input_key = route_entry.input_key + }) + end + end + + graph:compute_auto_chain() + + local has_cycles, cycle_desc = graph:detect_cycles() + if has_cycles then + return nil, "Flow contains cycles: " .. cycle_desc + end + + return graph, nil +end + +function compiler.find_root_nodes(graph) + local nodes_with_incoming = table.create(0, 32) + + for _, edges in pairs((graph :: any).edges) do + local edge_set = (edges :: any) + for _, edge in ipairs(edge_set.targets) do + if edge.target_node_id and not edge.is_auto_chain then + nodes_with_incoming[edge.target_node_id] = true + end + end + for _, edge in ipairs(edge_set.error_targets) do + if edge.target_node_id then + nodes_with_incoming[edge.target_node_id] = true + end + end + end + + local roots = table.create(8, 0) + for node_id, node_def in pairs((graph :: any).nodes) do + if not nodes_with_incoming[node_id] and not (node_def :: any).parent_node_id then + table.insert(roots, node_id) + end + end + + return roots, nil +end + +function compiler.find_leaf_nodes(graph) + local leaves = table.create(8, 0) + + for node_id, edges in pairs((graph :: any).edges) do + local edge_set = (edges :: any) + local has_node_targets = false + for _, edge in ipairs(edge_set.targets) do + if edge.target_node_id then + has_node_targets = true + break + end + end + for _, edge in ipairs(edge_set.error_targets) do + if edge.target_node_id then + has_node_targets = true + break + end + end + if not has_node_targets then + table.insert(leaves, node_id) + end + end + + return leaves, nil +end + +function compiler.validate_graph(graph) + local nodes_with_incoming = table.create(0, 32) + + for _, edges in pairs((graph :: any).edges) do + local edge_set = (edges :: any) + for _, edge in ipairs(edge_set.targets) do + if edge.target_node_id then + nodes_with_incoming[edge.target_node_id] = true + end + end + for _, edge in ipairs(edge_set.error_targets) do + if edge.target_node_id then + nodes_with_incoming[edge.target_node_id] = true + end + end + end + + local has_workflow_input = (graph :: any).input_data ~= nil + local input_target_nodes = table.create(0, 8) + + if has_workflow_input then + if #(graph :: any).input_routes > 0 then + for _, route in ipairs((graph :: any).input_routes) do + local route_entry = (route :: any) + local target_id, err = (graph :: any):resolve_reference(route_entry.target_name) + if not err and target_id then + input_target_nodes[target_id] = true + end + end + else + local root_nodes, _ = compiler.find_root_nodes(graph) + if root_nodes then + for _, node_id in ipairs(root_nodes) do + input_target_nodes[node_id] = true + end + end + end + end + + local static_target_nodes = table.create(0, 8) + for _, static_source in ipairs((graph :: any).static_data_sources) do + local src = (static_source :: any) + for _, route in ipairs(src.routes) do + local route_entry = (route :: any) + local target_id, err = (graph :: any):resolve_reference(route_entry.target_name) + if not err and target_id then + static_target_nodes[target_id] = true + end + end + end + + local dead_nodes = table.create(8, 0) + for node_id, node_def in pairs((graph :: any).nodes) do + local nd = (node_def :: any) + if nd.status ~= consts.STATUS.TEMPLATE and not nd.parent_node_id then + local has_incoming = nodes_with_incoming[node_id] + local has_input = input_target_nodes[node_id] + local has_static = static_target_nodes[node_id] + + if not has_incoming and not has_input and not has_static then + local title = nd.metadata and nd.metadata.title or "unnamed" + table.insert(dead_nodes, string.format("%s (%s)", title, node_id:sub(1, 12))) + end + end + end + + if #dead_nodes > 0 then + return false, string.format( + "Dead nodes detected (no incoming routes): %s. All nodes must either receive data from another node, workflow input, or static data.", + table.concat(dead_nodes, ", ") + ) + end + + local nodes_with_default_inputs = table.create(0, 32) + + for source_node_id, edges in pairs((graph :: any).edges) do + local edge_set = (edges :: any) + for _, edge in ipairs(edge_set.targets) do + if edge.target_node_id then + local discriminator = edge.input_key or "default" + if discriminator == "default" or discriminator == "" then + nodes_with_default_inputs[edge.target_node_id] = true + end + end + end + end + + if (graph :: any).input_data then + if #(graph :: any).input_routes > 0 then + for _, route in ipairs((graph :: any).input_routes) do + local route_entry = (route :: any) + local target_id, err = (graph :: any):resolve_reference(route_entry.target_name) + if not err and target_id then + local discriminator = route_entry.input_key or "default" + if discriminator == "default" or discriminator == "" then + nodes_with_default_inputs[target_id] = true + end + end + end + else + local root_nodes, _ = compiler.find_root_nodes(graph) + if root_nodes then + for _, node_id in ipairs(root_nodes) do + nodes_with_default_inputs[node_id] = true + end + end + end + end + + for _, static_source in ipairs((graph :: any).static_data_sources) do + local src = (static_source :: any) + for _, route in ipairs(src.routes) do + local route_entry = (route :: any) + local target_id, err = (graph :: any):resolve_reference(route_entry.target_name) + if not err and target_id then + local discriminator = route_entry.input_key or "default" + if discriminator == "default" or discriminator == "" then + nodes_with_default_inputs[target_id] = true + end + end + end + end + + local conflicts = table.create(8, 0) + for node_id, node_def in pairs((graph :: any).nodes) do + local nd = (node_def :: any) + if nodes_with_default_inputs[node_id] then + local has_args = nd.config.args ~= nil + local has_string_transform = type(nd.config.input_transform) == "string" + + if has_args or has_string_transform then + local title = nd.metadata and nd.metadata.title or "unnamed" + local reason = "" + if has_args then + reason = " (has args)" + elseif has_string_transform then + reason = " (has string input_transform)" + end + table.insert(conflicts, string.format("%s (%s)%s", title, node_id:sub(1, 12), reason)) + end + end + end + + if #conflicts > 0 then + return false, string.format( + "Nodes with base arguments (args) cannot receive inputs with 'default' discriminator: %s. " .. + "Use named discriminators with :to(target, 'input_key') or input_transform table form {key = 'expr'}.", + table.concat(conflicts, ", ") + ) + end + + local has_success_terminal = false + local has_auto_output = false + local leaf_nodes = table.create(8, 0) + + for node_id, edges in pairs((graph :: any).edges) do + local edge_set = (edges :: any) + local node_def = (graph :: any).nodes[node_id] + if node_def and (node_def :: any).status ~= consts.STATUS.TEMPLATE then + local nd = (node_def :: any) + local has_node_targets = false + for _, edge in ipairs(edge_set.targets) do + if edge.target_node_id then + has_node_targets = true + break + end + if edge.is_workflow_terminal and edge.is_success then + has_success_terminal = true + end + end + + if not has_node_targets and not nd.parent_node_id then + table.insert(leaf_nodes, { + node_id = node_id, + has_success_route = #edge_set.targets > 0, + has_error_route = #edge_set.error_targets > 0, + metadata = nd.metadata + }) + end + end + end + + for _, leaf_info in ipairs(leaf_nodes) do + local li = (leaf_info :: any) + if not li.has_success_route and not li.has_error_route then + has_auto_output = true + break + end + end + + if not has_success_terminal and not has_auto_output then + local problematic_nodes = table.create(#leaf_nodes, 0) + for _, leaf_info in ipairs(leaf_nodes) do + local li = (leaf_info :: any) + if li.has_error_route and not li.has_success_route then + local title = li.metadata and li.metadata.title or "unnamed" + table.insert(problematic_nodes, string.format("%s (%s)", title, li.node_id:sub(1, 12))) + end + end + + if #problematic_nodes > 0 then + return false, string.format( + "Workflow has no success termination path. Node(s) with :error_to() but no :to() route: %s. " .. + "Add :to(\"@success\") to at least one node to complete the workflow on success.", + table.concat(problematic_nodes, ", ") + ) + else + return false, "Workflow has no completion path. Add :to(\"@success\") to at least one leaf node." + end + end + + return true, nil +end + +function compiler.compile_to_commands(graph, session_context) + if not graph then + return nil, "Graph is required" + end + + local commands = table.create(#(graph :: any).node_order * 2, 0) + local input_data_id = nil + local is_nested = session_context and session_context.dataflow_id + + -- Step 1: Create workflow input data object (only for non-nested workflows) + if (graph :: any).input_data and not is_nested then + input_data_id = uuid.v7() + table.insert(commands, { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = input_data_id, + data_type = consts.DATA_TYPE.WORKFLOW_INPUT, + content = (graph :: any).input_data, + content_type = type((graph :: any).input_data) == "table" and consts.CONTENT_TYPE.JSON or consts.CONTENT_TYPE.TEXT + } + }) + end + + -- Step 2: Create all nodes first + local leaf_nodes, leaf_err = compiler.find_leaf_nodes(graph) + if leaf_err then + return nil, leaf_err + end + + for _, node_id in ipairs((graph :: any).node_order) do + local node_def = ((graph :: any).nodes[node_id] :: any) + local config = {} + + for k, v in pairs(node_def.config) do + config[k] = v + end + + local edges = ((graph :: any).edges[node_id] :: any) + local has_explicit_edges = false + + for _, edge in ipairs(edges.targets) do + if edge.target_node_id or edge.is_workflow_terminal then + has_explicit_edges = true + break + end + end + for _, edge in ipairs(edges.error_targets) do + if edge.target_node_id or edge.is_workflow_terminal then + has_explicit_edges = true + break + end + end + + if has_explicit_edges then + config.data_targets = table.create(#edges.targets, 0) + config.error_targets = table.create(#edges.error_targets, 0) + + for _, edge in ipairs(edges.targets) do + if edge.is_workflow_terminal then + local has_parent = node_def.parent_node_id or (graph :: any).session_parent_id + local output_type = has_parent and consts.DATA_TYPE.NODE_OUTPUT or consts.DATA_TYPE.WORKFLOW_OUTPUT + + local target = { + data_type = output_type, + discriminator = edge.is_success and "result" or "error", + condition = edge.condition, + transform = edge.transform, + metadata = { + source_node_id = node_id + } + } + + if output_type == consts.DATA_TYPE.NODE_OUTPUT then + (target :: any).node_id = node_id + end + + table.insert(config.data_targets, target) + else + table.insert(config.data_targets, { + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = edge.target_node_id, + discriminator = edge.input_key or "default", + condition = edge.condition, + transform = edge.transform, + metadata = { + source_node_id = node_id + } + }) + end + end + + for _, edge in ipairs(edges.error_targets) do + if edge.is_workflow_terminal then + local has_parent = node_def.parent_node_id or (graph :: any).session_parent_id + local output_type = has_parent and consts.DATA_TYPE.NODE_OUTPUT or consts.DATA_TYPE.WORKFLOW_OUTPUT + + local target = { + data_type = output_type, + discriminator = edge.is_success and "result" or "error", + condition = edge.condition, + transform = edge.transform, + metadata = { + source_node_id = node_id + } + } + + if output_type == consts.DATA_TYPE.NODE_OUTPUT then + (target :: any).node_id = node_id + end + + table.insert(config.error_targets, target) + else + table.insert(config.error_targets, { + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = edge.target_node_id, + discriminator = edge.input_key or "default", + condition = edge.condition, + transform = edge.transform, + metadata = { + source_node_id = node_id + } + }) + end + end + else + local is_leaf = false + local is_template = node_def.status == consts.STATUS.TEMPLATE + + for _, leaf_id in ipairs(leaf_nodes) do + if leaf_id == node_id then + is_leaf = true + break + end + end + + if is_leaf and not is_template then + local has_parent = node_def.parent_node_id or (graph :: any).session_parent_id + local output_data_type = has_parent and consts.DATA_TYPE.NODE_OUTPUT or consts.DATA_TYPE.WORKFLOW_OUTPUT + + config.data_targets = table.create(1, 0) + + local target = { + data_type = output_data_type, + discriminator = "result", + content_type = consts.CONTENT_TYPE.JSON, + metadata = { + source_node_id = node_id + } + } + + if output_data_type == consts.DATA_TYPE.NODE_OUTPUT then + (target :: any).node_id = node_id + end + + table.insert(config.data_targets, target) + end + end + + local node_payload = { + node_id = node_id, + node_type = node_def.node_type, + status = node_def.status, + config = config, + metadata = node_def.metadata + } :: any + + if node_def.parent_node_id then + node_payload.parent_node_id = node_def.parent_node_id + elseif (graph :: any).session_parent_id then + node_payload.parent_node_id = (graph :: any).session_parent_id + end + + table.insert(commands, { + type = consts.COMMAND_TYPES.CREATE_NODE, + payload = node_payload + }) + end + + -- Step 3: Create static data sources (nodes now exist) + local static_data_ids = table.create(0, #(graph :: any).static_data_sources) + + for _, static_source in ipairs((graph :: any).static_data_sources) do + local src = (static_source :: any) + if #src.routes > 0 then + local first_route = (src.routes[1] :: any) + local target_node_id, err = (graph :: any):resolve_reference(first_route.target_name) + if err then + return nil, err + end + + local content = src.data + if first_route.transform then + local transform_env = { + output = src.data + } + local transformed, eval_err = expr.eval(first_route.transform :: string, transform_env) + if eval_err then + return nil, "Static data route transform failed: " .. (eval_err :: string) + end + content = transformed + end + + local data_id = uuid.v7() + static_data_ids[src.static_id] = data_id + + table.insert(commands, { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = data_id, + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = target_node_id, + discriminator = first_route.input_key, + content = content, + content_type = type(content) == "table" and consts.CONTENT_TYPE.JSON or consts.CONTENT_TYPE.TEXT + } + }) + + for i = 2, #src.routes do + local route = (src.routes[i] :: any) + local route_target_id, route_err = (graph :: any):resolve_reference(route.target_name) + if route_err then + return nil, route_err + end + + table.insert(commands, { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = uuid.v7(), + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = route_target_id, + discriminator = route.input_key, + key = data_id, + content = "", + content_type = consts.CONTENT_TYPE.REFERENCE + } + }) + end + end + end + + -- Step 4: Create nested workflow input routing (nodes now exist) + if (graph :: any).input_data and is_nested then + if #(graph :: any).input_routes > 0 then + for _, route in ipairs((graph :: any).input_routes) do + local route_entry = (route :: any) + local target_node_id, err = (graph :: any):resolve_reference(route_entry.target_name) + if err then + return nil, err + end + + local content = (graph :: any).input_data + if route_entry.transform then + local transform_env = { + input = (graph :: any).input_data, + output = (graph :: any).input_data + } + local transformed, eval_err = expr.eval(route_entry.transform :: string, transform_env) + if eval_err then + return nil, "Input route transform failed: " .. (eval_err :: string) + end + content = transformed + end + + table.insert(commands, { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = uuid.v7(), + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = target_node_id, + discriminator = route_entry.input_key, + content = content, + content_type = type(content) == "table" and consts.CONTENT_TYPE.JSON or consts.CONTENT_TYPE.TEXT + } + }) + end + else + local root_nodes, roots_err = compiler.find_root_nodes(graph) + if roots_err then + return nil, roots_err + end + + for _, node_id in ipairs(root_nodes) do + table.insert(commands, { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = uuid.v7(), + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = node_id, + discriminator = "default", + content = (graph :: any).input_data, + content_type = type((graph :: any).input_data) == "table" and consts.CONTENT_TYPE.JSON or + consts.CONTENT_TYPE.TEXT + } + }) + end + end + end + + -- Step 5: Create workflow input data references (non-nested, nodes now exist) + if input_data_id and not is_nested then + local root_nodes, roots_err = compiler.find_root_nodes(graph) + if roots_err then + return nil, roots_err + end + + if #(graph :: any).input_routes > 0 then + for _, route in ipairs((graph :: any).input_routes) do + local route_entry = (route :: any) + local target_node_id, err = (graph :: any):resolve_reference(route_entry.target_name) + if err then + return nil, err + end + + table.insert(commands, { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = uuid.v7(), + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = target_node_id, + key = input_data_id, + discriminator = route_entry.input_key, + content = "", + content_type = "dataflow/reference" + } + }) + end + else + for _, node_id in ipairs(root_nodes) do + table.insert(commands, { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = uuid.v7(), + data_type = consts.DATA_TYPE.NODE_INPUT, + node_id = node_id, + key = input_data_id, + discriminator = "default", + content = "", + content_type = "dataflow/reference" + } + }) + end + end + end + + return commands, nil +end + +function compiler.compile(operations, session_context) + if not operations or #operations == 0 then + return nil, "No operations to compile" + end + + local graph, graph_err = compiler.build_graph(operations, session_context) + if graph_err then + return nil, graph_err + end + + local valid, validation_err = compiler.validate_graph(graph) + if not valid then + return nil, validation_err + end + + local commands, commands_err = compiler.compile_to_commands(graph, session_context) + if commands_err then + return nil, commands_err + end + + return { + commands = commands, + graph = graph + }, nil +end + +return compiler diff --git a/testdata/fixtures/regression/deadlock-dataflow-node/main.lua b/testdata/fixtures/regression/deadlock-dataflow-node/main.lua new file mode 100644 index 00000000..bdaf90be --- /dev/null +++ b/testdata/fixtures/regression/deadlock-dataflow-node/main.lua @@ -0,0 +1,718 @@ +local json = require("json") +local uuid = require("uuid") +local expr = require("expr") +local consts = require("consts") + +local default_deps = { + commit = require("commit"), + data_reader = require("data_reader"), + process = process +} + +local node = {} +local methods = {} +local mt = { __index = methods } + +local function merge_metadata(existing, new_fields) + local existing_count = 0 + local new_count = 0 + + if type(existing) == "table" then + for _ in pairs(existing) do + existing_count = existing_count + 1 + end + end + + if type(new_fields) == "table" then + for _ in pairs(new_fields) do + new_count = new_count + 1 + end + end + + local result = table.create(0, existing_count + new_count) + + if type(existing) == "table" then + for k, v in pairs(existing) do + result[k] = v + end + end + if type(new_fields) == "table" then + for k, v in pairs(new_fields) do + result[k] = v + end + end + return result +end + +local function create_transform_env(raw_inputs) + local input_count = 0 + for _ in pairs(raw_inputs) do + input_count = input_count + 1 + end + + local inputs_by_key = table.create(0, input_count) + local default_content = nil + + for key, input_data in pairs(raw_inputs) do + inputs_by_key[key] = input_data.content + + if key == "default" or key == "" then + default_content = input_data.content + end + end + + return { + input = raw_inputs, + inputs = inputs_by_key, + default = default_content + } +end + +local function resolve_dataflow_references(self, value) + if type(value) ~= "table" then + return value + end + + if value._dataflow_ref then + local reader = (self._deps.data_reader.with_dataflow(self.dataflow_id) :: any) + :with_data(value._dataflow_ref) + :fetch_options({ replace_references = true }) + + local data = reader:one() + if data and data.content then + if data.content_type == consts.CONTENT_TYPE.JSON and type(data.content) == "string" then + local parsed, _ = json.decode(data.content) + return parsed or data.content + end + return data.content + end + return value + end + + if #value > 0 and value[1] and value[1]._dataflow_ref then + local resolved = {} + for i, item in ipairs(value) do + resolved[i] = resolve_dataflow_references(self, item) + end + return resolved + end + + return value +end + +function node.new(args, deps) + if not args then + return nil, "Node args required" + end + if not args.node_id or not args.dataflow_id then + return nil, "Node args must contain node_id and dataflow_id" + end + + deps = deps or default_deps + + local yield_reply_topic = consts.MESSAGE_TOPIC.YIELD_REPLY_PREFIX .. args.node_id + local yield_channel = (deps.process :: any).listen(yield_reply_topic) + + local instance = { + node_id = args.node_id, + dataflow_id = args.dataflow_id, + node = args.node or {}, + path = args.path or table.create(1, 0), + + _config = (args.node and args.node.config) or {}, + data_targets = (args.node and args.node.config and args.node.config.data_targets) or table.create(0, 0), + error_targets = (args.node and args.node.config and args.node.config.error_targets) or table.create(0, 0), + + _metadata = (args.node and args.node.metadata) or {}, + _queued_commands = table.create(10, 0), + _created_data_ids = table.create(5, 0), + _cached_inputs = nil, + + _yield_channel = yield_channel, + _yield_reply_topic = yield_reply_topic, + _last_yield_id = nil, + + _deps = deps + } + + if not instance.path[1] or instance.path[1] ~= args.node_id then + table.insert(instance.path, args.node_id) + end + + return setmetatable(instance, mt) :: any, nil +end + +function methods:config() + return self._config +end + +function methods:_transform_inputs_with_expr(raw_inputs, transform_config) + local env = create_transform_env(raw_inputs) + + if type(transform_config) == "string" then + local content, err = expr.eval(transform_config, env) + if err then + return nil, "Input transformation failed: " .. tostring(err) + end + return { + ["default"] = { + content = content, + metadata = {}, + key = "default", + discriminator = nil + } + }, nil + end + + if type(transform_config) ~= "table" then + return nil, "Node [" .. self.node_id .. "] input_transform must be string or table" + end + + local field_count = 0 + for _ in pairs(transform_config :: {[string]: string}) do + field_count = field_count + 1 + end + + local result = table.create(0, field_count) + for field_name, expression in pairs(transform_config :: {[string]: string}) do + if type(expression) ~= "string" then + return nil, "Transform failed for " .. field_name .. ": expression must be a string" + end + + local content, err = expr.eval(expression, env) + if err then + return nil, "Transform failed for " .. field_name .. ": " .. tostring(err) + end + result[field_name] = { + content = content, + metadata = {}, + key = field_name, + discriminator = nil + } + end + return result, nil +end + +function methods:_load_raw_inputs() + local input_data = (self._deps.data_reader.with_dataflow(self.dataflow_id) :: any) + :with_nodes(self.node_id) + :with_data_types(consts.DATA_TYPE.NODE_INPUT) + :fetch_options({ replace_references = true }) + :all() + + local inputs_map = table.create(0, #input_data) + + for _, input in ipairs(input_data) do + local parsed_content = input.content + + if input.content_type == consts.CONTENT_TYPE.JSON and type(input.content) == "string" then + local parsed, err = json.decode(input.content) + if not err then + parsed_content = parsed + end + end + + local map_key = input.discriminator or input.key or "default" + inputs_map[map_key] = { + content = parsed_content, + metadata = input.metadata or {}, + key = input.key, + discriminator = input.discriminator, + data_id = input.data_id, + content_type = input.content_type + } + end + + return inputs_map +end + +function methods:inputs() + if self._cached_inputs then + return self._cached_inputs + end + + local raw_inputs = self:_load_raw_inputs() + + local transform_config = self._config.input_transform + if transform_config then + local transformed, err = self:_transform_inputs_with_expr(raw_inputs, transform_config) + if err then + error(err) + end + self._cached_inputs = transformed + return transformed, nil + end + + self._cached_inputs = raw_inputs + return raw_inputs, nil +end + +function methods:input(key) + if not key then + error("Input key is required") + end + + local inputs_map, err = self:inputs() + if err then + error(err) + end + return inputs_map[key], nil +end + +function methods:with_child_nodes(definitions) + if not definitions or type(definitions) ~= "table" then + return nil, "Child definitions required" + end + + local child_ids = table.create(#definitions, 0) + + for i, definition in ipairs(definitions) do + if type(definition) ~= "table" then + return nil, "Invalid child definition at index " .. i + end + + local node_type = definition.node_type + if type(node_type) ~= "string" or node_type == "" then + return nil, "Child definition at index " .. i .. " missing node_type" + end + + local child_id = definition.node_id or uuid.v7() + child_ids[i] = child_id + + table.insert(self._queued_commands, { + type = consts.COMMAND_TYPES.CREATE_NODE, + payload = { + node_id = child_id, + node_type = node_type, + parent_node_id = self.node_id, + status = definition.status or consts.STATUS.PENDING, + config = definition.config, + metadata = definition.metadata + } + }) + end + + return child_ids, nil +end + +function methods:data(data_type, content, options) + if not data_type or data_type == "" then + return nil, "Node [" .. self.node_id .. "] data type is required" + end + + if content == nil then + return nil, "Node [" .. self.node_id .. "] content is required" + end + + options = options or {} + + local content_type = options.content_type + if not content_type then + if type(content) == "table" then + content_type = consts.CONTENT_TYPE.JSON + else + content_type = consts.CONTENT_TYPE.TEXT + end + end + + local data_id = options.data_id or uuid.v7() + table.insert(self._created_data_ids, data_id) + + local command = { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = data_id, + data_type = data_type, + key = options.key, + content = content, + content_type = content_type, + discriminator = options.discriminator, + node_id = options.node_id, + metadata = options.metadata + } + } + + table.insert(self._queued_commands, command) + return self, nil +end + +function methods:update_metadata(updates) + if not updates or type(updates) ~= "table" then + return self, nil + end + + self._metadata = merge_metadata(self._metadata, updates) + + local command = { + type = consts.COMMAND_TYPES.UPDATE_NODE, + payload = { + node_id = self.node_id, + metadata = self._metadata + } + } + + table.insert(self._queued_commands, command) + return self, nil +end + +function methods:update_config(updates) + if not updates or type(updates) ~= "table" then + return self, nil + end + + self._config = merge_metadata(self._config, updates) + + local command = { + type = consts.COMMAND_TYPES.UPDATE_NODE, + payload = { + node_id = self.node_id, + config = self._config + } + } + + table.insert(self._queued_commands, command) + return self, nil +end + +function methods:submit() + if #self._queued_commands == 0 then + return true, nil + end + + local op_id = uuid.v7() + local success, err = self._deps.commit.submit(self.dataflow_id, op_id, self._queued_commands) + + if success then + self._queued_commands = table.create(10, 0) + return true, nil + else + return false, err or "unknown" + end +end + +function methods.yield(self: table, options) + options = options or {} + + local yield_id = uuid.v7() + local op_id = uuid.v7() + + local yield_command: table = { + type = consts.COMMAND_TYPES.CREATE_DATA, + payload = { + data_id = uuid.v7(), + data_type = consts.DATA_TYPE.NODE_YIELD, + content = { + node_id = self.node_id, + yield_id = yield_id, + reply_to = self._yield_reply_topic, + yield_context = { + run_nodes = options.run_nodes or table.create(0, 0) + } + }, + content_type = consts.CONTENT_TYPE.JSON, + key = yield_id, + node_id = self.node_id + } + } + table.insert(self._queued_commands, yield_command) + + local submitted, err = self._deps.commit.submit(self.dataflow_id, op_id, self._queued_commands) + if not submitted then + return nil, "Failed to submit yield: " .. (err or "unknown") + end + self._queued_commands = table.create(10, 0) + + local yield_signal = { + request_context = { + yield_id = yield_id, + node_id = self.node_id, + reply_to = self._yield_reply_topic + }, + yield_context = { + run_nodes = options.run_nodes or table.create(0, 0) + } + } + + local success = (self._deps.process :: any).send( + "dataflow." .. self.dataflow_id, + consts.MESSAGE_TOPIC.YIELD_REQUEST, + yield_signal + ) + + if not success then + return nil, "Failed to send yield signal" + end + + local received, ok = self._yield_channel:receive() + if not ok then + return nil, "Yield channel closed" + end + + self._last_yield_id = yield_id + + if received and received.response_data then + return received.response_data.run_node_results or table.create(0, 0), nil + end + + return table.create(0, 0), nil +end + +function methods:query() + return self._deps.data_reader.with_dataflow(self.dataflow_id) +end + +function methods:_route_outputs(content) + local routed_data_ids = table.create(#self.data_targets, 0) + local data_id_count = 0 + + local resolved_content = resolve_dataflow_references(self, content) + + local ok, inputs_or_err = pcall(function() + local values = self:inputs() + return values + end) + if not ok then + return nil, "Node [" .. self.node_id .. "] failed to load inputs for output routing: " .. tostring(inputs_or_err) + end + + local inputs_map = inputs_or_err + local env = { + output = resolved_content, + input = inputs_map or {}, + inputs = inputs_map or {}, + node = self + } + + for target_idx, target in ipairs(self.data_targets) do + local target_desc = "target[" .. target_idx .. "]" + if target.discriminator then + target_desc = target_desc .. " (discriminator=" .. target.discriminator .. ")" + end + if target.node_id then + target_desc = target_desc .. " -> node[" .. target.node_id .. "]" + end + + if target.condition then + local should_create, condition_err = expr.eval(target.condition :: string, env) + if condition_err then + return nil, "Output condition evaluation failed for " .. target_desc .. ": " .. tostring(condition_err) + end + if not should_create then + goto continue + end + end + + local output_content = resolved_content + local has_transform = target.transform ~= nil + + if has_transform then + local transformed, transform_err = expr.eval(target.transform :: string, env) + if transform_err then + return nil, "Output transform failed for " .. target_desc .. ": " .. tostring(transform_err) + end + output_content = transformed + end + + if output_content == nil then + return nil, "Node [" .. self.node_id .. "] output content is nil for " .. target_desc .. + " (transform: " .. tostring(target.transform or "none") .. ")" + end + + local data_id = uuid.v7() + data_id_count = data_id_count + 1 + routed_data_ids[data_id_count] = data_id + + local _, data_err = self:data(target.data_type, output_content, { + data_id = data_id, + key = target.key, + discriminator = target.discriminator, + node_id = target.node_id or self.node_id, + content_type = target.content_type, + metadata = target.metadata + }) + if data_err then + return nil, "Node [" .. self.node_id .. "] failed to create data for " .. target_desc .. ": " .. (data_err :: string) + end + + ::continue:: + end + + return routed_data_ids, nil +end + +function methods:_route_errors(error_content) + local routed_data_ids = table.create(#self.error_targets, 0) + local data_id_count = 0 + + local env = { + error = error_content, + node = self + } + + for target_idx, target in ipairs(self.error_targets) do + local target_desc = "error_target[" .. target_idx .. "]" + if target.discriminator then + target_desc = target_desc .. " (discriminator=" .. target.discriminator .. ")" + end + if target.node_id then + target_desc = target_desc .. " -> node[" .. target.node_id .. "]" + end + + if target.condition then + local should_create, condition_err = expr.eval(target.condition :: string, env) + if condition_err then + goto continue + end + if not should_create then + goto continue + end + end + + local error_output = error_content + if target.transform then + local transformed, transform_err = expr.eval(target.transform :: string, env) + if not transform_err then + error_output = transformed + end + end + + local data_id = uuid.v7() + data_id_count = data_id_count + 1 + routed_data_ids[data_id_count] = data_id + + self:data(target.data_type, error_output, { + data_id = data_id, + key = target.key, + discriminator = target.discriminator, + node_id = target.node_id, + content_type = target.content_type, + metadata = target.metadata + }) + + ::continue:: + end + + return routed_data_ids, nil +end + +function methods:_submit_final() + if #self._queued_commands == 0 then + return true, nil + end + + local result, err = self._deps.commit.submit( + self.dataflow_id, + uuid.v7(), + self._queued_commands + ) + + self._queued_commands = table.create(10, 0) + return result ~= nil, err +end + +function methods:complete(output_content, message, extra_metadata) + if extra_metadata then + local _, meta_err = self:update_metadata(extra_metadata) + if meta_err then + return { + success = false, + message = "Node [" .. self.node_id .. "] failed to update metadata: " .. (meta_err :: string), + error = meta_err, + data_ids = table.create(0, 0) + } + end + end + + if message then + local _, msg_err = self:update_metadata({ status_message = message }) + if msg_err then + return { + success = false, + message = "Node [" .. self.node_id .. "] failed to set status message: " .. (msg_err :: string), + error = msg_err, + data_ids = table.create(0, 0) + } + end + end + + local data_ids = table.create(#self.data_targets, 0) + if output_content ~= nil then + local routed_ids, route_err = self:_route_outputs(output_content) + if route_err then + error(route_err) + end + data_ids = routed_ids + end + + local success, err = self:_submit_final() + if not success then + return { + success = false, + message = "Node [" .. self.node_id .. "] failed to submit final commands: " .. (err or "unknown"), + error = err, + data_ids = table.create(0, 0) + } + end + + return { + success = true, + message = message or "Node execution completed successfully", + data_ids = data_ids + } +end + +function methods:fail(error_details, message, extra_metadata) + local error_msg = error_details or "Unknown error" + local status_msg = message or error_msg + + local error_metadata = { + status_message = status_msg, + error = error_msg + } + + if extra_metadata then + error_metadata = merge_metadata(error_metadata, extra_metadata) + end + + self:update_metadata(error_metadata) + + local data_ids = table.create(#self.error_targets, 0) + if error_details ~= nil then + local routed_ids, route_err = self:_route_errors(error_details) + if not route_err then + data_ids = routed_ids + end + end + + local success, err = self:_submit_final() + if not success then + return { + success = false, + message = "Node [" .. self.node_id .. "] failed to submit final commands: " .. (err or "unknown"), + error = err, + data_ids = table.create(0, 0) + } + end + + return { + success = false, + message = status_msg, + error = error_msg, + data_ids = data_ids + } +end + +function methods:command(cmd) + if not cmd or not cmd.type then + return nil, "Command must have a type" + end + + table.insert(self._queued_commands, cmd) + return self, nil +end + +function methods:created_data_ids() + return self._created_data_ids +end + +return node diff --git a/testdata/fixtures/regression/error-return-union-errorlike/main.lua b/testdata/fixtures/regression/error-return-union-errorlike/main.lua new file mode 100644 index 00000000..7659fa13 --- /dev/null +++ b/testdata/fixtures/regression/error-return-union-errorlike/main.lua @@ -0,0 +1,50 @@ +type GenError = { + message: string, +} + +local id_source = {} + +function id_source.v7(): (string, GenError?) + return "id", nil +end + +type ActiveSession = { + pid: any, +} + +local active_sessions = {} :: {[string]: ActiveSession} + +local function create_session(payload_data) + if not payload_data then + return nil, "missing payload" + end + + local session_id = payload_data.session_id + if not session_id then + local id, err = id_source.v7() + if err then + return nil, err + end + session_id = id + end + + if not payload_data.start_token then + return nil, "missing token" + end + + return session_id, nil +end + +local function use_session(payload_data) + local created_session_id, err = create_session(payload_data) + if err then + return + end + + local recovered_session_info = active_sessions[created_session_id] + if recovered_session_info then + return recovered_session_info.pid + end +end + +return use_session diff --git a/testdata/fixtures/regression/error-return-union-errorlike/manifest.json b/testdata/fixtures/regression/error-return-union-errorlike/manifest.json new file mode 100644 index 00000000..142ae27a --- /dev/null +++ b/testdata/fixtures/regression/error-return-union-errorlike/manifest.json @@ -0,0 +1,4 @@ +{ + "files": ["main.lua"], + "check": {"errors": 0} +} diff --git a/testdata/fixtures/regression/field-defined-wrapper-return-local-alias-reassigned/main.lua b/testdata/fixtures/regression/field-defined-wrapper-return-local-alias-reassigned/main.lua new file mode 100644 index 00000000..1cdd0bd5 --- /dev/null +++ b/testdata/fixtures/regression/field-defined-wrapper-return-local-alias-reassigned/main.lua @@ -0,0 +1,20 @@ +type Res = { answer: string } + +local M = { + dep = { + get = function() + return nil + end, + }, +} + +function M.run() + return M.dep.get() +end + +M.run = function() + return nil +end + +local f: fun(): Res = M.run +return f diff --git a/testdata/fixtures/regression/field-defined-wrapper-return-local-alias-reassigned/manifest.json b/testdata/fixtures/regression/field-defined-wrapper-return-local-alias-reassigned/manifest.json new file mode 100644 index 00000000..78a7ad01 --- /dev/null +++ b/testdata/fixtures/regression/field-defined-wrapper-return-local-alias-reassigned/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Aliased wrapper value typing must respect the current reassigned function value instead of reviving the original definition", + "files": [ + "main.lua" + ], + "check": { + "errors": 1 + } +} diff --git a/testdata/fixtures/regression/field-defined-wrapper-return-local-alias/main.lua b/testdata/fixtures/regression/field-defined-wrapper-return-local-alias/main.lua new file mode 100644 index 00000000..c9459005 --- /dev/null +++ b/testdata/fixtures/regression/field-defined-wrapper-return-local-alias/main.lua @@ -0,0 +1,24 @@ +type Res = { answer: string } + +local M = { + dep = { + get = function() + return nil + end, + }, +} + +function M.run() + return M.dep.get() +end + +M.dep = { + get = function() + return { answer = "ok" } + end, +} + +local f: fun(): Res = M.run +local res = f() +local answer: string = res.answer +return answer diff --git a/testdata/fixtures/regression/field-defined-wrapper-return-local-alias/manifest.json b/testdata/fixtures/regression/field-defined-wrapper-return-local-alias/manifest.json new file mode 100644 index 00000000..7424ad4a --- /dev/null +++ b/testdata/fixtures/regression/field-defined-wrapper-return-local-alias/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Field-defined wrapper return should survive local aliasing after the dominating visible write", + "files": [ + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/regression/field-defined-wrapper-return/main.lua b/testdata/fixtures/regression/field-defined-wrapper-return/main.lua new file mode 100644 index 00000000..9c7b56c1 --- /dev/null +++ b/testdata/fixtures/regression/field-defined-wrapper-return/main.lua @@ -0,0 +1,21 @@ +local M = { + dep = { + get = function() + return nil + end, + }, +} + +function M.run() + return M.dep.get() +end + +M.dep = { + get = function() + return { answer = "ok" } + end, +} + +local res = M.run() +local answer: string = res.answer +return answer diff --git a/testdata/fixtures/regression/field-defined-wrapper-return/manifest.json b/testdata/fixtures/regression/field-defined-wrapper-return/manifest.json new file mode 100644 index 00000000..16402dff --- /dev/null +++ b/testdata/fixtures/regression/field-defined-wrapper-return/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Field-defined wrapper functions should track dominating visible captured field writes", + "files": [ + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/regression/local-function-fact-authority/main.lua b/testdata/fixtures/regression/local-function-fact-authority/main.lua new file mode 100644 index 00000000..9ac72075 --- /dev/null +++ b/testdata/fixtures/regression/local-function-fact-authority/main.lua @@ -0,0 +1,64 @@ +type Entry = {id: string, meta: {type: string, suite: string?, order: number?}?} + +local function make() + local obj = { x = 1 } + local function init() + obj.get_x = function(self): number + return self.x + end + end + init() + return obj +end + +local built = make() +local n: number = built:get_x() + +local function make_async() + local obj = {} + coroutine.spawn(function() + obj.get_value = function(self): number + return 42 + end + end) + return obj +end + +local async_obj = make_async() +local v: number = async_obj:get_value() + +local function sorted_keys(t) + local keys = {} + for k in pairs(t) do + table.insert(keys, k) + end + table.sort(keys) + return keys +end + +local function group_by_suite(entries: {Entry}) + local suites = {} + local no_suite = {} + + for _, entry in ipairs(entries) do + local suite = entry.meta and entry.meta.suite + if suite then + suites[suite] = suites[suite] or {} + table.insert(suites[suite], entry) + else + table.insert(no_suite, entry) + end + end + + return suites, no_suite +end + +local entries: {Entry} = {} +local suites, no_suite = group_by_suite(entries) +local suite_names = sorted_keys(suites) + +for _, name in ipairs(suite_names) do + local tests: {Entry} = suites[name] +end + +local uncategorized: {Entry} = no_suite diff --git a/testdata/fixtures/regression/local-function-fact-authority/manifest.json b/testdata/fixtures/regression/local-function-fact-authority/manifest.json new file mode 100644 index 00000000..37b79947 --- /dev/null +++ b/testdata/fixtures/regression/local-function-fact-authority/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Stable local-function facts must dominate weaker bound-expression re-synthesis", + "files": [ + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/regression/local-function-narrow-return-repair/main.lua b/testdata/fixtures/regression/local-function-narrow-return-repair/main.lua new file mode 100644 index 00000000..41732a0e --- /dev/null +++ b/testdata/fixtures/regression/local-function-narrow-return-repair/main.lua @@ -0,0 +1,29 @@ +local function f(blocks) + local tool_use_block = nil + for _, block in ipairs(blocks) do + if block.type == "tool_use" and block.name == "structured_output" then + tool_use_block = block + break + end + end + if not tool_use_block then + return { success = false, error = "missing" } + end + return { success = true, result = { data = tool_use_block.input } } +end + +local export = { f = f } + +local out = export.f({ + { + type = "tool_use", + name = "structured_output", + input = { "ok" }, + }, +}) + +if out.success and type(out.result.data) == "table" then + local n: integer = #out.result.data +end + +return export diff --git a/testdata/fixtures/regression/local-function-narrow-return-repair/manifest.json b/testdata/fixtures/regression/local-function-narrow-return-repair/manifest.json new file mode 100644 index 00000000..aeffcd10 --- /dev/null +++ b/testdata/fixtures/regression/local-function-narrow-return-repair/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Local function literal summary should repair pre-flow nested never artifacts after narrowing", + "files": [ + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/testdata/fixtures/regression/non-dominating-field-defined-wrapper-return/main.lua b/testdata/fixtures/regression/non-dominating-field-defined-wrapper-return/main.lua new file mode 100644 index 00000000..906bc5cf --- /dev/null +++ b/testdata/fixtures/regression/non-dominating-field-defined-wrapper-return/main.lua @@ -0,0 +1,27 @@ +local function run(flag: boolean) + local M = { + dep = { + get = function() + return nil + end, + }, + } + + function M.run() + return M.dep.get() + end + + if flag then + M.dep = { + get = function() + return { answer = "ok" } + end, + } + end + + local res = M.run() + local answer: string = res.answer + return answer +end + +return run diff --git a/testdata/fixtures/regression/non-dominating-field-defined-wrapper-return/manifest.json b/testdata/fixtures/regression/non-dominating-field-defined-wrapper-return/manifest.json new file mode 100644 index 00000000..deb2c53a --- /dev/null +++ b/testdata/fixtures/regression/non-dominating-field-defined-wrapper-return/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Field-defined wrappers must not treat non-dominating captured field writes as definite", + "files": [ + "main.lua" + ], + "check": { + "errors": 2 + } +} diff --git a/testdata/fixtures/regression/non-dominating-field-write-call-assignment/main.lua b/testdata/fixtures/regression/non-dominating-field-write-call-assignment/main.lua new file mode 100644 index 00000000..f0e704d6 --- /dev/null +++ b/testdata/fixtures/regression/non-dominating-field-write-call-assignment/main.lua @@ -0,0 +1,23 @@ +local function run(flag: boolean) + local M = { + dep = { + get = function() + return nil + end, + }, + } + + if flag then + M.dep = { + get = function() + return { answer = "ok" } + end, + } + end + + local res = M.dep.get() + local answer: string = res.answer + return answer +end + +return run diff --git a/testdata/fixtures/regression/non-dominating-field-write-call-assignment/manifest.json b/testdata/fixtures/regression/non-dominating-field-write-call-assignment/manifest.json new file mode 100644 index 00000000..8b3a937b --- /dev/null +++ b/testdata/fixtures/regression/non-dominating-field-write-call-assignment/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Non-dominating branch field writes must not make later call results definite", + "files": [ + "main.lua" + ], + "check": { + "errors": 2 + } +} diff --git a/testdata/fixtures/regression/reassigned-field-call-assignment/main.lua b/testdata/fixtures/regression/reassigned-field-call-assignment/main.lua new file mode 100644 index 00000000..f2adae49 --- /dev/null +++ b/testdata/fixtures/regression/reassigned-field-call-assignment/main.lua @@ -0,0 +1,17 @@ +local M = { + dep = { + get = function() + return nil + end, + }, +} + +M.dep = { + get = function() + return { answer = "ok" } + end, +} + +local res = M.dep.get() +local answer: string = res.answer +return answer diff --git a/testdata/fixtures/regression/reassigned-field-call-assignment/manifest.json b/testdata/fixtures/regression/reassigned-field-call-assignment/manifest.json new file mode 100644 index 00000000..11a9921c --- /dev/null +++ b/testdata/fixtures/regression/reassigned-field-call-assignment/manifest.json @@ -0,0 +1,9 @@ +{ + "description": "Single-target assignment should preserve direct call result type after field reassignment", + "files": [ + "main.lua" + ], + "check": { + "errors": 0 + } +} diff --git a/types/constraint/constraint.go b/types/constraint/constraint.go index fa1c3098..c3726b4a 100644 --- a/types/constraint/constraint.go +++ b/types/constraint/constraint.go @@ -93,7 +93,7 @@ const ( // enabling efficient incremental narrowing. // // Substitute replaces placeholder paths ($0, $1) with concrete argument paths, -// used when applying function effect constraints at call sites. +// used when applying function refinement constraints at call sites. type Constraint interface { Kind() Kind Paths() []Path diff --git a/types/constraint/doc.go b/types/constraint/doc.go index cf7b8f3a..5f12ebae 100644 --- a/types/constraint/doc.go +++ b/types/constraint/doc.go @@ -19,7 +19,7 @@ // // [Solver]: Applies constraints to type environments to produce narrowed types. // -// [FunctionEffect]: Describes type refinements a function produces on its parameters. +// [FunctionRefinement]: Describes type refinements a function produces on its parameters. // // [Interner]: Provides constraint interning to reduce allocations for common patterns. // diff --git a/types/constraint/effect.go b/types/constraint/effect.go index 662661c3..60212911 100644 --- a/types/constraint/effect.go +++ b/types/constraint/effect.go @@ -5,15 +5,15 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -// EffectLookupBySym retrieves a function's inferred effect by symbol ID. +// RefinementLookupBySym retrieves a function's inferred refinement by symbol ID. // // Used during call site analysis to determine what type refinements a -// function call produces. Returns nil if the function has no recorded effect. -type EffectLookupBySym func(sym cfg.SymbolID) *FunctionEffect +// function call produces. Returns nil if the function has no recorded refinement. +type RefinementLookupBySym func(sym cfg.SymbolID) *FunctionRefinement -// FunctionEffect describes the type refinements a function produces. +// FunctionRefinement describes the type refinements a function produces. // -// Effects encode how a function call narrows types based on its return value. +// Refinements encode how a function call narrows types based on its return value. // Three categories are supported: // - OnReturn: constraints that hold when the function returns normally // (used for assert-style functions that error() on failure) @@ -23,7 +23,7 @@ type EffectLookupBySym func(sym cfg.SymbolID) *FunctionEffect // // Placeholder roots ($0, $1, ...) reference parameters by position. // At call sites, placeholders are substituted with actual argument paths. -type FunctionEffect struct { +type FunctionRefinement struct { // Row is the effect label set (IO, Mutate, Throw, etc.). // Stored as typ.EffectInfo to avoid circular import with effect package. // The concrete type is effect.Row. @@ -45,9 +45,9 @@ type FunctionEffect struct { Terminates bool } -// NewEffect creates a FunctionEffect from constraint slices. -func NewEffect(onReturn, onTrue, onFalse []Constraint) *FunctionEffect { - return &FunctionEffect{ +// NewRefinement creates a FunctionRefinement from constraint slices. +func NewRefinement(onReturn, onTrue, onFalse []Constraint) *FunctionRefinement { + return &FunctionRefinement{ OnReturn: FromConstraints(onReturn...), OnTrue: FromConstraints(onTrue...), OnFalse: FromConstraints(onFalse...), @@ -55,7 +55,7 @@ func NewEffect(onReturn, onTrue, onFalse []Constraint) *FunctionEffect { } // IsEmpty returns true if the effect has no constraints, no row, and doesn't terminate. -func (e *FunctionEffect) IsEmpty() bool { +func (e *FunctionRefinement) IsEmpty() bool { if e == nil { return true } @@ -64,23 +64,23 @@ func (e *FunctionEffect) IsEmpty() bool { } // HasAssertSemantics returns true if function has assert-style semantics. -func (e *FunctionEffect) HasAssertSemantics() bool { +func (e *FunctionRefinement) HasAssertSemantics() bool { return e != nil && e.OnReturn.HasConstraints() } // HasPredicateSemantics returns true if function has predicate semantics. -func (e *FunctionEffect) HasPredicateSemantics() bool { +func (e *FunctionRefinement) HasPredicateSemantics() bool { return e != nil && (e.OnTrue.HasConstraints() || e.OnFalse.HasConstraints()) } -// Equals returns true if two function effects are structurally equal. +// Equals returns true if two function refinements are structurally equal. // Implements internal.Equaler interface for use in typ.Function. -func (e *FunctionEffect) Equals(other any) bool { +func (e *FunctionRefinement) Equals(other any) bool { if other == nil { return e == nil } - o, ok := other.(*FunctionEffect) + o, ok := other.(*FunctionRefinement) if !ok { return false @@ -116,19 +116,19 @@ func effectRowEquals(a, b typ.EffectInfo) bool { } // IsRefinementInfo implements typ.RefinementInfo. -func (e *FunctionEffect) IsRefinementInfo() {} +func (e *FunctionRefinement) IsRefinementInfo() {} -// Substitute returns a new FunctionEffect with placeholder paths replaced. +// Substitute returns a new FunctionRefinement with placeholder paths replaced. // // At a call site, parameter placeholders ($0, $1, ...) are replaced with // the actual argument paths, producing concrete constraints that can be // applied to narrow types at the call location. -func (e *FunctionEffect) Substitute(args []Path) *FunctionEffect { +func (e *FunctionRefinement) Substitute(args []Path) *FunctionRefinement { if e == nil || e.IsEmpty() { return nil } - result := &FunctionEffect{ + result := &FunctionRefinement{ OnReturn: e.OnReturn.Substitute(args), OnTrue: e.OnTrue.Substitute(args), OnFalse: e.OnFalse.Substitute(args), @@ -147,7 +147,7 @@ func (e *FunctionEffect) Substitute(args []Path) *FunctionEffect { // // Detection looks for KeyOf constraints in OnReturn where the table is a // parameter placeholder ($N) and the key is a return path. -func (e *FunctionEffect) KeysCollectorInfo() (paramIndex int, returnIndex int, ok bool) { +func (e *FunctionRefinement) KeysCollectorInfo() (paramIndex int, returnIndex int, ok bool) { if e == nil || !e.OnReturn.HasConstraints() { return 0, 0, false } @@ -187,7 +187,7 @@ func (e *FunctionEffect) KeysCollectorInfo() (paramIndex int, returnIndex int, o // KeysCollectorParamIndex checks if the function returns keys of a parameter. // // Returns the parameter index (0-based) if found, or -1 otherwise. -func (e *FunctionEffect) KeysCollectorParamIndex() int { +func (e *FunctionRefinement) KeysCollectorParamIndex() int { paramIdx, _, ok := e.KeysCollectorInfo() if !ok { return -1 diff --git a/types/constraint/effect_test.go b/types/constraint/effect_test.go index fff46a70..cadf5206 100644 --- a/types/constraint/effect_test.go +++ b/types/constraint/effect_test.go @@ -6,14 +6,14 @@ import ( "github.com/wippyai/go-lua/types/narrow" ) -func TestNewEffect(t *testing.T) { +func TestNewRefinement(t *testing.T) { onReturn := []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}} onTrue := []Constraint{Truthy{Path: Path{Root: "$0"}}} onFalse := []Constraint{Falsy{Path: Path{Root: "$0"}}} - e := NewEffect(onReturn, onTrue, onFalse) + e := NewRefinement(onReturn, onTrue, onFalse) if e == nil { - t.Fatal("NewEffect returned nil") + t.Fatal("NewRefinement returned nil") } if len(e.OnReturn.MustConstraints()) != 1 { @@ -29,18 +29,18 @@ func TestNewEffect(t *testing.T) { } } -func TestFunctionEffect_IsEmpty(t *testing.T) { - var nilEffect *FunctionEffect +func TestFunctionRefinement_IsEmpty(t *testing.T) { + var nilEffect *FunctionRefinement if !nilEffect.IsEmpty() { t.Error("nil effect should be empty") } - empty := &FunctionEffect{} + empty := &FunctionRefinement{} if !empty.IsEmpty() { t.Error("empty effect should be empty") } - nonEmpty := NewEffect( + nonEmpty := NewRefinement( []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, nil, nil, ) @@ -49,13 +49,13 @@ func TestFunctionEffect_IsEmpty(t *testing.T) { } } -func TestFunctionEffect_HasAssertSemantics(t *testing.T) { - var nilEffect *FunctionEffect +func TestFunctionRefinement_HasAssertSemantics(t *testing.T) { + var nilEffect *FunctionRefinement if nilEffect.HasAssertSemantics() { t.Error("nil effect should not have assert semantics") } - assert := NewEffect( + assert := NewRefinement( []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, nil, nil, ) @@ -63,7 +63,7 @@ func TestFunctionEffect_HasAssertSemantics(t *testing.T) { t.Error("effect with OnReturn should have assert semantics") } - predicate := NewEffect(nil, + predicate := NewRefinement(nil, []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, nil, ) @@ -72,13 +72,13 @@ func TestFunctionEffect_HasAssertSemantics(t *testing.T) { } } -func TestFunctionEffect_HasPredicateSemantics(t *testing.T) { - var nilEffect *FunctionEffect +func TestFunctionRefinement_HasPredicateSemantics(t *testing.T) { + var nilEffect *FunctionRefinement if nilEffect.HasPredicateSemantics() { t.Error("nil effect should not have predicate semantics") } - withOnTrue := NewEffect(nil, + withOnTrue := NewRefinement(nil, []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, nil, ) @@ -86,14 +86,14 @@ func TestFunctionEffect_HasPredicateSemantics(t *testing.T) { t.Error("effect with OnTrue should have predicate semantics") } - withOnFalse := NewEffect(nil, nil, + withOnFalse := NewRefinement(nil, nil, []Constraint{NotHasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, ) if !withOnFalse.HasPredicateSemantics() { t.Error("effect with OnFalse should have predicate semantics") } - assertOnly := NewEffect( + assertOnly := NewRefinement( []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, nil, nil, ) @@ -102,13 +102,13 @@ func TestFunctionEffect_HasPredicateSemantics(t *testing.T) { } } -func TestFunctionEffect_Equals(t *testing.T) { - e1 := NewEffect( +func TestFunctionRefinement_Equals(t *testing.T) { + e1 := NewRefinement( []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, []Constraint{Truthy{Path: Path{Root: "$0"}}}, nil, ) - e2 := NewEffect( + e2 := NewRefinement( []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, []Constraint{Truthy{Path: Path{Root: "$0"}}}, nil, @@ -118,7 +118,7 @@ func TestFunctionEffect_Equals(t *testing.T) { t.Error("identical effects should be equal") } - e3 := NewEffect( + e3 := NewRefinement( []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("number")}}, nil, nil, ) @@ -126,7 +126,7 @@ func TestFunctionEffect_Equals(t *testing.T) { t.Error("different effects should not be equal") } - var nilEffect *FunctionEffect + var nilEffect *FunctionRefinement if !nilEffect.Equals(nil) { t.Error("nil effect should equal nil") } @@ -144,8 +144,8 @@ func TestFunctionEffect_Equals(t *testing.T) { } } -func TestFunctionEffect_Substitute(t *testing.T) { - e := NewEffect( +func TestFunctionRefinement_Substitute(t *testing.T) { + e := NewRefinement( []Constraint{HasType{Path: Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, []Constraint{Truthy{Path: Path{Root: "$1"}}}, nil, @@ -187,25 +187,25 @@ func TestFunctionEffect_Substitute(t *testing.T) { } } -func TestFunctionEffect_SubstituteNil(t *testing.T) { - var nilEffect *FunctionEffect +func TestFunctionRefinement_SubstituteNil(t *testing.T) { + var nilEffect *FunctionRefinement if nilEffect.Substitute([]Path{{Root: "x"}}) != nil { t.Error("substituting nil effect should return nil") } - empty := &FunctionEffect{} + empty := &FunctionRefinement{} if empty.Substitute([]Path{{Root: "x"}}) != nil { t.Error("substituting empty effect should return nil") } } -func TestFunctionEffect_IsRefinementInfo(t *testing.T) { - e := &FunctionEffect{} +func TestFunctionRefinement_IsRefinementInfo(t *testing.T) { + e := &FunctionRefinement{} e.IsRefinementInfo() // Should not panic } -func TestFunctionEffect_KeysCollectorInfo(t *testing.T) { - eff := &FunctionEffect{ +func TestFunctionRefinement_KeysCollectorInfo(t *testing.T) { + eff := &FunctionRefinement{ OnReturn: FromConstraints(KeyOf{ Table: ParamPath(0), Key: RetPath(1), @@ -227,8 +227,8 @@ func TestFunctionEffect_KeysCollectorInfo(t *testing.T) { } } -func TestFunctionEffect_KeysCollectorInfo_InvalidKeyPath(t *testing.T) { - eff := &FunctionEffect{ +func TestFunctionRefinement_KeysCollectorInfo_InvalidKeyPath(t *testing.T) { + eff := &FunctionRefinement{ OnReturn: FromConstraints(KeyOf{ Table: ParamPath(0), Key: Path{Root: "ret[abc]"}, @@ -243,8 +243,8 @@ func TestFunctionEffect_KeysCollectorInfo_InvalidKeyPath(t *testing.T) { } } -func TestFunctionEffect_KeysCollectorInfo_AmbiguousDisjuncts(t *testing.T) { - eff := &FunctionEffect{ +func TestFunctionRefinement_KeysCollectorInfo_AmbiguousDisjuncts(t *testing.T) { + eff := &FunctionRefinement{ OnReturn: Condition{ Disjuncts: [][]Constraint{ { diff --git a/types/constraint/infer.go b/types/constraint/infer.go index 6190eef6..a43dbf09 100644 --- a/types/constraint/infer.go +++ b/types/constraint/infer.go @@ -375,6 +375,17 @@ func (c *InferSet) unifySCC(scc []int) { // walkType traverses a type tree, calling pred on each node. // Returns true if pred returns true for any node. +func canContainTypeVar(k kind.Kind) bool { + switch k { + case kind.Optional, kind.Union, kind.Intersection, kind.Array, + kind.Map, kind.Tuple, kind.Function, kind.Record, kind.Alias, + kind.TypeVar, kind.Instantiated: + return true + default: + return false + } +} + func walkType(t typ.Type, depth int, pred func(typ.Type) bool) bool { if stopDepth(t, depth) { return false @@ -384,6 +395,10 @@ func walkType(t typ.Type, depth int, pred func(typ.Type) bool) bool { return true } + if !canContainTypeVar(t.Kind()) { + return false + } + return typ.Visit(t, typ.Visitor[bool]{ Optional: func(o *typ.Optional) bool { return walkType(o.Inner, depth+1, pred) @@ -451,12 +466,90 @@ func walkType(t typ.Type, depth int, pred func(typ.Type) bool) bool { } func occursIn(varID int, t typ.Type) bool { - return walkType(t, 0, func(inner typ.Type) bool { + seen := make(map[typ.Type]bool) + return walkTypeMemo(t, 0, seen, func(inner typ.Type) bool { if tv, ok := inner.(*typ.TypeVar); ok { return tv.ID == varID } + return false + }) +} +func walkTypeMemo(t typ.Type, depth int, seen map[typ.Type]bool, pred func(typ.Type) bool) bool { + if stopDepth(t, depth) { + return false + } + if pred(t) { + return true + } + if !canContainTypeVar(t.Kind()) { return false + } + if seen[t] { + return false + } + seen[t] = true + return typ.Visit(t, typ.Visitor[bool]{ + Optional: func(o *typ.Optional) bool { + return walkTypeMemo(o.Inner, depth+1, seen, pred) + }, + Union: func(u *typ.Union) bool { + for _, m := range u.Members { + if walkTypeMemo(m, depth+1, seen, pred) { + return true + } + } + return false + }, + Intersection: func(in *typ.Intersection) bool { + for _, m := range in.Members { + if walkTypeMemo(m, depth+1, seen, pred) { + return true + } + } + return false + }, + Tuple: func(tup *typ.Tuple) bool { + for _, e := range tup.Elements { + if walkTypeMemo(e, depth+1, seen, pred) { + return true + } + } + return false + }, + Array: func(a *typ.Array) bool { + return walkTypeMemo(a.Element, depth+1, seen, pred) + }, + Map: func(m *typ.Map) bool { + return walkTypeMemo(m.Key, depth+1, seen, pred) || walkTypeMemo(m.Value, depth+1, seen, pred) + }, + Function: func(fn *typ.Function) bool { + for _, p := range fn.Params { + if walkTypeMemo(p.Type, depth+1, seen, pred) { + return true + } + } + for _, r := range fn.Returns { + if walkTypeMemo(r, depth+1, seen, pred) { + return true + } + } + return walkTypeMemo(fn.Variadic, depth+1, seen, pred) + }, + Record: func(r *typ.Record) bool { + for _, f := range r.Fields { + if walkTypeMemo(f.Type, depth+1, seen, pred) { + return true + } + } + return false + }, + Alias: func(a *typ.Alias) bool { + return walkTypeMemo(a.Target, depth+1, seen, pred) + }, + Default: func(t typ.Type) bool { + return false + }, }) } @@ -565,7 +658,8 @@ func (e *UnsatisfiableError) Error() string { } func containsTypeVar(t typ.Type) bool { - return walkType(t, 0, func(inner typ.Type) bool { + seen := make(map[typ.Type]bool) + return walkTypeMemo(t, 0, seen, func(inner typ.Type) bool { return inner.Kind() == kind.TypeVar }) } @@ -576,24 +670,20 @@ type InferSubstitution map[int]typ.Type // Apply applies this substitution to a type. func (s InferSubstitution) Apply(t typ.Type) typ.Type { visited := make(map[int]bool) - typeVisited := make(map[typ.Type]bool) + memo := make(map[typ.Type]typ.Type) - return applyInferSubst(t, s, visited, typeVisited, 0) + return applyInferSubst(t, s, visited, memo, 0) } -func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, typeVisited map[typ.Type]bool, depth int) typ.Type { +func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, memo map[typ.Type]typ.Type, depth int) typ.Type { if stopDepth(t, depth) { return t } - if typeVisited[t] { - return t + if result, ok := memo[t]; ok { + return result } - typeVisited[t] = true - - defer delete(typeVisited, t) - if _, ok := t.(*typ.TypeVar); ok { current := t @@ -619,7 +709,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type continue } - result := applyInferSubst(solved, s, visited, typeVisited, depth+1) + result := applyInferSubst(solved, s, visited, memo, depth+1) curr := t for { @@ -648,9 +738,11 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type return current } - return typ.Visit(t, typ.Visitor[typ.Type]{ + // Sentinel: mark t as in-progress so recursive references return t as-is. + memo[t] = t + result := typ.Visit(t, typ.Visitor[typ.Type]{ Optional: func(o *typ.Optional) typ.Type { - inner := applyInferSubst(o.Inner, s, visited, typeVisited, depth+1) + inner := applyInferSubst(o.Inner, s, visited, memo, depth+1) if inner == o.Inner { return t } @@ -662,7 +754,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type members := make([]typ.Type, len(u.Members)) for i, m := range u.Members { - members[i] = applyInferSubst(m, s, visited, typeVisited, depth+1) + members[i] = applyInferSubst(m, s, visited, memo, depth+1) if members[i] != m { changed = true } @@ -679,7 +771,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type members := make([]typ.Type, len(in.Members)) for i, m := range in.Members { - members[i] = applyInferSubst(m, s, visited, typeVisited, depth+1) + members[i] = applyInferSubst(m, s, visited, memo, depth+1) if members[i] != m { changed = true } @@ -696,7 +788,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type elems := make([]typ.Type, len(tup.Elements)) for i, e := range tup.Elements { - elems[i] = applyInferSubst(e, s, visited, typeVisited, depth+1) + elems[i] = applyInferSubst(e, s, visited, memo, depth+1) if elems[i] != e { changed = true } @@ -713,7 +805,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type params := make([]typ.Param, len(fn.Params)) for i, p := range fn.Params { - pType := applyInferSubst(p.Type, s, visited, typeVisited, depth+1) + pType := applyInferSubst(p.Type, s, visited, memo, depth+1) params[i] = typ.Param{Name: p.Name, Type: pType, Optional: p.Optional} if pType != p.Type { @@ -724,7 +816,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type returns := make([]typ.Type, len(fn.Returns)) for i, r := range fn.Returns { - returns[i] = applyInferSubst(r, s, visited, typeVisited, depth+1) + returns[i] = applyInferSubst(r, s, visited, memo, depth+1) if returns[i] != r { changed = true } @@ -732,7 +824,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type var variadic typ.Type if fn.Variadic != nil { - variadic = applyInferSubst(fn.Variadic, s, visited, typeVisited, depth+1) + variadic = applyInferSubst(fn.Variadic, s, visited, memo, depth+1) if variadic != fn.Variadic { changed = true } @@ -761,7 +853,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type return fb.Build() }, Array: func(a *typ.Array) typ.Type { - elem := applyInferSubst(a.Element, s, visited, typeVisited, depth+1) + elem := applyInferSubst(a.Element, s, visited, memo, depth+1) if elem == a.Element { return t } @@ -769,8 +861,8 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type return typ.NewArray(elem) }, Map: func(m *typ.Map) typ.Type { - key := applyInferSubst(m.Key, s, visited, typeVisited, depth+1) - value := applyInferSubst(m.Value, s, visited, typeVisited, depth+1) + key := applyInferSubst(m.Key, s, visited, memo, depth+1) + value := applyInferSubst(m.Value, s, visited, memo, depth+1) if key == m.Key && value == m.Value { return t @@ -783,7 +875,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type fields := make([]typ.Field, len(r.Fields)) for i, f := range r.Fields { - fType := applyInferSubst(f.Type, s, visited, typeVisited, depth+1) + fType := applyInferSubst(f.Type, s, visited, memo, depth+1) fields[i] = typ.Field{Name: f.Name, Type: fType, Optional: f.Optional, Readonly: f.Readonly} if fType != f.Type { @@ -813,7 +905,7 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type return rec }, Alias: func(a *typ.Alias) typ.Type { - target := applyInferSubst(a.Target, s, visited, typeVisited, depth+1) + target := applyInferSubst(a.Target, s, visited, memo, depth+1) if target == a.Target { return t } @@ -824,6 +916,8 @@ func applyInferSubst(t typ.Type, s InferSubstitution, visited map[int]bool, type return t }, }) + memo[t] = result + return result } // Match walks pattern and concrete types in parallel, collecting constraints. @@ -883,6 +977,14 @@ func matchDepth(pattern, concrete typ.Type, cs *InferSet, variance subtype.Varia } return struct{}{} }, + Union: func(p *typ.Union) struct{} { + concreteMembers := []typ.Type{concrete} + if c, ok := concrete.(*typ.Union); ok { + concreteMembers = c.Members + } + matchUnionMembers(p.Members, concreteMembers, cs, variance, depth+1) + return struct{}{} + }, Array: func(p *typ.Array) struct{} { if c, ok := concrete.(*typ.Array); ok { v := subtype.CombineVariance(variance, subtype.Invariant) @@ -1010,6 +1112,92 @@ func matchDepth(pattern, concrete typ.Type, cs *InferSet, variance subtype.Varia }) } +func matchUnionMembers(patternMembers, concreteMembers []typ.Type, cs *InferSet, variance subtype.Variance, depth int) { + if len(patternMembers) == 0 || len(concreteMembers) == 0 { + return + } + + order := make([]int, len(patternMembers)) + for i := range patternMembers { + order[i] = i + } + sort.SliceStable(order, func(i, j int) bool { + left := unionMemberSpecificity(patternMembers[order[i]]) + right := unionMemberSpecificity(patternMembers[order[j]]) + if left != right { + return left > right + } + return order[i] < order[j] + }) + + used := make([]bool, len(concreteMembers)) + for _, patternIdx := range order { + patternMember := patternMembers[patternIdx] + + bestConcrete := -1 + bestScore := -1 + for concreteIdx, concreteMember := range concreteMembers { + if used[concreteIdx] || !typesOverlapForInference(patternMember, concreteMember) { + continue + } + score := unionPairScore(patternMember, concreteMember) + if score > bestScore { + bestScore = score + bestConcrete = concreteIdx + } + } + + if bestConcrete < 0 { + continue + } + + used[bestConcrete] = true + matchDepth(patternMember, concreteMembers[bestConcrete], cs, variance, depth) + } +} + +func unionMemberSpecificity(t typ.Type) int { + if t == nil { + return 0 + } + + score := 0 + if !containsTypeVar(t) { + score += 4 + } + + switch unwrap.Alias(t).Kind() { + case kind.TypeVar: + // Leave fully-generic members as the last resort. + case kind.Literal: + score += 4 + default: + score += 2 + } + + return score +} + +func unionPairScore(pattern, concrete typ.Type) int { + score := unionMemberSpecificity(pattern) + + patternKind := unwrap.Alias(pattern).Kind() + concreteKind := unwrap.Alias(concrete).Kind() + if patternKind == concreteKind { + score += 4 + } + if subtype.IsSubtype(concrete, pattern) || subtype.IsSubtype(pattern, concrete) { + score += 2 + } + + return score +} + +func typesOverlapForInference(a, b typ.Type) bool { + intersection := subtype.NormalizeIntersection(a, b) + return intersection != nil && !typ.IsNever(intersection) +} + func isTopOrBottom(t typ.Type) bool { if t == nil { return true diff --git a/types/constraint/path.go b/types/constraint/path.go index a1375cb9..5c390001 100644 --- a/types/constraint/path.go +++ b/types/constraint/path.go @@ -36,7 +36,7 @@ type Segment struct { // Symbol provides SSA identity; Root is optional for display when Symbol is set. // When Symbol is non-zero, it is the primary identity for the path root. // Placeholder paths ($0, $1, etc.) use Root only with Symbol=0 and are -// substituted with concrete paths when applying function effects at call sites. +// substituted with concrete paths when applying function refinements at call sites. // // Examples: // - {Root: "x", Symbol: 5}: Variable x with symbol ID 5 @@ -63,7 +63,7 @@ func NewPath(sym cfg.SymbolID, name string) Path { return Path{Root: name, Symbol: sym} } -// NewPlaceholder creates a placeholder path for function effect parameters. +// NewPlaceholder creates a placeholder path for function refinement parameters. // Index 0 creates $0, index 1 creates $1, etc. // // Example: @@ -389,7 +389,7 @@ func (p Path) Less(other Path) bool { return false } -// IsPlaceholder returns true if this path is a placeholder (used in function effects). +// IsPlaceholder returns true if this path is a placeholder (used in function refinements). // Placeholders have Symbol == 0 and Root matching $0, $1, etc. func (p Path) IsPlaceholder() bool { return p.Symbol == 0 && p.PlaceholderIndex() >= 0 diff --git a/types/constraint/path_visit.go b/types/constraint/path_visit.go new file mode 100644 index 00000000..10523bc9 --- /dev/null +++ b/types/constraint/path_visit.go @@ -0,0 +1,86 @@ +package constraint + +// VisitPaths calls fn for each path referenced by c. +// +// It is the allocation-free counterpart to Constraint.Paths(). Callers must +// treat visited paths as read-only. +func VisitPaths(c Constraint, fn func(Path) bool) bool { + switch v := c.(type) { + case Truthy: + if fn(v.Path) { + return true + } + return visitParentFieldPath(v.Path, fn) + case Falsy: + if fn(v.Path) { + return true + } + return visitParentFieldPath(v.Path, fn) + case IsNil: + return fn(v.Path) + case NotNil: + return fn(v.Path) + case HasType: + return fn(v.Path) + case NotHasType: + return fn(v.Path) + case HasField: + return fn(v.Path) + case FieldEquals: + if fn(v.Target) { + return true + } + return visitParentFieldPath(v.Target, fn) + case FieldNotEquals: + if fn(v.Target) { + return true + } + return visitParentFieldPath(v.Target, fn) + case IndexEquals: + return fn(v.Target) + case IndexNotEquals: + return fn(v.Target) + case EqPath: + return fn(v.Left) || fn(v.Right) + case NotEqPath: + return fn(v.Left) || fn(v.Right) + case FieldEqualsPath: + return fn(v.Target) || fn(v.Value) + case FieldNotEqualsPath: + return fn(v.Target) || fn(v.Value) + case IndexEqualsPath: + return fn(v.Target) || fn(v.Value) + case IndexNotEqualsPath: + return fn(v.Target) || fn(v.Value) + case KeyOf: + return fn(v.Table) || fn(v.Key) + default: + return false + } +} + +// FirstPath returns the first path referenced by c, if any. +func FirstPath(c Constraint) (Path, bool) { + var first Path + ok := false + VisitPaths(c, func(p Path) bool { + first = p + ok = true + return true + }) + return first, ok +} + +func visitParentFieldPath(path Path, fn func(Path) bool) bool { + if len(path.Segments) == 0 { + return false + } + if path.Segments[len(path.Segments)-1].Kind != SegmentField { + return false + } + parent := Path{Root: path.Root, Symbol: path.Symbol} + if len(path.Segments) > 1 { + parent.Segments = path.Segments[:len(path.Segments)-1] + } + return fn(parent) +} diff --git a/types/constraint/solver.go b/types/constraint/solver.go index 9a4ad8dd..1f02f89d 100644 --- a/types/constraint/solver.go +++ b/types/constraint/solver.go @@ -176,13 +176,14 @@ func (s Solver) applyWithWorkSkipping(constraints []Constraint, out map[PathKey] func buildPathConstraintIndex(constraints []Constraint) map[PathKey][]Constraint { index := make(map[PathKey][]Constraint) for _, c := range constraints { - for _, p := range c.Paths() { + VisitPaths(c, func(p Path) bool { if p.IsEmpty() { - continue + return false } key := p.Key() index[key] = append(index[key], c) - } + return false + }) } return index } @@ -207,7 +208,6 @@ func collectAffectedConstraints(changedPaths map[PathKey]struct{}, pathConstrain // applyConstraintTrackChanges applies a constraint and tracks which paths changed. func applyConstraintTrackChanges(out *map[PathKey]typ.Type, env Env, c Constraint, changedPaths map[PathKey]struct{}) bool { // Snapshot types before applying - paths := c.Paths() var keys [4]PathKey var before [4]typ.Type count := 0 @@ -232,17 +232,18 @@ func applyConstraintTrackChanges(out *map[PathKey]typ.Type, env Env, c Constrain return } if overflow == nil { - overflow = make(map[PathKey]typ.Type, len(paths)-count) + overflow = make(map[PathKey]typ.Type, 4) } overflow[key] = t } - for _, p := range paths { + VisitPaths(c, func(p Path) bool { if p.IsEmpty() { - continue + return false } key := p.Key() addSnapshot(key) - } + return false + }) // Apply the constraint changed := applyConstraint(out, env, c) diff --git a/types/effect/codecs.go b/types/effect/codecs.go index e573eb45..bb78bd37 100644 --- a/types/effect/codecs.go +++ b/types/effect/codecs.go @@ -372,8 +372,9 @@ const ( returnTypeArrayOfCallback = 4 returnTypeSameAs = 5 returnTypeDeepElementOf = 6 - returnTypeSelectCaseOfParam = 7 - returnTypeSelectResultCases = 8 + returnTypeStringUnpackValue = 7 + returnTypeSelectCaseOfParam = 8 + returnTypeSelectResultCases = 9 ) func writeReturnType(w Writer, rt ReturnType) error { @@ -418,6 +419,12 @@ func writeReturnType(w Writer, rt ReturnType) error { } return w.WriteInt32(int32(v.Source.Index)) }, + StringUnpackValue: func(v StringUnpackValue) error { + if err := w.WriteByte(returnTypeStringUnpackValue); err != nil { + return err + } + return w.WriteInt32(int32(v.Format.Index)) + }, SelectCaseOfParam: func(v SelectCaseOfParam) error { if err := w.WriteByte(returnTypeSelectCaseOfParam); err != nil { return err @@ -490,6 +497,13 @@ func readReturnType(r Reader) (ReturnType, error) { } return DeepElementOf{Source: ParamRef{Index: int(idx)}}, nil + case returnTypeStringUnpackValue: + idx, err := r.ReadInt32() + if err != nil { + return nil, err + } + + return StringUnpackValue{Format: ParamRef{Index: int(idx)}}, nil case returnTypeSelectCaseOfParam: idx, err := r.ReadInt32() if err != nil { diff --git a/types/effect/codecs_test.go b/types/effect/codecs_test.go index 34a5e939..0e21d57b 100644 --- a/types/effect/codecs_test.go +++ b/types/effect/codecs_test.go @@ -951,6 +951,10 @@ func TestWriteReadReturnType(t *testing.T) { name: "deep element of", rt: DeepElementOf{Source: ParamRef{Index: 1}}, }, + { + name: "string unpack value", + rt: StringUnpackValue{Format: ParamRef{Index: 0}}, + }, { name: "select case of param", rt: SelectCaseOfParam{Source: ParamRef{Index: 1}}, @@ -1849,6 +1853,15 @@ func TestWriteReturnTypeErrorPaths(t *testing.T) { } }) + t.Run("string unpack value tag error", func(t *testing.T) { + w := &errorAfterNWriter{n: 0} + + err := writeReturnType(w, StringUnpackValue{Format: ParamRef{Index: 0}}) + if err == nil { + t.Error("expected error") + } + }) + t.Run("select case of param tag error", func(t *testing.T) { w := &errorAfterNWriter{n: 0} diff --git a/types/effect/label.go b/types/effect/label.go index c8975e40..089c6eb4 100644 --- a/types/effect/label.go +++ b/types/effect/label.go @@ -290,6 +290,9 @@ func (r ReturnLength) Equals(other Label) bool { // - DeepElementOf: Recursively extracts non-array leaf types. // For nested arrays, returns the innermost element type. // +// - StringUnpackValue: Returns the first unpacked value type derived from a +// string.pack/string.unpack format parameter when that format is known. +// // - SelectCaseOfParam: Builds select case from parameter type. // // - SelectResultOfCases: Builds select result from cases and default. @@ -379,6 +382,19 @@ func (d DeepElementOf) String() string { return fmt.Sprintf("deep_elem(%s)", d.Source) } +// StringUnpackValue derives the first unpacked value type from a format parameter. +// +// This models builtins like string.unpack where the first returned value depends +// on the literal format string supplied at the call site. +type StringUnpackValue struct { + Format ParamRef +} + +func (StringUnpackValue) returnType() {} +func (s StringUnpackValue) String() string { + return fmt.Sprintf("string_unpack(%s)", s.Format) +} + // Throw indicates the function may throw an error. type Throw struct{} @@ -776,6 +792,12 @@ func returnTypeEquals(a, b ReturnType) bool { } return false }, + StringUnpackValue: func(av StringUnpackValue) bool { + if bv, ok := b.(StringUnpackValue); ok { + return av.Format.Index == bv.Format.Index + } + return false + }, SelectCaseOfParam: func(av SelectCaseOfParam) bool { if bv, ok := b.(SelectCaseOfParam); ok { return av.Source.Index == bv.Source.Index diff --git a/types/effect/label_test.go b/types/effect/label_test.go index 05f080b4..14823692 100644 --- a/types/effect/label_test.go +++ b/types/effect/label_test.go @@ -252,6 +252,13 @@ func TestDeepElementOf_String(t *testing.T) { } } +func TestStringUnpackValue_String(t *testing.T) { + u := StringUnpackValue{Format: ParamRef{Index: 0}} + if got := u.String(); got != "string_unpack(param[0])" { + t.Errorf("StringUnpackValue.String() = %q", got) + } +} + func TestThrow(t *testing.T) { th := Throw{} if got := th.String(); got != "throw" { @@ -372,6 +379,7 @@ func TestReturnTypeInterface(t *testing.T) { ArrayOfCallbackReturn{}, SameAs{}, DeepElementOf{}, + StringUnpackValue{}, } for _, rt := range returnTypes { @@ -553,6 +561,7 @@ func TestMarkerMethods(t *testing.T) { ArrayOfCallbackReturn{}.returnType() SameAs{}.returnType() DeepElementOf{}.returnType() + StringUnpackValue{}.returnType() } func TestReturnLengthEqualsNonMatch(t *testing.T) { diff --git a/types/effect/visit.go b/types/effect/visit.go index 96933d20..a7ee0c93 100644 --- a/types/effect/visit.go +++ b/types/effect/visit.go @@ -286,6 +286,7 @@ type ReturnTypeVisitor[R any] struct { ArrayOfCallbackReturn func(ArrayOfCallbackReturn) R SameAs func(SameAs) R DeepElementOf func(DeepElementOf) R + StringUnpackValue func(StringUnpackValue) R SelectCaseOfParam func(SelectCaseOfParam) R SelectResultOfCases func(SelectResultOfCases) R Default func(ReturnType) R @@ -342,6 +343,14 @@ func VisitReturnType[R any](t ReturnType, v ReturnTypeVisitor[R]) R { if v.DeepElementOf != nil { return v.DeepElementOf(*tt) } + case StringUnpackValue: + if v.StringUnpackValue != nil { + return v.StringUnpackValue(tt) + } + case *StringUnpackValue: + if v.StringUnpackValue != nil { + return v.StringUnpackValue(*tt) + } case SelectCaseOfParam: if v.SelectCaseOfParam != nil { return v.SelectCaseOfParam(tt) diff --git a/types/flow/domain/shape_domain.go b/types/flow/domain/shape_domain.go index daffe6ca..96fd6e63 100644 --- a/types/flow/domain/shape_domain.go +++ b/types/flow/domain/shape_domain.go @@ -116,14 +116,14 @@ func (d *ShapeDomain) ApplyConstraint(c constraint.Constraint, target constraint } // Propagate to ALL ancestor paths for nested field constraints - paths := c.Paths() - if len(paths) > 0 && len(paths[0].Segments) > 0 { + path, ok := constraint.FirstPath(c) + if ok && len(path.Segments) > 0 { // Walk up the path tree, propagating to each ancestor - for depth := len(paths[0].Segments) - 1; depth >= 0; depth-- { + for depth := len(path.Segments) - 1; depth >= 0; depth-- { ancestorPath := constraint.Path{ - Root: paths[0].Root, - Symbol: paths[0].Symbol, - Segments: paths[0].Segments[:depth], + Root: path.Root, + Symbol: path.Symbol, + Segments: path.Segments[:depth], } if d.Env.ResolvePath == nil { continue diff --git a/types/flow/inference.go b/types/flow/inference.go index 98c686f0..54287bc3 100644 --- a/types/flow/inference.go +++ b/types/flow/inference.go @@ -1,11 +1,11 @@ -// inference.go implements function effect inference from flow analysis results. +// inference.go implements function refinement inference from flow analysis results. // -// Function effects describe how a function's return value relates to its parameters. -// For predicate functions (returning boolean), effects capture when parameters are +// Function refinements describe how a function's return value relates to its parameters. +// For predicate functions (returning boolean), refinements capture when parameters are // narrowed based on the return value (OnTrue/OnFalse). For assert-style functions, -// effects capture constraints that hold after the function returns (OnReturn). +// refinements capture constraints that hold after the function returns (OnReturn). // -// Effects enable interprocedural narrowing: calling a predicate function in a +// Refinements enable interprocedural narrowing: calling a predicate function in a // conditional allows the checker to narrow argument types in the appropriate branch. package flow @@ -29,7 +29,7 @@ type ParamInfo struct { Type typ.Type } -// InferFunctionEffect computes a FunctionEffect from solved flow analysis. +// InferFunctionRefinement computes a FunctionRefinement from solved flow analysis. // // This is the post-flow variant that uses the complete flow solution to determine // which constraints hold at each return point. It examines return points to compute: @@ -51,26 +51,26 @@ type ParamInfo struct { // Example: For function `function is_string(x) return type(x) == "string" end`: // - OnTrue: HasType($0, string) // - OnFalse: NotHasType($0, string) -func InferFunctionEffect( +func InferFunctionRefinement( solution *Solution, g *cfg.CFG, params []ParamInfo, returnType typ.Type, -) *constraint.FunctionEffect { +) *constraint.FunctionRefinement { if solution == nil || g == nil { return nil } - src := effectSource{conditionAt: solution.ConditionAt} + src := refinementSource{conditionAt: solution.ConditionAt} if solution.inputs != nil { src.returnConstraints = solution.inputs.ReturnConstraints src.returnKinds = solution.inputs.ReturnKinds } - return inferFunctionEffectCore(src, g, params, returnType) + return inferFunctionRefinementCore(src, g, params, returnType) } -// InferFunctionEffectFromInputs computes a FunctionEffect without running full flow analysis. +// InferFunctionRefinementFromInputs computes a FunctionRefinement without running full flow analysis. // // This is the pre-flow variant that uses only the extracted return constraints without // propagating conditions through the CFG. It produces conservative effects based on: @@ -80,27 +80,27 @@ func InferFunctionEffect( // // The pre-flow variant is faster but less precise than post-flow inference because // it cannot account for path conditions from prior conditionals. It's suitable for -// bootstrapping effect extraction before the full type checking pass. +// bootstrapping refinement extraction before the full type checking pass. // // Example: For function `function assert_string(x) assert(type(x) == "string") end`: // - OnReturn: HasType($0, string) (from assert expression) // - Terminates: false (has implicit return via exit) -func InferFunctionEffectFromInputs( +func InferFunctionRefinementFromInputs( inputs *Inputs, g *cfg.CFG, params []ParamInfo, returnType typ.Type, -) *constraint.FunctionEffect { +) *constraint.FunctionRefinement { if inputs == nil || g == nil { return nil } - src := effectSource{returnConstraints: inputs.ReturnConstraints} + src := refinementSource{returnConstraints: inputs.ReturnConstraints} - return inferFunctionEffectCore(src, g, params, returnType) + return inferFunctionRefinementCore(src, g, params, returnType) } -// effectSource abstracts the data sources for effect inference. +// refinementSource abstracts the data sources for refinement inference. // // This struct allows the same inference algorithm to work with both pre-flow // and post-flow data by providing optional access to different data sources: @@ -116,13 +116,13 @@ func InferFunctionEffectFromInputs( // // Nil fields indicate the data is unavailable, causing the inference to skip // the corresponding analysis step. -type effectSource struct { +type refinementSource struct { returnConstraints map[cfg.Point]ReturnExprConstraints returnKinds map[cfg.Point]ReturnKind conditionAt func(cfg.Point) constraint.Condition } -// inferFunctionEffectCore is the shared implementation for effect inference. +// inferFunctionRefinementCore is the shared implementation for effect inference. // // This method walks the CFG to find all return and exit points, collecting // constraints that hold at each. The algorithm: @@ -135,14 +135,14 @@ type effectSource struct { // 6. Substitute parameter names/symbols with placeholders ($0, $1, ...) // 7. Detect terminating functions (no return nodes, exit unreachable) // -// The effectSource parameter routes queries to the appropriate data source +// The refinementSource parameter routes queries to the appropriate data source // (pre-flow inputs vs. post-flow solution). -func inferFunctionEffectCore( - src effectSource, +func inferFunctionRefinementCore( + src refinementSource, g *cfg.CFG, params []ParamInfo, returnType typ.Type, -) *constraint.FunctionEffect { +) *constraint.FunctionRefinement { // Build param Symbol -> index and name -> index maps paramIndex := make(map[cfg.SymbolID]int) paramNameIndex := make(map[string]int) @@ -253,7 +253,7 @@ func inferFunctionEffectCore( exitHasPredecessors := len(graphPredecessors(g, g.Exit())) > 0 terminates := !hasReturnNode && !exitHasPredecessors - eff := &constraint.FunctionEffect{ + eff := &constraint.FunctionRefinement{ OnTrue: onTrueCond, OnFalse: onFalseCond, OnReturn: onReturnCond, diff --git a/types/flow/inference_test.go b/types/flow/inference_test.go index da1cefbb..c487b9b7 100644 --- a/types/flow/inference_test.go +++ b/types/flow/inference_test.go @@ -279,7 +279,7 @@ func TestConstraintPropagation_Loop(_ *testing.T) { _ = solver.ConditionAt(ret) } -func TestInferFunctionEffect_Empty(t *testing.T) { +func TestInferFunctionRefinement_Empty(t *testing.T) { g := cfg.New() ret := g.AddNode(cfg.NodeReturn, cfg.SymbolID(0), "") g.AddEdge(g.Entry(), ret, false) @@ -289,13 +289,13 @@ func TestInferFunctionEffect_Empty(t *testing.T) { inputs := &Inputs{Graph: mock} solver := Solve(inputs, testResolver()) - effect := InferFunctionEffect(solver, g, nil, typ.Boolean) + effect := InferFunctionRefinement(solver, g, nil, typ.Boolean) if effect != nil { t.Error("expected nil effect for empty constraints") } } -func TestInferFunctionEffect_NonBoolean(t *testing.T) { +func TestInferFunctionRefinement_NonBoolean(t *testing.T) { g := cfg.New() ret := g.AddNode(cfg.NodeReturn, cfg.SymbolID(0), "") g.AddEdge(g.Entry(), ret, false) @@ -306,7 +306,7 @@ func TestInferFunctionEffect_NonBoolean(t *testing.T) { solver := Solve(inputs, testResolver()) // Non-boolean return type should not produce OnTrue/OnFalse - effect := InferFunctionEffect(solver, g, []ParamInfo{{Name: "x", Symbol: 100, Type: typ.Any}}, typ.String) + effect := InferFunctionRefinement(solver, g, []ParamInfo{{Name: "x", Symbol: 100, Type: typ.Any}}, typ.String) if effect != nil { t.Error("expected nil effect for non-boolean return type") } diff --git a/types/flow/join/join.go b/types/flow/join/join.go index bd877c1d..34bb10d6 100644 --- a/types/flow/join/join.go +++ b/types/flow/join/join.go @@ -28,6 +28,7 @@ package join import ( + "github.com/wippyai/go-lua/internal" "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" @@ -273,42 +274,52 @@ func CoalesceRecordMapComponents(types []typ.Type) []typ.Type { // Group records by field signature type recGroup struct { - indices []int - records []*typ.Record + template *typ.Record + indices []int + records []*typ.Record } - groups := make(map[string]*recGroup) + groups := make(map[uint64][]*recGroup) for i, t := range types { rec, ok := t.(*typ.Record) if !ok || len(rec.Fields) == 0 { continue } - // Only group records that have a map component (at least one in the set) - // Build a signature from field names, types, and flags - sig := recordFieldSignature(rec) - g, exists := groups[sig] - if !exists { - g = &recGroup{} - groups[sig] = g + sigHash := recordFieldSignatureHash(rec) + var group *recGroup + for _, candidate := range groups[sigHash] { + if sameRecordFieldSignature(candidate.template, rec) { + group = candidate + break + } } - g.indices = append(g.indices, i) - g.records = append(g.records, rec) + if group == nil { + group = &recGroup{template: rec} + groups[sigHash] = append(groups[sigHash], group) + } + group.indices = append(group.indices, i) + group.records = append(group.records, rec) } // Check if any group has records with map components to merge needsMerge := false - for _, g := range groups { - if len(g.records) < 2 { - continue - } - hasMap := false - for _, r := range g.records { - if r.HasMapComponent() { - hasMap = true + for _, bucket := range groups { + for _, g := range bucket { + if len(g.records) < 2 { + continue + } + hasMap := false + for _, r := range g.records { + if r.HasMapComponent() { + hasMap = true + break + } + } + if hasMap { + needsMerge = true break } } - if hasMap { - needsMerge = true + if needsMerge { break } } @@ -319,67 +330,69 @@ func CoalesceRecordMapComponents(types []typ.Type) []typ.Type { // Build result replacing merged groups skip := make(map[int]bool) result := make([]typ.Type, 0, len(types)) - for sig, g := range groups { - _ = sig - if len(g.records) < 2 { - continue - } - hasMap := false - for _, r := range g.records { - if r.HasMapComponent() { - hasMap = true - break + for _, bucket := range groups { + for _, g := range bucket { + if len(g.records) < 2 { + continue } - } - if !hasMap { - continue - } - // Merge map components - var mapKey, mapValue typ.Type - for _, r := range g.records { - if !r.HasMapComponent() { + hasMap := false + for _, r := range g.records { + if r.HasMapComponent() { + hasMap = true + break + } + } + if !hasMap { continue } - if mapKey == nil { - mapKey = r.MapKey - mapValue = r.MapValue - } else { - mapKey = Types(mapKey, r.MapKey) - mapValue = Types(mapValue, r.MapValue) + + // Merge map components + var mapKey, mapValue typ.Type + for _, r := range g.records { + if !r.HasMapComponent() { + continue + } + if mapKey == nil { + mapKey = r.MapKey + mapValue = r.MapValue + } else { + mapKey = Types(mapKey, r.MapKey) + mapValue = Types(mapValue, r.MapValue) + } } - } - // Use the first record as the template - template := g.records[0] - builder := typ.NewRecord() - if template.Open { - builder.SetOpen(true) - } - for _, f := range template.Fields { - switch { - case f.Optional && f.Readonly: - builder.OptReadonlyField(f.Name, f.Type) - case f.Optional: - builder.OptField(f.Name, f.Type) - case f.Readonly: - builder.ReadonlyField(f.Name, f.Type) - default: - builder.Field(f.Name, f.Type) + // Use the first record as the template + template := g.template + builder := typ.NewRecord() + if template.Open { + builder.SetOpen(true) } - } - if template.Metatable != nil { - builder.Metatable(template.Metatable) - } - if mapKey != nil && mapValue != nil { - builder.MapComponent(mapKey, mapValue) - } - merged := builder.Build() + for _, f := range template.Fields { + switch { + case f.Optional && f.Readonly: + builder.OptReadonlyField(f.Name, f.Type) + case f.Optional: + builder.OptField(f.Name, f.Type) + case f.Readonly: + builder.ReadonlyField(f.Name, f.Type) + default: + builder.Field(f.Name, f.Type) + } + } + if template.Metatable != nil { + builder.Metatable(template.Metatable) + } + if mapKey != nil && mapValue != nil { + builder.MapComponent(mapKey, mapValue) + } + merged := builder.Build() - // Mark all indices in this group for replacement - for _, idx := range g.indices { - skip[idx] = true + // Mark all indices in this group for replacement + for _, idx := range g.indices { + skip[idx] = true + } + // Add merged record at the position of the first occurrence + result = append(result, merged) } - // Add merged record at the position of the first occurrence - result = append(result, merged) } if len(skip) == 0 { @@ -398,33 +411,44 @@ func CoalesceRecordMapComponents(types []typ.Type) []typ.Type { return final } -// recordFieldSignature computes a canonical string signature for a record's fields. -// -// The signature includes for each field: name, optional flag (?), readonly flag (!), -// and type string. Records with identical signatures have the same fields with the -// same types, and their map components can be merged. -// -// The signature also includes the open flag to distinguish {x: T} from {x: T, ...}. -func recordFieldSignature(r *typ.Record) string { - if len(r.Fields) == 0 { - return "" +func sameRecordFieldSignature(a, b *typ.Record) bool { + if a == nil || b == nil { + return a == b + } + if a.Open != b.Open || len(a.Fields) != len(b.Fields) { + return false + } + for i, af := range a.Fields { + bf := b.Fields[i] + if af.Name != bf.Name || af.Optional != bf.Optional || af.Readonly != bf.Readonly { + return false + } + if !typ.TypeEquals(af.Type, bf.Type) { + return false + } + } + return true +} + +func recordFieldSignatureHash(r *typ.Record) uint64 { + if r == nil { + return 0 + } + h := internal.HashCombine(uint64(kind.Record), uint64(len(r.Fields))) + if r.Open { + h = internal.HashCombine(h, 1) } - // Fields are already sorted by name in Record.Build() - sig := "" for _, f := range r.Fields { - sig += f.Name + ":" + h = internal.HashCombine(h, internal.FnvString(f.Name)) if f.Optional { - sig += "?" + h = internal.HashCombine(h, 2) } if f.Readonly { - sig += "!" + h = internal.HashCombine(h, 3) } - sig += f.Type.String() + ";" - } - if r.Open { - sig += "..." + h = internal.HashCombine(h, f.Type.Hash()) } - return sig + return h } // CoalesceEmptyRecordWithMap removes empty records when maps are present. diff --git a/types/flow/join/join_test.go b/types/flow/join/join_test.go index db384634..99562a9a 100644 --- a/types/flow/join/join_test.go +++ b/types/flow/join/join_test.go @@ -78,6 +78,52 @@ func TestCoalesceMaps_MultipleMaps(t *testing.T) { } } +func TestCoalesceRecordMapComponents_MergesMatchingFieldShapes(t *testing.T) { + left := typ.NewRecord(). + Field("kind", typ.String). + Field("handler", typ.Func().Returns(typ.String).Build()). + MapComponent(typ.String, typ.Number). + Build() + right := typ.NewRecord(). + Field("kind", typ.String). + Field("handler", typ.Func().Returns(typ.String).Build()). + MapComponent(typ.String, typ.Boolean). + Build() + + result := CoalesceRecordMapComponents([]typ.Type{left, right}) + if len(result) != 1 { + t.Fatalf("expected 1 merged record, got %d", len(result)) + } + rec, ok := result[0].(*typ.Record) + if !ok { + t.Fatalf("expected record, got %T", result[0]) + } + if !rec.HasMapComponent() { + t.Fatal("expected merged record to keep map component") + } + if _, ok := rec.MapValue.(*typ.Union); !ok { + t.Fatalf("expected merged map value union, got %T", rec.MapValue) + } +} + +func TestCoalesceRecordMapComponents_DoesNotMergeDifferentFieldTypes(t *testing.T) { + left := typ.NewRecord(). + Field("kind", typ.String). + Field("handler", typ.Func().Returns(typ.String).Build()). + MapComponent(typ.String, typ.Number). + Build() + right := typ.NewRecord(). + Field("kind", typ.String). + Field("handler", typ.Func().Returns(typ.Integer).Build()). + MapComponent(typ.String, typ.Boolean). + Build() + + result := CoalesceRecordMapComponents([]typ.Type{left, right}) + if len(result) != 2 { + t.Fatalf("expected distinct records to remain separate, got %d", len(result)) + } +} + func TestCoalesceEmptyRecordWithMap_NoEmptyRecord(t *testing.T) { m := typ.NewMap(typ.String, typ.Number) rec := typ.NewRecord().Field("x", typ.Number).Build() diff --git a/types/flow/pathkey/key.go b/types/flow/pathkey/key.go index 7b4f8e49..b31dcaa7 100644 --- a/types/flow/pathkey/key.go +++ b/types/flow/pathkey/key.go @@ -6,9 +6,21 @@ package pathkey import ( + "sync" + "github.com/wippyai/go-lua/types/constraint" ) +type eqPair struct { + left, right constraint.Path +} + +var parseSuffixCache sync.Map + +type pathSet struct { + paths map[uint64][]constraint.Path +} + // SegmentsSuffix converts path segments to a suffix string for key construction. // // The suffix format matches Lua syntax: @@ -39,6 +51,20 @@ func ParseSuffix(suffix string) []constraint.Segment { if suffix == "" { return nil } + if cached, ok := parseSuffixCache.Load(suffix); ok { + return cached.([]constraint.Segment) + } + segs := parseSuffixSlow(suffix) + if segs == nil { + return nil + } + if cached, loaded := parseSuffixCache.LoadOrStore(suffix, segs); loaded { + return cached.([]constraint.Segment) + } + return segs +} + +func parseSuffixSlow(suffix string) []constraint.Segment { var segs []constraint.Segment i := 0 for i < len(suffix) { @@ -202,26 +228,23 @@ func FilterConstraintsForPath(constraints []constraint.Constraint, target constr return constraints } - equivalentPaths := CollectEquivalentPaths(constraints, target) - // If target is used in Field/Index EqualsPath constraints, keep constraints on the value path too. - relatedValuePaths := make(map[constraint.PathKey]struct{}) - for _, c := range constraints { - constraint.VisitConstraint(c, constraint.ConstraintVisitor[struct{}]{ - FieldEqualsPath: func(v constraint.FieldEqualsPath) struct{} { - if PathRelated(target, v.Target) { - relatedValuePaths[v.Value.Key()] = struct{}{} - } - return struct{}{} - }, - IndexEqualsPath: func(v constraint.IndexEqualsPath) struct{} { - if PathRelated(target, v.Target) { - relatedValuePaths[v.Value.Key()] = struct{}{} - } - return struct{}{} - }, - }) + pairs, relatedValuePaths := collectPathFilterFacts(constraints, target) + if len(pairs) == 0 && relatedValuePaths.len() == 0 { + var filtered []constraint.Constraint + for _, c := range constraints { + if shouldDropAsymmetricNotEquals(c, target) { + continue + } + if constraintAnyPathMatches(c, func(p constraint.Path) bool { + return PathRelated(target, p) + }) { + filtered = append(filtered, c) + } + } + return filtered } + equivalentPaths := collectEquivalentPathsFromPairs(pairs, target) var filtered []constraint.Constraint for _, c := range constraints { // Field/Index NotEquals constraints are asymmetric: they narrow the target, @@ -233,12 +256,10 @@ func FilterConstraintsForPath(constraints []constraint.Constraint, target constr if PathRelated(target, p) { return true } - key := p.Key() - if equivalentPaths[key] { + if equivalentPaths.has(p) { return true } - _, ok := relatedValuePaths[key] - return ok + return relatedValuePaths.has(p) }) { filtered = append(filtered, c) } @@ -259,50 +280,54 @@ func FilterConstraintsForPath(constraints []constraint.Constraint, target constr // // Returns a set of PathKeys that are transitively equivalent to target. func CollectEquivalentPaths(constraints []constraint.Constraint, target constraint.Path) map[constraint.PathKey]bool { + pairs, _ := collectPathFilterFacts(constraints, constraint.Path{}) + set := collectEquivalentPathsFromPairs(pairs, target) result := make(map[constraint.PathKey]bool) - result[target.Key()] = true + set.visit(func(path constraint.Path) { + result[path.Key()] = true + }) + return result +} - type eqPair struct { - left, right constraint.Path - leftKey, rightKey constraint.PathKey - } +func collectPathFilterFacts(constraints []constraint.Constraint, target constraint.Path) ([]eqPair, *pathSet) { var pairs []eqPair + relatedValuePaths := newPathSet() for _, c := range constraints { - constraint.VisitConstraint(c, constraint.ConstraintVisitor[struct{}]{ - EqPath: func(v constraint.EqPath) struct{} { - pairs = append(pairs, eqPair{ - left: v.Left, - right: v.Right, - leftKey: v.Left.Key(), - rightKey: v.Right.Key(), - }) - return struct{}{} - }, - FieldEqualsPath: func(v constraint.FieldEqualsPath) struct{} { - fieldPath := v.Target.Append(constraint.Segment{Kind: constraint.SegmentField, Name: v.Field}) - if !fieldPath.IsEmpty() { - pairs = append(pairs, eqPair{ - left: fieldPath, - right: v.Value, - leftKey: fieldPath.Key(), - rightKey: v.Value.Key(), - }) - } - return struct{}{} - }, - IndexEqualsPath: func(v constraint.IndexEqualsPath) struct{} { + switch v := c.(type) { + case constraint.EqPath: + pairs = append(pairs, eqPair{ + left: v.Left, + right: v.Right, + }) + case constraint.FieldEqualsPath: + fieldPath := v.Target.Append(constraint.Segment{Kind: constraint.SegmentField, Name: v.Field}) + if !fieldPath.IsEmpty() { pairs = append(pairs, eqPair{ - left: v.Target, - right: v.Value, - leftKey: v.Target.Key(), - rightKey: v.Value.Key(), + left: fieldPath, + right: v.Value, }) - return struct{}{} - }, - }) + } + if !target.IsEmpty() && PathRelated(target, v.Target) { + relatedValuePaths.add(v.Value) + } + case constraint.IndexEqualsPath: + pairs = append(pairs, eqPair{ + left: v.Target, + right: v.Value, + }) + if !target.IsEmpty() && PathRelated(target, v.Target) { + relatedValuePaths.add(v.Value) + } + } } + return pairs, relatedValuePaths +} + +func collectEquivalentPathsFromPairs(pairs []eqPair, target constraint.Path) *pathSet { + result := newPathSet() + result.add(target) if len(pairs) == 0 { return result } @@ -311,21 +336,71 @@ func CollectEquivalentPaths(constraints []constraint.Constraint, target constrai for changed { changed = false for _, pair := range pairs { - leftIn := result[pair.leftKey] - rightIn := result[pair.rightKey] + leftIn := result.has(pair.left) + rightIn := result.has(pair.right) if leftIn && !rightIn { - result[pair.rightKey] = true - changed = true + changed = result.add(pair.right) || changed } else if rightIn && !leftIn { - result[pair.leftKey] = true - changed = true + changed = result.add(pair.left) || changed } } } - return result } +func newPathSet() *pathSet { + return &pathSet{paths: make(map[uint64][]constraint.Path)} +} + +func (s *pathSet) add(path constraint.Path) bool { + if s == nil || path.IsEmpty() { + return false + } + hash := path.Hash() + bucket := s.paths[hash] + for _, existing := range bucket { + if existing.Equal(path) { + return false + } + } + s.paths[hash] = append(bucket, path) + return true +} + +func (s *pathSet) has(path constraint.Path) bool { + if s == nil || path.IsEmpty() { + return false + } + for _, existing := range s.paths[path.Hash()] { + if existing.Equal(path) { + return true + } + } + return false +} + +func (s *pathSet) len() int { + if s == nil { + return 0 + } + total := 0 + for _, bucket := range s.paths { + total += len(bucket) + } + return total +} + +func (s *pathSet) visit(fn func(constraint.Path)) { + if s == nil { + return + } + for _, bucket := range s.paths { + for _, path := range bucket { + fn(path) + } + } +} + func shouldDropAsymmetricNotEquals(c constraint.Constraint, target constraint.Path) bool { return constraint.VisitConstraint(c, constraint.ConstraintVisitor[bool]{ FieldNotEqualsPath: func(v constraint.FieldNotEqualsPath) bool { diff --git a/types/flow/pathkey/resolver.go b/types/flow/pathkey/resolver.go index 4ffc8144..957bfb17 100644 --- a/types/flow/pathkey/resolver.go +++ b/types/flow/pathkey/resolver.go @@ -34,13 +34,25 @@ type VersionedGraph interface { // // - Querying the CFG for visible SSA versions // - Building canonical key strings in the "sym@" format -// - Handling placeholder paths (used in function effects) +// - Handling placeholder paths (used in function refinements) // - Validating that paths have resolvable versions // // Using the resolver ensures that the same path at the same point always // produces the same key, enabling correct value lookup and constraint matching. type Resolver struct { - graph VersionedGraph + graph VersionedGraph + root map[versionedRootKey]constraint.PathKey + single map[singleSegmentKey]constraint.PathKey +} + +type versionedRootKey struct { + sym cfg.SymbolID + version int +} + +type singleSegmentKey struct { + root versionedRootKey + seg constraint.Segment } // NewResolver creates a resolver bound to a versioned graph. @@ -48,7 +60,11 @@ type Resolver struct { // The graph must provide SSA version lookup for all symbols that will be // resolved. Typically this is a cfg.VersionedGraph from CFG construction. func NewResolver(g VersionedGraph) *Resolver { - return &Resolver{graph: g} + return &Resolver{ + graph: g, + root: make(map[versionedRootKey]constraint.PathKey), + single: make(map[singleSegmentKey]constraint.PathKey), + } } // KeyAt returns the canonical key for a path at a CFG point. @@ -103,13 +119,47 @@ func (r *Resolver) KeyAtVersion(sym cfg.SymbolID, versionID int, segments []cons // - Parseable back to components via ParseKey // - Sortable in a meaningful order (by symbol, then version) func (r *Resolver) buildKey(sym cfg.SymbolID, versionID int, segments []constraint.Segment) constraint.PathKey { + rootKey := r.rootKey(sym, versionID) + if len(segments) == 0 { + return rootKey + } + if len(segments) == 1 { + cacheKey := singleSegmentKey{ + root: versionedRootKey{sym: sym, version: versionID}, + seg: segments[0], + } + if cached, ok := r.single[cacheKey]; ok { + return cached + } + b := strings.Builder{} + b.Grow(len(rootKey) + segmentStringLen(segments[0])) + b.WriteString(string(rootKey)) + appendSegments(&b, segments) + key := constraint.PathKey(b.String()) + r.single[cacheKey] = key + return key + } + var b strings.Builder + b.Grow(len(rootKey) + segmentsStringLen(segments)) + b.WriteString(string(rootKey)) + appendSegments(&b, segments) + return constraint.PathKey(b.String()) +} + +func (r *Resolver) rootKey(sym cfg.SymbolID, versionID int) constraint.PathKey { + cacheKey := versionedRootKey{sym: sym, version: versionID} + if cached, ok := r.root[cacheKey]; ok { + return cached + } var b strings.Builder + b.Grow(16) b.WriteString("sym") writeUint(&b, uint64(sym)) b.WriteByte('@') writeInt(&b, versionID) - appendSegments(&b, segments) - return constraint.PathKey(b.String()) + key := constraint.PathKey(b.String()) + r.root[cacheKey] = key + return key } func buildPlaceholderKey(root string, segments []constraint.Segment) constraint.PathKey { @@ -138,6 +188,47 @@ func appendSegments(b *strings.Builder, segments []constraint.Segment) { } } +func segmentsStringLen(segments []constraint.Segment) int { + total := 0 + for _, seg := range segments { + total += segmentStringLen(seg) + } + return total +} + +func segmentStringLen(seg constraint.Segment) int { + switch seg.Kind { + case constraint.SegmentField: + return 1 + len(seg.Name) + case constraint.SegmentIndexString: + escaped := 0 + for i := 0; i < len(seg.Name); i++ { + switch seg.Name[i] { + case '\\', '"': + escaped++ + } + } + return 4 + len(seg.Name) + escaped + case constraint.SegmentIndexInt: + n := seg.Index + if n == 0 { + return 3 + } + digits := 0 + if n < 0 { + digits++ + n = -n + } + for n > 0 { + digits++ + n /= 10 + } + return digits + 2 + default: + return 0 + } +} + func writeQuotedSegmentIndex(b *strings.Builder, key string) { b.WriteString("[\"") for i := 0; i < len(key); i++ { diff --git a/types/flow/product_domain.go b/types/flow/product_domain.go index c214179e..0f653f16 100644 --- a/types/flow/product_domain.go +++ b/types/flow/product_domain.go @@ -146,14 +146,14 @@ func (d *ProductDomain) ApplyAtom(atom constraint.Atom) bool { // // Returns false if the constraint proves the domain unsatisfiable. func (d *ProductDomain) ApplyLeftoverConstraint(c constraint.Constraint) bool { - paths := c.Paths() - if len(paths) == 0 { + path, ok := constraint.FirstPath(c) + if !ok { return true } if d.env.ResolvePath == nil { return true } - target := d.env.ResolvePath(paths[0]) + target := d.env.ResolvePath(path) if target == "" { return true } @@ -211,7 +211,7 @@ func (d *ProductDomain) ApplyConjunction(constraints []constraint.Constraint) bo // Update PathTypeAt to include Type domain narrowings for leftover constraints originalPathTypeAt := d.Shape.Solver.Env.PathTypeAt - d.Shape.Solver.Env.PathTypeAt = func(key constraint.PathKey) typ.Type { + narrowedPathTypeAt := func(key constraint.PathKey) typ.Type { if narrowed := d.Type.NarrowedTypeAt(key); narrowed != nil { return narrowed } @@ -220,6 +220,8 @@ func (d *ProductDomain) ApplyConjunction(constraints []constraint.Constraint) bo } return nil } + d.Shape.Solver.Env.PathTypeAt = narrowedPathTypeAt + d.Shape.Env.PathTypeAt = narrowedPathTypeAt // Apply leftovers (Shape domain via wrapped Solver) for _, c := range result.Leftover { @@ -267,12 +269,13 @@ func (d *ProductDomain) buildCongruenceClosure(atoms []constraint.Atom, constrai resolve := d.env.ResolvePath for _, c := range constraints { - for _, path := range c.Paths() { + constraint.VisitPaths(c, func(path constraint.Path) bool { key := resolve(path) if key != "" { d.EGraph.RegisterKey(key) } - } + return false + }) } for _, c := range constraints { diff --git a/types/flow/product_domain_test.go b/types/flow/product_domain_test.go new file mode 100644 index 00000000..ed926c08 --- /dev/null +++ b/types/flow/product_domain_test.go @@ -0,0 +1,67 @@ +package flow + +import ( + "testing" + + "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" +) + +func TestProductDomain_ApplyCondition_UsesTypeNarrowedBaseForLeftoverFieldConstraint(t *testing.T) { + allow := typ.NewAlias("Allow", typ.NewRecord(). + Field("kind", typ.LiteralString("allow")). + Field("reason", typ.String). + Build()) + deny := typ.NewAlias("Deny", typ.NewRecord(). + Field("kind", typ.LiteralString("deny")). + Field("reason", typ.String). + Build()) + deferType := typ.NewAlias("Defer", typ.NewRecord(). + Field("kind", typ.LiteralString("defer")). + Field("queue", typ.String). + Build()) + decision := typ.NewAlias("Decision", typ.NewUnion(allow, deny, deferType)) + baseType := typ.NewOptional(decision) + + path := constraint.Path{Root: "decision", Symbol: 1} + key := path.Key() + env := constraint.Env{ + PathTypeAt: func(k constraint.PathKey) typ.Type { + if k == key { + return baseType + } + return nil + }, + ResolvePath: func(p constraint.Path) constraint.PathKey { + return p.Key() + }, + Resolver: &core.FuncResolver{ + FieldFunc: core.Field, + }, + } + + dom := NewProductDomain(env) + ok := dom.ApplyCondition(constraint.FromConstraints( + constraint.Truthy{Path: path}, + constraint.FieldEquals{Target: path, Field: "kind", Value: typ.LiteralString("defer")}, + )) + if !ok { + t.Fatal("ApplyCondition returned false") + } + + got := dom.TypeAt(key) + if got == nil { + t.Fatal("TypeAt(decision) returned nil") + } + if subtype.IsSubtype(typ.Nil, got) { + t.Fatalf("TypeAt(decision) = %v, want non-nil Defer variant", got) + } + if queue, ok := core.Field(got, "queue"); !ok || !typ.TypeEquals(queue, typ.String) { + t.Fatalf("queue field = %v, want string on narrowed Defer variant", queue) + } + if _, ok := core.Field(got, "reason"); ok { + t.Fatalf("reason field should not remain on narrowed Defer variant: %v", got) + } +} diff --git a/types/flow/propagate/propagate.go b/types/flow/propagate/propagate.go index 00537649..db7ea994 100644 --- a/types/flow/propagate/propagate.go +++ b/types/flow/propagate/propagate.go @@ -337,15 +337,16 @@ func FilterConditionSymbols(cond constraint.Condition, syms []cfg.SymbolID) cons var kept []constraint.Constraint for _, c := range d { shouldKeep := true - for _, p := range c.Paths() { + constraint.VisitPaths(c, func(p constraint.Path) bool { if p.Symbol == 0 { - continue + return false } if _, ok := symSet[p.Symbol]; ok { shouldKeep = false - break + return true } - } + return false + }) if shouldKeep { kept = append(kept, c) } @@ -395,20 +396,18 @@ func KillRedefinedConditions(cond constraint.Condition, p cfg.Point, assignments var kept []constraint.Constraint for _, c := range d { shouldKeep := true - for _, cpath := range c.Paths() { + constraint.VisitPaths(c, func(cpath constraint.Path) bool { if cpath.Symbol == 0 { - continue + return false } for _, ap := range assignedPaths { if PathAffectedByAssignment(cpath, ap.TargetSym, ap.TargetSegs) { shouldKeep = false - break + return true } } - if !shouldKeep { - break - } - } + return false + }) if shouldKeep { kept = append(kept, c) } diff --git a/types/flow/query.go b/types/flow/query.go index 67f45b87..85030863 100644 --- a/types/flow/query.go +++ b/types/flow/query.go @@ -76,15 +76,19 @@ func (s *Solution) TypeAt(p cfg.Point, path constraint.Path) typ.Type { } if full != nil && full.Kind().IsPlaceholder() && derived != nil { - return derived + full = derived } if derived != nil && derived.Kind().IsPlaceholder() && full != nil { - return full + derived = full } + + var candidate typ.Type if full != nil { - return full + candidate = full + } else { + candidate = derived } - return derived + return candidate } // ConditionAt returns the full DNF condition at a CFG point. @@ -552,21 +556,25 @@ func (s *Solution) applyConstraints(p cfg.Point, baseType typ.Type, path constra return nil } - // Look up narrowed type using canonical key + current := baseType if narrowed := dom.TypeAt(canonicalKey); narrowed != nil { - return narrowed + current = narrowed } if narrowed, ok := s.deriveFromNarrowedAncestors(canonicalKey, dom); ok { - return narrowed + if current == nil { + current = narrowed + } else { + current = narrow.Intersect(current, narrowed) + } } childNarrowings := dom.NarrowedChildPaths(canonicalKey) if len(childNarrowings) > 0 { - return s.filterByChildNarrowings(baseType, path, childNarrowings) + current = s.filterByChildNarrowings(current, path, childNarrowings) } - return baseType + return current } // deriveFromNarrowedAncestors projects narrowed ancestor path types down to a target key. @@ -578,7 +586,7 @@ func (s *Solution) deriveFromNarrowedAncestors(targetKey constraint.PathKey, dom if !ok { return nil, false } - targetSegs := pathkey.ParseSuffix(targetSuffix) + targetSegs := s.parseSuffixCached(targetSuffix) seen := make(map[constraint.PathKey]bool) candidates := make([]constraint.PathKey, 0, len(dom.Type.Narrowed)+len(dom.Shape.Narrowed)) @@ -607,7 +615,7 @@ func (s *Solution) deriveFromNarrowedAncestors(targetKey constraint.PathKey, dom if !ok || sym != targetSym || version != targetVersion { continue } - ancestorSegs := pathkey.ParseSuffix(suffix) + ancestorSegs := s.parseSuffixCached(suffix) if len(ancestorSegs) >= len(targetSegs) { continue } @@ -664,7 +672,7 @@ func (s *Solution) filterByChildNarrowings(baseType typ.Type, parentPath constra if !ok || childSym != parentSym { continue } - segs := pathkey.ParseSuffix(suffix) + segs := s.parseSuffixCached(suffix) if len(segs) == 0 { continue } @@ -734,3 +742,20 @@ func (s *Solution) HasKeyOf(p cfg.Point, tablePath, keyPath constraint.Path) boo } return constraint.HasKeyOfConstraint(cond, tablePath, keyPath, resolve) } + +func (s *Solution) parseSuffixCached(suffix string) []constraint.Segment { + if suffix == "" { + return nil + } + cache := s.scratchParsedSuffixes + if cache == nil { + cache = make(map[string][]constraint.Segment, 32) + s.scratchParsedSuffixes = cache + } + if segs, ok := cache[suffix]; ok { + return segs + } + segs := pathkey.ParseSuffix(suffix) + cache[suffix] = segs + return segs +} diff --git a/types/flow/solver.go b/types/flow/solver.go index b729c3fb..18c86862 100644 --- a/types/flow/solver.go +++ b/types/flow/solver.go @@ -9,7 +9,6 @@ import ( "github.com/wippyai/go-lua/types/flow/numeric" "github.com/wippyai/go-lua/types/flow/pathkey" "github.com/wippyai/go-lua/types/flow/propagate" - "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/narrow" "github.com/wippyai/go-lua/types/typ" ) @@ -45,6 +44,8 @@ type Solution struct { scratchUnresolvedPaths map[constraint.PathKey]struct{} scratchValueMap map[constraint.PathKey]typ.Type scratchResolvedPathMap map[constraint.PathKey]constraint.PathKey + scratchParsedSuffixes map[string][]constraint.Segment + fieldOverlayCache map[string][]mergedField pathAliases map[string]string // canonical target path key -> canonical source path key narrowedTypeCache map[narrowedTypeCacheKey]narrowedTypeCacheValue queryCacheEnabled bool @@ -72,6 +73,12 @@ type edgeKey struct { to cfg.Point } +type mergedField struct { + Name string + Type typ.Type + Optional bool +} + // Solve computes flow analysis and returns the solution. // // Solve is the main entry point for flow-sensitive type analysis. It takes @@ -170,31 +177,43 @@ func (s *Solution) runPropagation() { // // This environment is passed to the constraint solver for type narrowing. func (s *Solution) buildPointValueMap(p cfg.Point, targetPath constraint.Path, baseType typ.Type, constraints []constraint.Constraint) map[constraint.PathKey]typ.Type { + visibleVersions := map[cfg.SymbolID]cfg.Version(nil) + visibleCount := len(s.declaredSyms) + if s.inputs != nil && s.inputs.Graph != nil { + visibleVersions = s.inputs.Graph.AllVisibleVersions(p) + visibleCount = len(visibleVersions) + } + queryVisibleLookup := visibleVersions != nil + result := s.scratchValueMap if result == nil { - result = make(map[constraint.PathKey]typ.Type, 1+len(s.declaredSyms)) + result = make(map[constraint.PathKey]typ.Type, estimatePointValueMapCapacity(visibleCount, len(constraints))) s.scratchValueMap = result } clear(result) - versionIDs := s.scratchVersionIDs - if versionIDs == nil { - versionIDs = make(map[cfg.SymbolID]int, len(s.declaredSyms)+1) - s.scratchVersionIDs = versionIDs - } - clear(versionIDs) + var versionIDs map[cfg.SymbolID]int + var missingVersions map[cfg.SymbolID]struct{} + hasMissingVersions := false + if !queryVisibleLookup { + versionIDs = s.scratchVersionIDs + if versionIDs == nil { + versionIDs = make(map[cfg.SymbolID]int, estimateVersionCacheCapacity(visibleCount)) + s.scratchVersionIDs = versionIDs + } + clear(versionIDs) - missingVersions := s.scratchMissingVersions - if missingVersions == nil { - missingVersions = make(map[cfg.SymbolID]struct{}, 8) - s.scratchMissingVersions = missingVersions + missingVersions = s.scratchMissingVersions + if missingVersions == nil { + missingVersions = make(map[cfg.SymbolID]struct{}, 8) + s.scratchMissingVersions = missingVersions + } + clear(missingVersions) } - clear(missingVersions) - hasMissingVersions := false unresolved := s.scratchUnresolvedPaths if unresolved == nil { - unresolved = make(map[constraint.PathKey]struct{}, len(constraints)) + unresolved = make(map[constraint.PathKey]struct{}, estimateUnresolvedPathCapacity(len(constraints))) s.scratchUnresolvedPaths = unresolved } clear(unresolved) @@ -212,6 +231,13 @@ func (s *Solution) buildPointValueMap(p cfg.Point, targetPath constraint.Path, b if path.Version != 0 { return s.pkResolver.KeyAtVersion(path.Symbol, path.Version, path.Segments) } + if queryVisibleLookup { + ver, ok := visibleVersions[path.Symbol] + if !ok || ver.IsZero() { + return "" + } + return s.pkResolver.KeyAtVersion(path.Symbol, ver.ID, path.Segments) + } if hasMissingVersions { if _, missing := missingVersions[path.Symbol]; missing { return "" @@ -238,16 +264,40 @@ func (s *Solution) buildPointValueMap(p cfg.Point, targetPath constraint.Path, b // Add declared types for symbols visible at this point if s.inputs != nil && s.inputs.DeclaredTypes != nil && s.inputs.Graph != nil { - for _, sym := range s.declaredSyms { - declPath := constraint.Path{Symbol: sym} - canonicalKey := keyAtPoint(declPath) - if canonicalKey == "" { - continue - } - declType := s.inputs.DeclaredTypes[sym] - if _, exists := result[canonicalKey]; !exists { + if s.queryCacheEnabled { + for sym, ver := range visibleVersions { + if ver.IsZero() { + continue + } + declType := s.inputs.DeclaredTypes[sym] + if declType == nil { + continue + } + canonicalKey := s.pkResolver.KeyAtVersion(sym, ver.ID, nil) + if canonicalKey == "" { + continue + } + if _, exists := result[canonicalKey]; exists { + continue + } + if t := s.values[string(canonicalKey)]; t != nil { + result[canonicalKey] = t + continue + } result[canonicalKey] = declType } + } else { + for _, sym := range s.declaredSyms { + declPath := constraint.Path{Symbol: sym} + canonicalKey := keyAtPoint(declPath) + if canonicalKey == "" { + continue + } + declType := s.inputs.DeclaredTypes[sym] + if _, exists := result[canonicalKey]; !exists { + result[canonicalKey] = declType + } + } } } @@ -279,26 +329,26 @@ func (s *Solution) buildPointValueMap(p cfg.Point, targetPath constraint.Path, b // Tracks canonical paths we already attempted but could not resolve. // Successful resolutions live in result and are checked directly. for _, c := range constraints { - for _, cpath := range c.Paths() { + constraint.VisitPaths(c, func(cpath constraint.Path) bool { if cpath.IsEmpty() || cpath.Symbol == 0 { - continue + return false } cpath = normalizeConstraintPathForQuery(cpath) canonicalKey := keyAtPoint(cpath) if canonicalKey == "" { - continue + return false } if _, exists := result[canonicalKey]; exists { - continue + return false } if _, knownUnresolved := unresolved[canonicalKey]; knownUnresolved { - continue + return false } // Look up value using canonical key if t := s.values[string(canonicalKey)]; t != nil { result[canonicalKey] = t - continue + return false } // Derive child path type from parent's base type @@ -307,7 +357,7 @@ func (s *Solution) buildPointValueMap(p cfg.Point, targetPath constraint.Path, b if len(relativeSegs) > 0 { if derived, ok := s.deriveTypeFrom(baseType, relativeSegs); ok { result[canonicalKey] = derived - continue + return false } } } @@ -318,22 +368,48 @@ func (s *Solution) buildPointValueMap(p cfg.Point, targetPath constraint.Path, b if rootType, ok := resolveRootType(cpath.Symbol); ok { if len(cpath.Segments) == 0 { result[canonicalKey] = rootType - continue + return false } if derived, ok := s.deriveTypeFrom(rootType, cpath.Segments); ok { result[canonicalKey] = derived - continue + return false } } unresolved[canonicalKey] = struct{}{} - } + return false + }) } } return result } +func estimatePointValueMapCapacity(visibleCount, constraintCount int) int { + capacity := 8 + if visibleCount > 0 { + capacity += min(visibleCount, 32) + } + if constraintCount > 0 { + capacity += min(constraintCount*2, 32) + } + return capacity +} + +func estimateVersionCacheCapacity(visibleCount int) int { + if visibleCount <= 0 { + return 8 + } + return min(visibleCount+1, 32) +} + +func estimateUnresolvedPathCapacity(constraintCount int) int { + if constraintCount <= 0 { + return 8 + } + return min(constraintCount, 16) +} + // isDescendantOf returns true if child is a strict descendant of parent. // // A path is a descendant if it has the same symbol and extends the parent's @@ -713,49 +789,11 @@ func (s *Solution) resolveTypeKey(key narrow.TypeKey) typ.Type { // // This enables gradual type construction for tables built incrementally. func (s *Solution) mergeFieldAssignments(baseType typ.Type, baseKey string) typ.Type { - type mergedField struct { - Name string - Type typ.Type - Optional bool - } - - var fields []mergedField baseSym, baseVersion, _, ok := pathkey.ParseKeyUnchecked(constraint.PathKey(baseKey)) if !ok { return baseType } - baseRoot := pathkey.SymbolVersionRoot(baseSym, baseVersion) - prefixLen := len(baseRoot) - keys := make([]string, 0, len(s.values)) - for key := range s.values { - if len(key) <= prefixLen || key[:prefixLen] != baseRoot { - continue - } - // Match only strict child paths of this symbol/version root. - next := key[prefixLen] - if next != '.' && next != '[' { - continue - } - keys = append(keys, key) - } - if len(keys) == 0 { - return baseType - } - sort.Strings(keys) - for _, key := range keys { - suffix := key[prefixLen:] - segs := pathkey.ParseSuffix(suffix) - if len(segs) != 1 { - continue - } - seg := segs[0] - switch seg.Kind { - case constraint.SegmentField, constraint.SegmentIndexString: - fieldType, optional := splitOptionalAssignedFieldType(s.values[key]) - fields = append(fields, mergedField{Name: seg.Name, Type: fieldType, Optional: optional}) - } - } - + fields := s.fieldAssignmentsForRoot(pathkey.SymbolVersionRoot(baseSym, baseVersion)) if len(fields) == 0 { return baseType } @@ -766,6 +804,29 @@ func (s *Solution) mergeFieldAssignments(baseType typ.Type, baseKey string) typ. // Merge fields into base type return typ.Visit(baseType, typ.Visitor[typ.Type]{ + Alias: func(a *typ.Alias) typ.Type { + merged := s.mergeFieldAssignments(a.Target, baseKey) + if merged == nil || typ.TypeEquals(merged, a.Target) { + return baseType + } + return typ.NewAlias(a.Name, merged) + }, + Recursive: func(r *typ.Recursive) typ.Type { + mergedBody := s.mergeFieldAssignments(r.Body, baseKey) + if mergedBody == nil || typ.TypeEquals(mergedBody, r.Body) { + return baseType + } + + rebuilt := typ.NewRecursivePlaceholder(r.Name) + rebuiltBody := typ.Rewrite(mergedBody, func(n typ.Type) (typ.Type, bool) { + if typ.IsRecursiveRef(n, r) { + return rebuilt, true + } + return nil, false + }) + rebuilt.SetBody(rebuiltBody) + return rebuilt + }, Map: func(m *typ.Map) typ.Type { // Map base: create Record(open) with MapComponent + merged fields builder := typ.NewRecord().SetOpen(true) @@ -805,8 +866,13 @@ func (s *Solution) mergeFieldAssignments(baseType typ.Type, baseKey string) typ. fieldType := f.Type optional := f.Optional if assigned, ok := assignedByName[f.Name]; ok { - fieldType = typ.JoinReturnSlot(fieldType, assigned.t) - optional = optional || assigned.optional + // Child-path facts already represent the current value of the + // field at this program point. Rebuilding the root should + // project that current field value back into the record rather + // than re-join it with the declared/base slot as if it were a + // separate branch. + fieldType = assigned.t + optional = assigned.optional delete(assignedByName, f.Name) } switch { @@ -850,70 +916,123 @@ func (s *Solution) mergeFieldAssignments(baseType typ.Type, baseKey string) typ. }) } -// splitOptionalAssignedFieldType converts nil-capable field assignment types -// into (innerType, optional=true) so merged record shapes model absent fields -// as optional fields instead of required fields typed as T|nil. -func splitOptionalAssignedFieldType(t typ.Type) (typ.Type, bool) { - if t == nil { - return typ.Unknown, true - } - // Preserve alias identity for non-optional assignments. Losing alias names - // here breaks downstream receiver checks for self-recursive method types. - if a, ok := t.(*typ.Alias); ok { - if a == nil || a.Target == nil { - return t, false - } - if opt, ok := a.Target.(*typ.Optional); ok && opt != nil && opt.Inner != nil { - return opt.Inner, true - } - if u, ok := a.Target.(*typ.Union); ok && u != nil && len(u.Members) > 0 { - hasNil := false - nonNil := make([]typ.Type, 0, len(u.Members)) - for _, m := range u.Members { - if m != nil && m.Kind() == kind.Nil { - hasNil = true - continue - } - nonNil = append(nonNil, m) - } - if hasNil { - switch len(nonNil) { - case 0: - return typ.Nil, true - case 1: - return nonNil[0], true - default: - return typ.NewUnion(nonNil...), true - } - } - } - return t, false +func (s *Solution) fieldAssignmentsForRoot(baseRoot string) []mergedField { + if s == nil || len(s.values) == 0 || baseRoot == "" { + return nil } - if opt, ok := t.(*typ.Optional); ok && opt != nil && opt.Inner != nil { - return opt.Inner, true + if s.fieldOverlayCache == nil { + s.fieldOverlayCache = make(map[string][]mergedField) } - u, ok := t.(*typ.Union) - if !ok || u == nil || len(u.Members) == 0 { - return t, false + if fields, ok := s.fieldOverlayCache[baseRoot]; ok { + return fields } - hasNil := false - nonNil := make([]typ.Type, 0, len(u.Members)) - for _, m := range u.Members { - if m != nil && m.Kind() == kind.Nil { - hasNil = true + fields := s.collectFieldAssignmentsForRoot(baseRoot) + s.fieldOverlayCache[baseRoot] = fields + return fields +} + +func (s *Solution) collectFieldAssignmentsForRoot(baseRoot string) []mergedField { + prefixLen := len(baseRoot) + fields := make([]mergedField, 0, 8) + for key, value := range s.values { + if len(key) <= prefixLen || key[:prefixLen] != baseRoot { + continue + } + seg, ok := parseSingleOverlaySegment(key[prefixLen:]) + if !ok { continue } - nonNil = append(nonNil, m) + fieldType, optional := typ.SplitNilableFieldType(value) + fields = append(fields, mergedField{Name: seg.Name, Type: fieldType, Optional: optional}) } - if !hasNil { - return t, false + sortMergedFields(fields) + return fields +} + +func sortMergedFields(fields []mergedField) { + if len(fields) <= 1 { + return } - switch len(nonNil) { - case 0: - return typ.Nil, true - case 1: - return nonNil[0], true + sort.Slice(fields, func(i, j int) bool { + if fields[i].Name != fields[j].Name { + return fields[i].Name < fields[j].Name + } + if fields[i].Optional != fields[j].Optional { + return !fields[i].Optional && fields[j].Optional + } + return fields[i].Type.Hash() < fields[j].Type.Hash() + }) +} + +func parseSingleOverlaySegment(suffix string) (constraint.Segment, bool) { + if suffix == "" { + return constraint.Segment{}, false + } + switch suffix[0] { + case '.': + name := suffix[1:] + if name == "" || !pathkey.IsIdentName(name) { + return constraint.Segment{}, false + } + return constraint.Segment{Kind: constraint.SegmentField, Name: name}, true + case '[': + if len(suffix) < 3 || suffix[len(suffix)-1] != ']' { + return constraint.Segment{}, false + } + inner := suffix[1 : len(suffix)-1] + if inner == "" { + return constraint.Segment{}, false + } + if inner[0] == '"' { + if len(inner) < 2 || inner[len(inner)-1] != '"' { + return constraint.Segment{}, false + } + name, ok := parseQuotedOverlayIndex(inner[1 : len(inner)-1]) + if !ok { + return constraint.Segment{}, false + } + return constraint.Segment{Kind: constraint.SegmentIndexString, Name: name}, true + } + if _, ok := pathkey.ParseIntLiteral(inner); ok { + return constraint.Segment{}, false + } + return constraint.Segment{Kind: constraint.SegmentIndexString, Name: inner}, true default: - return typ.NewUnion(nonNil...), true + return constraint.Segment{}, false + } +} + +func parseQuotedOverlayIndex(inner string) (string, bool) { + if inner == "" { + return "", true + } + escaped := false + for i := 0; i < len(inner); i++ { + if inner[i] == '\\' { + escaped = true + break + } + } + if !escaped { + return inner, true + } + + out := make([]byte, 0, len(inner)) + for i := 0; i < len(inner); i++ { + ch := inner[i] + if ch != '\\' { + out = append(out, ch) + continue + } + if i+1 >= len(inner) { + return "", false + } + next := inner[i+1] + if next != '\\' && next != '"' { + return "", false + } + out = append(out, next) + i++ } + return string(out), true } diff --git a/types/flow/solver_helpers.go b/types/flow/solver_helpers.go index 4f83e563..30da3b73 100644 --- a/types/flow/solver_helpers.go +++ b/types/flow/solver_helpers.go @@ -93,10 +93,25 @@ func (s *Solution) initSymbolTypes(src symbolTypeSource) { } } - s.values[keyStr] = t + s.setValue(keyStr, t) } } +func (s *Solution) setValue(key string, t typ.Type) { + if s == nil || s.values == nil || key == "" { + return + } + s.values[key] = t + if s.fieldOverlayCache == nil { + return + } + _, _, suffix, ok := pathkey.ParseKeyUnchecked(constraint.PathKey(key)) + if !ok || suffix == "" { + return + } + delete(s.fieldOverlayCache, key[:len(key)-len(suffix)]) +} + // dependencyMap tracks which CFG points depend on a given canonical key. // // During worklist iteration, when a key's type value changes, all dependent diff --git a/types/flow/solver_test.go b/types/flow/solver_test.go index 388ec9b8..47a20710 100644 --- a/types/flow/solver_test.go +++ b/types/flow/solver_test.go @@ -379,6 +379,122 @@ func TestMergeFieldAssignments_PreservesAliasFieldType(t *testing.T) { } } +func TestMergeFieldAssignments_PreservesAliasRootType(t *testing.T) { + builderAlias := typ.NewAlias("Builder", typ.NewRecord().Field("_messages", typ.NewArray(typ.String)).Build()) + s := &Solution{ + values: map[string]typ.Type{ + `sym11@1._messages`: typ.NewArray(typ.NewUnion(typ.String, typ.Integer)), + }, + } + + got := s.mergeFieldAssignments(builderAlias, "sym11@1") + alias, ok := got.(*typ.Alias) + if !ok { + t.Fatalf("mergeFieldAssignments(alias root) = %T, want *typ.Alias", got) + } + if alias.Name != "Builder" { + t.Fatalf("alias name = %q, want Builder", alias.Name) + } + rec, ok := alias.Target.(*typ.Record) + if !ok { + t.Fatalf("alias target = %T, want *typ.Record", alias.Target) + } + field := rec.GetField("_messages") + if field == nil { + t.Fatalf("expected _messages field in merged alias target") + } + arr, ok := field.Type.(*typ.Array) + if !ok { + t.Fatalf("_messages field type = %T, want *typ.Array", field.Type) + } + if !typ.TypeEquals(arr.Element, typ.NewUnion(typ.String, typ.Integer)) { + t.Fatalf("_messages element = %v, want string|integer", arr.Element) + } +} + +func TestMergeFieldAssignments_PreservesRecursiveAliasRootType(t *testing.T) { + rec := typ.NewRecursivePlaceholder("Builder") + rec.SetBody( + typ.NewRecord(). + Field("_messages", typ.NewArray(typ.String)). + Field("clone", typ.Func(). + Param("self", rec). + Returns(rec). + Build()). + Build(), + ) + builderAlias := typ.NewAlias("Builder", rec) + + s := &Solution{ + values: map[string]typ.Type{ + `sym12@1._messages`: typ.NewArray(typ.LiteralString("x")), + }, + } + + got := s.mergeFieldAssignments(builderAlias, "sym12@1") + alias, ok := got.(*typ.Alias) + if !ok { + t.Fatalf("mergeFieldAssignments(recursive alias root) = %T, want *typ.Alias", got) + } + if alias.Name != "Builder" { + t.Fatalf("alias name = %q, want Builder", alias.Name) + } + mergedRec, ok := alias.Target.(*typ.Recursive) + if !ok { + t.Fatalf("alias target = %T, want *typ.Recursive", alias.Target) + } + body, ok := mergedRec.Body.(*typ.Record) + if !ok { + t.Fatalf("recursive body = %T, want *typ.Record", mergedRec.Body) + } + msgs := body.GetField("_messages") + if msgs == nil { + t.Fatalf("missing _messages field in merged recursive body") + } + arr, ok := msgs.Type.(*typ.Array) + if !ok { + t.Fatalf("_messages field type = %T, want *typ.Array", msgs.Type) + } + if !typ.TypeEquals(arr.Element, typ.LiteralString("x")) { + t.Fatalf("_messages element = %v, want literal x", arr.Element) + } + clone := body.GetField("clone") + if clone == nil { + t.Fatalf("missing clone field in merged recursive body") + } + fn, ok := clone.Type.(*typ.Function) + if !ok { + t.Fatalf("clone field type = %T, want *typ.Function", clone.Type) + } + if len(fn.Params) != 1 || !typ.IsRecursiveRef(fn.Params[0].Type, mergedRec) { + t.Fatalf("clone self param = %v, want rebuilt recursive self", fn.Params) + } + if len(fn.Returns) != 1 || !typ.IsRecursiveRef(fn.Returns[0], mergedRec) { + t.Fatalf("clone return = %v, want rebuilt recursive self", fn.Returns) + } +} + +func TestFieldAssignmentsForRoot_InvalidatesCachedRootOnFieldWrite(t *testing.T) { + s := &Solution{ + values: map[string]typ.Type{ + `sym21@1.name`: typ.String, + }, + fieldOverlayCache: make(map[string][]mergedField), + } + + first := s.fieldAssignmentsForRoot("sym21@1") + if len(first) != 1 || first[0].Name != "name" || !typ.TypeEquals(first[0].Type, typ.String) { + t.Fatalf("first field overlay = %v, want name:string", first) + } + + s.setValue(`sym21@1.name`, typ.Integer) + + second := s.fieldAssignmentsForRoot("sym21@1") + if len(second) != 1 || second[0].Name != "name" || !typ.TypeEquals(second[0].Type, typ.Integer) { + t.Fatalf("second field overlay = %v, want name:integer", second) + } +} + // setupSymbol registers a symbol and sets its visibility at all given points. // Returns the SymbolID for use in version creation. func setupSymbol(g *mockSSAGraph, name string, points []cfg.Point) cfg.SymbolID { @@ -447,6 +563,159 @@ func TestFlow_PhiKey_Join(t *testing.T) { } } +func TestFlow_PhiChildSuffix_UsesPredecessorNarrowingWhenSuffixMissing(t *testing.T) { + c, branch, thenNode, elseNode, join := buildBranchJoinCFG() + g := newMockSSAGraph(c) + + allPoints := []cfg.Point{c.Entry(), branch, thenNode, elseNode, join, c.Exit()} + symMessages := setupSymbol(g, "messages", allPoints) + + ver1 := cfg.Version{Root: "messages", Symbol: symMessages, ID: 1} + ver2 := cfg.Version{Root: "messages", Symbol: symMessages, ID: 2} + ver3 := cfg.Version{Root: "messages", Symbol: symMessages, ID: 3} + + setVersion(g, c.Entry(), symMessages, ver1) + setVersion(g, branch, symMessages, ver1) + setVersion(g, thenNode, symMessages, ver2) + setVersion(g, elseNode, symMessages, ver1) + setVersion(g, join, symMessages, ver3) + setVersion(g, c.Exit(), symMessages, ver3) + + g.addPhiNode(cfg.PhiNode{ + Point: join, + Target: ver3, + Operands: []cfg.PhiOperand{ + {From: thenNode, Version: ver2}, + {From: elseNode, Version: ver1}, + }, + }) + + messageType := typ.NewRecord(). + Field("topic", typ.Func().Returns(typ.String).Build()). + Build() + messagesType := typ.NewMap(typ.String, messageType) + childPath := constraint.Path{ + Root: "messages", + Symbol: symMessages, + Segments: []constraint.Segment{ + {Kind: constraint.SegmentField, Name: "root"}, + }, + } + + inputs := newInputs(g) + inputs.DeclaredTypes[symMessages] = messagesType + inputs.Assignments = []UnifiedAssignment{ + { + Point: c.Entry(), + TargetPath: constraint.Path{Root: "messages", Symbol: symMessages}, + Type: messagesType, + }, + { + Point: thenNode, + TargetPath: childPath, + Type: messageType, + }, + } + inputs.EdgeConditions = []EdgeCondition{ + { + From: branch, + To: thenNode, + Condition: constraint.FromConstraints(constraint.Falsy{Path: childPath}), + }, + { + From: branch, + To: elseNode, + Condition: constraint.FromConstraints(constraint.Truthy{Path: childPath}), + }, + } + + s := Solve(inputs, testResolver()) + + got := s.TypeAt(join, childPath) + if got == nil { + t.Fatal("TypeAt(join, messages.root) returned nil") + } + if core.ContainsNil(got) { + t.Fatalf("TypeAt(join, messages.root) should be definite after constructive join, got %v", got) + } + if !typ.TypeEquals(got, messageType) { + t.Fatalf("TypeAt(join, messages.root) = %v, want %v", got, messageType) + } + + fullKey := string(s.pkResolver.KeyAt(join, childPath)) + raw := s.DebugValueAt(fullKey, join) + if raw == nil { + t.Fatal("joined child suffix value missing from solver state") + } + if core.ContainsNil(raw) { + t.Fatalf("raw joined child suffix should not contain nil, got %v", raw) + } +} + +func TestFlow_PhiChildSuffix_OneBranchOnlyInstallRemainsOptional(t *testing.T) { + c, _, thenNode, elseNode, join := buildBranchJoinCFG() + g := newMockSSAGraph(c) + + allPoints := []cfg.Point{c.Entry(), thenNode, elseNode, join, c.Exit()} + symMessages := setupSymbol(g, "messages", allPoints) + + ver1 := cfg.Version{Root: "messages", Symbol: symMessages, ID: 1} + ver2 := cfg.Version{Root: "messages", Symbol: symMessages, ID: 2} + ver3 := cfg.Version{Root: "messages", Symbol: symMessages, ID: 3} + + setVersion(g, c.Entry(), symMessages, ver1) + setVersion(g, thenNode, symMessages, ver2) + setVersion(g, elseNode, symMessages, ver1) + setVersion(g, join, symMessages, ver3) + setVersion(g, c.Exit(), symMessages, ver3) + + g.addPhiNode(cfg.PhiNode{ + Point: join, + Target: ver3, + Operands: []cfg.PhiOperand{ + {From: thenNode, Version: ver2}, + {From: elseNode, Version: ver1}, + }, + }) + + messageType := typ.NewRecord(). + Field("topic", typ.Func().Returns(typ.String).Build()). + Build() + messagesType := typ.NewMap(typ.String, messageType) + childPath := constraint.Path{ + Root: "messages", + Symbol: symMessages, + Segments: []constraint.Segment{ + {Kind: constraint.SegmentField, Name: "root"}, + }, + } + + inputs := newInputs(g) + inputs.DeclaredTypes[symMessages] = messagesType + inputs.Assignments = []UnifiedAssignment{ + { + Point: c.Entry(), + TargetPath: constraint.Path{Root: "messages", Symbol: symMessages}, + Type: messagesType, + }, + { + Point: thenNode, + TargetPath: childPath, + Type: messageType, + }, + } + + s := Solve(inputs, testResolver()) + + got := s.TypeAt(join, childPath) + if got == nil { + t.Fatal("TypeAt(join, messages.root) returned nil") + } + if !core.ContainsNil(got) { + t.Fatalf("TypeAt(join, messages.root) should remain optional when only one branch installs it, got %v", got) + } +} + func TestConditionAt_JoinOr(t *testing.T) { c, branch, thenNode, elseNode, join := buildBranchJoinCFG() g := newMockSSAGraph(c) @@ -844,6 +1113,71 @@ func TestNarrowedTypeAt_NestedBranch_PropagatesNarrowing(t *testing.T) { } } +func TestNarrowedTypeAt_ComposesNotNilWithChildDiscriminant(t *testing.T) { + c, branch1, then1, branch2, then2 := buildNestedBranchCFG() + g := newMockSSAGraph(c) + + allPoints := []cfg.Point{c.Entry(), branch1, then1, branch2, then2, c.Exit()} + symX := setupSymbol(g, "x", allPoints) + verX := cfg.Version{Root: "x", Symbol: symX, ID: 1} + for _, p := range allPoints { + setVersion(g, p, symX, verX) + } + + allow := typ.NewRecord(). + Field("kind", typ.LiteralString("allow")). + Field("reason", typ.String). + Build() + deny := typ.NewRecord(). + Field("kind", typ.LiteralString("deny")). + Field("reason", typ.String). + Build() + deferType := typ.NewRecord(). + Field("kind", typ.LiteralString("defer")). + Field("queue", typ.String). + Build() + + inputs := newInputs(g) + inputs.DeclaredTypes[symX] = typ.NewOptional(typ.NewUnion(allow, deny, deferType)) + + pathX := constraint.Path{Root: "x", Symbol: symX} + inputs.EdgeConditions = []EdgeCondition{ + { + From: branch1, + To: then1, + Condition: constraint.FromConstraints(constraint.NotNil{Path: pathX}), + }, + { + From: branch2, + To: then2, + Condition: constraint.FromConstraints(constraint.FieldEquals{ + Target: pathX, + Field: "kind", + Value: typ.LiteralString("defer"), + }), + }, + } + + s := Solve(inputs, testResolver()) + + got := s.NarrowedTypeAt(then2, pathX) + if !typ.TypeEquals(got, deferType) { + t.Fatalf("NarrowedTypeAt(then2) = %v, want %v", got, deferType) + } + + pathQueue := constraint.Path{ + Root: "x", + Symbol: symX, + Segments: []constraint.Segment{ + {Kind: constraint.SegmentField, Name: "queue"}, + }, + } + gotQueue := s.NarrowedTypeAt(then2, pathQueue) + if !typ.TypeEquals(gotQueue, typ.String) { + t.Fatalf("NarrowedTypeAt(then2, x.queue) = %v, want string", gotQueue) + } +} + // buildReassignCFG creates: entry -> assign1(x) -> call1 -> assign2(x) -> call2 -> use -> exit // This models: x = f1(); assert_is_nil(x); x = f2(); assert_not_nil(x); use(x) func buildReassignCFG() (*cfg.CFG, cfg.Point, cfg.Point, cfg.Point, cfg.Point, cfg.Point) { @@ -7089,9 +7423,9 @@ func TestUnionWithMixedTableAndPrimitive(t *testing.T) { } } -// TestInferFunctionEffect_ReturnConstraint tests that return expression constraints -// are correctly converted to function effects with placeholder substitution. -func TestInferFunctionEffect_ReturnConstraint(t *testing.T) { +// TestInferFunctionRefinement_ReturnConstraint tests that return expression constraints +// are correctly converted to function refinements with placeholder substitution. +func TestInferFunctionRefinement_ReturnConstraint(t *testing.T) { // Build simple CFG: entry -> return -> exit c := cfg.New() ret := c.AddNode(cfg.NodeReturn, cfg.SymbolID(0), "") @@ -7151,11 +7485,11 @@ func TestInferFunctionEffect_ReturnConstraint(t *testing.T) { } t.Logf("Entry: %d, Exit: %d", c.Entry(), c.Exit()) - eff := InferFunctionEffect(s, c, params, typ.Any) + eff := InferFunctionRefinement(s, c, params, typ.Any) t.Logf("Effect: %+v", eff) if eff == nil { - t.Fatal("InferFunctionEffect returned nil, want effect with OnReturn") + t.Fatal("InferFunctionRefinement returned nil, want effect with OnReturn") } if !eff.OnReturn.HasConstraints() { diff --git a/types/flow/transfer.go b/types/flow/transfer.go index 44f35c40..85b5804c 100644 --- a/types/flow/transfer.go +++ b/types/flow/transfer.go @@ -130,7 +130,7 @@ func (s *Solution) processAssignmentReturnChangedKeys(p cfg.Point) []string { assignedType = s.normalizeNilFieldAssignmentType(p, assign.TargetPath, old) } if !typ.TypeEquals(old, assignedType) { - s.values[targetKeyStr] = assignedType + s.setValue(targetKeyStr, assignedType) changedKeys = append(changedKeys, targetKeyStr) } if len(assign.TargetPath.Segments) > 0 { @@ -274,7 +274,7 @@ func (s *Solution) mirrorAliasedFieldWrite(p cfg.Point, targetPath constraint.Pa if typ.TypeEquals(old, newType) { return nil } - s.values[keyStr] = newType + s.setValue(keyStr, newType) return []string{keyStr} } @@ -771,7 +771,7 @@ func (s *Solution) processIndexerAssignmentReturnKey(p cfg.Point, ia IndexerAssi return "" } - s.values[string(pathKey)] = newType + s.setValue(string(pathKey), newType) return string(pathKey) } @@ -911,7 +911,7 @@ func (s *Solution) processTableMutatorAssignmentReturnKey(p cfg.Point, tm TableM return "" } - s.values[string(pathKey)] = newType + s.setValue(string(pathKey), newType) return string(pathKey) } @@ -971,7 +971,7 @@ func (s *Solution) processContainerMutatorAssignmentReturnKey(p cfg.Point, cm Co return "" } - s.values[string(pathKey)] = newType + s.setValue(string(pathKey), newType) return string(pathKey) } @@ -1094,6 +1094,13 @@ func WidenArrayElementType(arrayType typ.Type, elementType typ.Type, joinFn func } return typ.Visit(arrayType, typ.Visitor[typ.Type]{ + Alias: func(a *typ.Alias) typ.Type { + widened := WidenArrayElementType(a.Target, elementType, joinFn) + if widened == nil || typ.TypeEquals(widened, a.Target) { + return arrayType + } + return typ.NewAlias(a.Name, widened) + }, Array: func(arr *typ.Array) typ.Type { return typ.NewArray(joinFn(arr.Element, elementType)) }, @@ -1147,6 +1154,13 @@ func WidenMapValueArray(mapType typ.Type, keyType, elementType typ.Type) typ.Typ } return typ.Visit(mapType, typ.Visitor[typ.Type]{ + Alias: func(a *typ.Alias) typ.Type { + widened := WidenMapValueArray(a.Target, keyType, elementType) + if widened == nil || typ.TypeEquals(widened, a.Target) { + return mapType + } + return typ.NewAlias(a.Name, widened) + }, Map: func(m *typ.Map) typ.Type { newKey := mergeMapKeyDomain(m.Key, keyType) newVal := WidenArrayElementType(m.Value, elementType, typ.JoinPreferNonSoft) @@ -1258,6 +1272,13 @@ func widenWithIndexer(t typ.Type, keyType, valType typ.Type) typ.Type { } return typ.Visit(t, typ.Visitor[typ.Type]{ + Alias: func(a *typ.Alias) typ.Type { + widened := widenWithIndexer(a.Target, keyType, valType) + if widened == nil || typ.TypeEquals(widened, a.Target) { + return t + } + return typ.NewAlias(a.Name, widened) + }, Tuple: func(tp *typ.Tuple) typ.Type { elemType := valType for _, elem := range tp.Elements { @@ -1365,31 +1386,13 @@ func (s *Solution) processJoinReturnChangedKeys(p cfg.Point) []string { continue } - // Construct path for this phi symbol - path := constraint.Path{ - Root: phi.Target.Root, - Symbol: phi.Target.Symbol, - } - // Collect types from operands, applying edge conditions types := s.scratchTypes[:0] for _, op := range phi.Operands { - opKey := s.pkResolver.KeyAtVersion(op.Version.Symbol, op.Version.ID, nil) - if opKey == "" { - continue - } - opType := s.values[string(opKey)] + opType := s.phiOperandTypeAt(p, op, nil) if opType == nil { continue } - - // Apply edge condition from predecessor to phi point - edgeK := edgeKey{from: op.From, to: p} - if cond, ok := s.edgeConditions[edgeK]; ok && cond.HasConstraints() { - if narrowed := s.applyCondition(op.From, opType, path, cond); narrowed != nil { - opType = narrowed - } - } types = append(types, opType) } @@ -1405,7 +1408,7 @@ func (s *Solution) processJoinReturnChangedKeys(p cfg.Point) []string { } old := s.values[string(targetKey)] if !typ.TypeEquals(old, joined) { - s.values[string(targetKey)] = joined + s.setValue(string(targetKey), joined) changedKeys = append(changedKeys, string(targetKey)) } @@ -1423,32 +1426,11 @@ func (s *Solution) processJoinReturnChangedKeys(p cfg.Point) []string { } types = types[:0] for _, op := range phi.Operands { - opBaseKey := s.pkResolver.KeyAtVersion(op.Version.Symbol, op.Version.ID, nil) - if opBaseKey == "" { - types = append(types, typ.Nil) - continue - } - opKey := string(opBaseKey) + suffix - if opType := s.values[opKey]; opType != nil { - types = append(types, opType) - } else { - // Fall back to deriving the suffix from the operand's base type. - // Structured assignments may update a field path without writing - // every sibling suffix key explicitly on each version. - opBaseType := s.values[string(opBaseKey)] - if opBaseType == nil { - types = append(types, typ.Nil) - continue - } - derived, ok := s.deriveTypeFrom(opBaseType, segments) - if !ok || derived == nil { - // Missing suffix on one phi operand means the merged - // field/index path is nil on that path. - types = append(types, typ.Nil) - continue - } - types = append(types, derived) + opType := s.phiOperandTypeAt(p, op, segments) + if opType == nil { + opType = typ.Nil } + types = append(types, opType) } if len(types) == 0 { continue @@ -1457,7 +1439,7 @@ func (s *Solution) processJoinReturnChangedKeys(p cfg.Point) []string { fullKey := string(targetKey) + suffix old = s.values[fullKey] if !typ.TypeEquals(old, joined) { - s.values[fullKey] = joined + s.setValue(fullKey, joined) changedKeys = append(changedKeys, fullKey) } } @@ -1466,6 +1448,36 @@ func (s *Solution) processJoinReturnChangedKeys(p cfg.Point) []string { return changedKeys } +func (s *Solution) phiOperandTypeAt(joinPoint cfg.Point, op cfg.PhiOperand, segments []constraint.Segment) typ.Type { + if s == nil { + return nil + } + path := constraint.Path{ + Root: op.Version.Root, + Symbol: op.Version.Symbol, + Version: op.Version.ID, + } + if len(segments) > 0 { + path.Segments = append(path.Segments, segments...) + } + + opType := s.NarrowedTypeAt(op.From, path) + if opType == nil { + opType = s.baseTypeAt(op.From, path) + } + if opType == nil { + return nil + } + + edgeK := edgeKey{from: op.From, to: joinPoint} + if cond, ok := s.edgeConditions[edgeK]; ok && cond.HasConstraints() { + if narrowed := s.applyCondition(op.From, opType, path, cond); narrowed != nil { + opType = narrowed + } + } + return opType +} + // collectPhiOperandSuffixes collects field suffixes from phi operand canonical keys. // // When a phi merges versions of a variable, field assignments made on different diff --git a/types/flow/transfer_test.go b/types/flow/transfer_test.go index 5b069d1d..8433c84b 100644 --- a/types/flow/transfer_test.go +++ b/types/flow/transfer_test.go @@ -237,6 +237,48 @@ func TestWidenMapValueArray_PrefersNonSoftElement(t *testing.T) { } } +func TestWidenArrayElementType_PreservesAlias(t *testing.T) { + base := typ.NewAlias("Items", typ.NewArray(typ.String)) + got := WidenArrayElementType(base, typ.Integer, typ.JoinPreferNonSoft) + alias, ok := got.(*typ.Alias) + if !ok { + t.Fatalf("WidenArrayElementType(alias) = %T, want *typ.Alias", got) + } + if alias.Name != "Items" { + t.Fatalf("alias name = %q, want Items", alias.Name) + } + arr, ok := alias.Target.(*typ.Array) + if !ok { + t.Fatalf("alias target = %T, want *typ.Array", alias.Target) + } + if !typ.TypeEquals(arr.Element, typ.NewUnion(typ.String, typ.Integer)) { + t.Fatalf("alias array element = %v, want string|integer", arr.Element) + } +} + +func TestWidenMapValueArray_PreservesAlias(t *testing.T) { + base := typ.NewAlias("Registry", typ.NewMap(typ.String, typ.NewArray(typ.String))) + got := WidenMapValueArray(base, typ.String, typ.Integer) + alias, ok := got.(*typ.Alias) + if !ok { + t.Fatalf("WidenMapValueArray(alias) = %T, want *typ.Alias", got) + } + if alias.Name != "Registry" { + t.Fatalf("alias name = %q, want Registry", alias.Name) + } + mp, ok := alias.Target.(*typ.Map) + if !ok { + t.Fatalf("alias target = %T, want *typ.Map", alias.Target) + } + arr, ok := mp.Value.(*typ.Array) + if !ok { + t.Fatalf("map value = %T, want *typ.Array", mp.Value) + } + if !typ.TypeEquals(arr.Element, typ.NewUnion(typ.String, typ.Integer)) { + t.Fatalf("map alias array element = %v, want string|integer", arr.Element) + } +} + func TestProcessJoinReturnChangedKeys_NoPhi(t *testing.T) { c := cfg.New() g := newMockSSAGraph(c) diff --git a/types/io/manifest.go b/types/io/manifest.go index ec4334fc..b29743ce 100644 --- a/types/io/manifest.go +++ b/types/io/manifest.go @@ -625,7 +625,7 @@ func ApplyFunctionSummary(fn *typ.Function, summary *FunctionSummary) *typ.Funct // Build refinement from summary ensures (for narrowing), fall back to fn's refinement if summary.Ensures.HasConstraints() { - eff := &constraint.FunctionEffect{ + eff := &constraint.FunctionRefinement{ OnReturn: summary.Ensures, } builder.WithRefinement(eff) diff --git a/types/io/manifest_test.go b/types/io/manifest_test.go index eff585d9..4370d4b2 100644 --- a/types/io/manifest_test.go +++ b/types/io/manifest_test.go @@ -412,7 +412,7 @@ func TestManifest_EnrichedExport_SummarySuffixFallback(t *testing.T) { if !ok { t.Fatalf("not_nil field is not function: %T", field.Type) } - refinement, ok := fn.Refinement.(*constraint.FunctionEffect) + refinement, ok := fn.Refinement.(*constraint.FunctionRefinement) if !ok || refinement == nil || !refinement.OnReturn.HasConstraints() { t.Fatalf("expected suffix-matched summary refinement on not_nil, got %#v", fn.Refinement) } diff --git a/types/io/predicates.go b/types/io/predicates.go index a5706b5c..10a9e563 100644 --- a/types/io/predicates.go +++ b/types/io/predicates.go @@ -176,7 +176,7 @@ func (w *typeWriter) writeCondition(cond constraint.Condition) { } } -func (w *typeWriter) writeFunctionEffect(eff *constraint.FunctionEffect) { +func (w *typeWriter) writeFunctionRefinement(eff *constraint.FunctionRefinement) { if eff == nil { w.writeBool(false) return @@ -412,12 +412,12 @@ func (r *typeReader) readCondition() constraint.Condition { return constraint.Condition{Disjuncts: disjuncts} } -func (r *typeReader) readFunctionEffect() *constraint.FunctionEffect { +func (r *typeReader) readFunctionRefinement() *constraint.FunctionRefinement { if !r.readBool() { return nil } - return &constraint.FunctionEffect{ + return &constraint.FunctionRefinement{ OnReturn: r.readCondition(), OnTrue: r.readCondition(), OnFalse: r.readCondition(), diff --git a/types/io/predicates_test.go b/types/io/predicates_test.go index 1ce06241..07a1a05e 100644 --- a/types/io/predicates_test.go +++ b/types/io/predicates_test.go @@ -296,8 +296,8 @@ func TestPredicateCodec_Condition_MultiDisjunct(t *testing.T) { } } -func TestPredicateCodec_FunctionEffect(t *testing.T) { - orig := &constraint.FunctionEffect{ +func TestPredicateCodec_FunctionRefinement(t *testing.T) { + orig := &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(constraint.NotNil{Path: path("$0")}), OnTrue: constraint.FromConstraints(constraint.HasType{Path: path("$0"), Type: narrow.BuiltinTypeKey("string")}), OnFalse: constraint.FromConstraints(constraint.NotHasType{Path: path("$0"), Type: narrow.BuiltinTypeKey("string")}), @@ -305,15 +305,15 @@ func TestPredicateCodec_FunctionEffect(t *testing.T) { var buf bytes.Buffer w := &typeWriter{w: &buf} - w.writeFunctionEffect(orig) + w.writeFunctionRefinement(orig) r := &typeReader{r: bytes.NewReader(buf.Bytes())} - got := r.readFunctionEffect() + got := r.readFunctionRefinement() if r.err != nil { - t.Fatalf("readFunctionEffect: %v", r.err) + t.Fatalf("readFunctionRefinement: %v", r.err) } if got == nil { - t.Fatal("expected non-nil FunctionEffect") + t.Fatal("expected non-nil FunctionRefinement") } if len(got.OnReturn.MustConstraints()) != 1 { t.Errorf("OnReturn: got %d, want 1", len(got.OnReturn.MustConstraints())) @@ -326,15 +326,15 @@ func TestPredicateCodec_FunctionEffect(t *testing.T) { } } -func TestPredicateCodec_FunctionEffect_Nil(t *testing.T) { +func TestPredicateCodec_FunctionRefinement_Nil(t *testing.T) { var buf bytes.Buffer w := &typeWriter{w: &buf} - w.writeFunctionEffect(nil) + w.writeFunctionRefinement(nil) r := &typeReader{r: bytes.NewReader(buf.Bytes())} - got := r.readFunctionEffect() + got := r.readFunctionRefinement() if r.err != nil { - t.Fatalf("readFunctionEffect: %v", r.err) + t.Fatalf("readFunctionRefinement: %v", r.err) } if got != nil { t.Errorf("expected nil, got %+v", got) @@ -507,29 +507,29 @@ func TestPredicateCodec_KeyOf(t *testing.T) { } } -func TestPredicateCodec_FunctionEffect_WithKeyOf(t *testing.T) { +func TestPredicateCodec_FunctionRefinement_WithKeyOf(t *testing.T) { keyOf := constraint.KeyOf{ Table: constraint.ParamPath(0), Key: constraint.RetPath(0), } - orig := &constraint.FunctionEffect{ + orig := &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(keyOf), } var buf bytes.Buffer w := &typeWriter{w: &buf} - w.writeFunctionEffect(orig) + w.writeFunctionRefinement(orig) if w.err != nil { - t.Fatalf("writeFunctionEffect: %v", w.err) + t.Fatalf("writeFunctionRefinement: %v", w.err) } r := &typeReader{r: bytes.NewReader(buf.Bytes())} - got := r.readFunctionEffect() + got := r.readFunctionRefinement() if r.err != nil { - t.Fatalf("readFunctionEffect: %v", r.err) + t.Fatalf("readFunctionRefinement: %v", r.err) } if got == nil { - t.Fatal("expected non-nil FunctionEffect") + t.Fatal("expected non-nil FunctionRefinement") } if !got.OnReturn.HasConstraints() { t.Fatal("OnReturn should have constraints") diff --git a/types/io/reader.go b/types/io/reader.go index 33a40759..9ce5b13e 100644 --- a/types/io/reader.go +++ b/types/io/reader.go @@ -254,7 +254,7 @@ func (r *typeReader) readType() typ.Type { } } - if refinement := r.readFunctionEffect(); refinement != nil { + if refinement := r.readFunctionRefinement(); refinement != nil { fb.WithRefinement(refinement) } diff --git a/types/io/serialize_test.go b/types/io/serialize_test.go index fbba12a2..5a125314 100644 --- a/types/io/serialize_test.go +++ b/types/io/serialize_test.go @@ -311,7 +311,7 @@ func TestEncodeDecode_Function(t *testing.T) { }) t.Run("with-refinement", func(t *testing.T) { - refinement := &constraint.FunctionEffect{ + refinement := &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(constraint.NotNil{Path: constraint.Path{Root: "$0"}}), OnTrue: constraint.FromConstraints(constraint.HasType{Path: constraint.Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}), OnFalse: constraint.FromConstraints(constraint.NotHasType{Path: constraint.Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}), @@ -337,7 +337,7 @@ func TestEncodeDecode_Function(t *testing.T) { t.Fatalf("expected Function, got %T", decoded) } - decodedRefinement, ok := fn.Refinement.(*constraint.FunctionEffect) + decodedRefinement, ok := fn.Refinement.(*constraint.FunctionRefinement) if !ok || decodedRefinement == nil { t.Fatal("expected refinement") } @@ -1103,9 +1103,9 @@ func TestDNFPreservation_SpecEnsures(t *testing.T) { } } -// TestDNFPreservation_FunctionEffect verifies multi-disjunct OnTrue/OnFalse -// conditions in FunctionEffect are preserved through encode/decode. -func TestDNFPreservation_FunctionEffect(t *testing.T) { +// TestDNFPreservation_FunctionRefinement verifies multi-disjunct OnTrue/OnFalse +// conditions in FunctionRefinement are preserved through encode/decode. +func TestDNFPreservation_FunctionRefinement(t *testing.T) { // OnTrue: (string) OR (number) onTrueDisjuncts := [][]constraint.Constraint{ {constraint.HasType{Path: constraint.Path{Root: "$0"}, Type: narrow.BuiltinTypeKey("string")}}, @@ -1120,7 +1120,7 @@ func TestDNFPreservation_FunctionEffect(t *testing.T) { }, } - refinement := &constraint.FunctionEffect{ + refinement := &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(constraint.NotNil{Path: constraint.Path{Root: "$0"}}), OnTrue: constraint.Condition{Disjuncts: onTrueDisjuncts}, OnFalse: constraint.Condition{Disjuncts: onFalseDisjuncts}, @@ -1143,7 +1143,7 @@ func TestDNFPreservation_FunctionEffect(t *testing.T) { } fn := decoded.(*typ.Function) - decodedRef := fn.Refinement.(*constraint.FunctionEffect) + decodedRef := fn.Refinement.(*constraint.FunctionRefinement) // Verify OnTrue DNF if decodedRef.OnTrue.NumDisjuncts() != 2 { diff --git a/types/io/writer.go b/types/io/writer.go index cea927b0..e7ad9a50 100644 --- a/types/io/writer.go +++ b/types/io/writer.go @@ -172,10 +172,10 @@ func (w *typeWriter) writeTypeData(t typ.Type) { w.writeBool(false) } - if eff, ok := v.Refinement.(*constraint.FunctionEffect); ok { - w.writeFunctionEffect(eff) + if eff, ok := v.Refinement.(*constraint.FunctionRefinement); ok { + w.writeFunctionRefinement(eff) } else { - w.writeFunctionEffect(nil) + w.writeFunctionRefinement(nil) } return struct{}{} }, diff --git a/types/narrow/narrow.go b/types/narrow/narrow.go index de421b6a..9d4d2dfb 100644 --- a/types/narrow/narrow.go +++ b/types/narrow/narrow.go @@ -650,8 +650,9 @@ func FilterByKind(t typ.Type, target kind.Kind) typ.Type { handleUnion: func(u *typ.Union, _ func(typ.Type) typ.Type) typ.Type { var kept []typ.Type for _, m := range u.Members { - if KindMatches(m, target) { - kept = append(kept, m) + narrowed := FilterByKind(m, target) + if narrowed != nil && !narrowed.Kind().IsNever() { + kept = append(kept, narrowed) } } if len(kept) == 0 { @@ -663,6 +664,9 @@ func FilterByKind(t typ.Type, target kind.Kind) typ.Type { return typ.NewUnion(kept...) }, handleLeaf: func(t typ.Type) typ.Type { + if t.Kind().IsPlaceholder() { + return TypeForKind(target) + } if KindMatches(t, target) { return t } diff --git a/types/narrow/narrow_test.go b/types/narrow/narrow_test.go index 5a00ad25..6c3697ec 100644 --- a/types/narrow/narrow_test.go +++ b/types/narrow/narrow_test.go @@ -269,6 +269,31 @@ func TestExcludeKind_Unknown_PreservesUnknown(t *testing.T) { } } +func TestFilterByKind_UnionPlaceholderAndNil_PreservesRuntimePossibility(t *testing.T) { + got := narrow.FilterByKind(typ.NewUnion(typ.Unknown, typ.Nil), kind.Record) + if got == nil || got.Kind().IsNever() { + t.Fatalf("FilterByKind(unknown|nil, table) = %v, want table-like type", got) + } + if !narrow.KindMatches(got, kind.Record) { + t.Fatalf("FilterByKind(unknown|nil, table) = %v, want table-like type", got) + } +} + +func TestFilterByKind_UnionAnyAndNil_Number(t *testing.T) { + got := narrow.FilterByKind(typ.NewUnion(typ.Any, typ.Nil), kind.Number) + if !typ.TypeEquals(got, typ.Number) { + t.Fatalf("FilterByKind(any|nil, number) = %v, want number", got) + } +} + +func TestExcludeKind_OptionalUnknown_PreservesUnknownOptional(t *testing.T) { + got := narrow.ExcludeKind(typ.NewOptional(typ.Unknown), kind.String) + want := typ.NewOptional(typ.Unknown) + if !typ.TypeEquals(got, want) { + t.Fatalf("ExcludeKind(unknown?, string) = %v, want %v", got, want) + } +} + func TestExcludeType_Optional(t *testing.T) { opt := typ.NewOptional(typ.NewUnion(typ.String, typ.Number)) got := narrow.ExcludeType(opt, typ.Number) diff --git a/types/query/core/engine_test.go b/types/query/core/engine_test.go index 2ba88e34..25dbd8c7 100644 --- a/types/query/core/engine_test.go +++ b/types/query/core/engine_test.go @@ -3,6 +3,7 @@ package core import ( "testing" + "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/typ" ) @@ -127,6 +128,48 @@ func TestEngine_Callable_Function(t *testing.T) { } } +func TestEngine_IsSubtype_RecursiveAliasRecordUnion(t *testing.T) { + e := NewEngine() + ctx := db.NewQueryContext(db.New()) + + rec := typ.NewRecursivePlaceholder("Message") + msgAlias := typ.NewAlias("Message", rec) + rec.SetBody(typ.NewRecord(). + Field("_topic", typ.String). + Field("topic", typ.Func().Param("self", rec).Returns(typ.String).Build()). + Build()) + + msgCh := typ.NewAlias("MsgCh", typ.NewRecord().Field("__tag", typ.LiteralString("msg")).Build()) + timerCh := typ.NewAlias("TimerCh", typ.NewRecord().Field("__tag", typ.LiteralString("timer")).Build()) + timer := typ.NewRecord().Field("elapsed", typ.Number).Build() + + result := typ.NewUnion( + typ.NewRecord(). + Field("channel", msgCh). + Field("value", msgAlias). + Field("ok", typ.Boolean). + Build(), + typ.NewRecord(). + Field("channel", timerCh). + Field("value", timer). + Field("ok", typ.Boolean). + Build(), + ) + + synthesized := typ.NewRecord(). + Field("channel", msgCh). + Field("value", typ.NewRecord(). + Field("_topic", typ.String). + Field("topic", typ.Func().Param("s", msgAlias).Returns(typ.String).Build()). + Build()). + Field("ok", typ.True). + Build() + + if !e.IsSubtype(ctx, synthesized, result) { + t.Fatal("engine subtype should accept recursive alias record inside union member") + } +} + func TestEngine_Callable_NonFunction(t *testing.T) { e := NewEngine() _, ok := e.Callable(nil, typ.String) diff --git a/types/query/core/field.go b/types/query/core/field.go index d12c5f43..2e114d09 100644 --- a/types/query/core/field.go +++ b/types/query/core/field.go @@ -79,6 +79,13 @@ func fieldDepth(t typ.Type, name string, depth int) (typ.Type, bool) { ft, ok := fieldInOptional(o, name, depth) return fieldResult{t: ft, ok: ok} }, + Recursive: func(rec *typ.Recursive) fieldResult { + if rec.Body == nil || rec.Body == rec { + return fieldResult{} + } + ft, ok := fieldDepth(rec.Body, name, depth+1) + return fieldResult{t: ft, ok: ok} + }, Alias: func(a *typ.Alias) fieldResult { ft, ok := fieldDepth(a.Target, name, depth+1) return fieldResult{t: ft, ok: ok} @@ -128,6 +135,9 @@ func fieldInRecordDepth(r *typ.Record, name string, depth int) (typ.Type, bool) // Direct field lookup if f := r.GetField(name); f != nil { + if f.Optional { + return typ.NewOptional(f.Type), true + } return f.Type, true } diff --git a/types/query/core/field_test.go b/types/query/core/field_test.go index 55e15b7c..60545e2c 100644 --- a/types/query/core/field_test.go +++ b/types/query/core/field_test.go @@ -11,6 +11,9 @@ func TestField(t *testing.T) { Field("name", typ.String). Field("age", typ.Integer). Build() + recWithOpt := typ.NewRecord(). + OptField("name", typ.String). + Build() iface := typ.NewInterface("Reader", []typ.Method{ {Name: "read", Type: typ.Func().Param("n", typ.Integer).Returns(typ.String).Build()}, @@ -26,6 +29,9 @@ func TestField(t *testing.T) { {"nil type", nil, "x", false, nil}, {"record existing field", rec, "name", true, func(t typ.Type) bool { return t == typ.String }}, {"record another field", rec, "age", true, func(t typ.Type) bool { return t == typ.Integer }}, + {"record optional field", recWithOpt, "name", true, func(t typ.Type) bool { + return typ.TypeEquals(t, typ.NewOptional(typ.String)) + }}, {"record missing field", rec, "missing", false, nil}, {"interface method", iface, "read", true, func(t typ.Type) bool { return t.Kind() == typ.String.Kind() || true }}, {"interface missing", iface, "write", false, nil}, @@ -125,6 +131,8 @@ func TestFieldIntersection(t *testing.T) { func TestFieldOptional(t *testing.T) { rec := typ.NewRecord().Field("x", typ.Number).Build() opt := typ.NewOptional(rec) + optFieldRec := typ.NewRecord().OptField("x", typ.Number).Build() + optFieldOptRec := typ.NewOptional(optFieldRec) t.Run("field on optional record", func(t *testing.T) { result, ok := Field(opt, "x") @@ -141,6 +149,26 @@ func TestFieldOptional(t *testing.T) { t.Error("expected optional wrapper on field type") } }) + + t.Run("optional field on record", func(t *testing.T) { + result, ok := Field(optFieldRec, "x") + if !ok { + t.Error("expected to find optional field") + } + if !typ.TypeEquals(result, typ.NewOptional(typ.Number)) { + t.Errorf("expected number?, got %v", result) + } + }) + + t.Run("optional field on optional record stays optional", func(t *testing.T) { + result, ok := Field(optFieldOptRec, "x") + if !ok { + t.Error("expected to find optional field on optional record") + } + if !typ.TypeEquals(result, typ.NewOptional(typ.Number)) { + t.Errorf("expected number?, got %v", result) + } + }) } func TestFieldAlias(t *testing.T) { diff --git a/types/query/core/index.go b/types/query/core/index.go index 81012020..0829f87e 100644 --- a/types/query/core/index.go +++ b/types/query/core/index.go @@ -1,6 +1,8 @@ package core import ( + "sort" + "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -90,29 +92,8 @@ func indexDepth(t, keyType typ.Type, depth int) (typ.Type, bool) { if len(r.Fields) == 0 && !r.HasMapComponent() { return indexResult{t: typ.Nil, ok: true} } - // String key for record field access - if lit, ok := keyType.(*typ.Literal); ok && lit.Base == kind.String { - fieldType, found := fieldInRecord(r, lit.Value.(string)) - return indexResult{t: fieldType, ok: found} - } - // Union of string literals: look up each field and union the results - if union, ok := keyType.(*typ.Union); ok { - var resultTypes []typ.Type - allLiterals := true - for _, m := range union.Members { - lit, isLit := m.(*typ.Literal) - if !isLit || lit.Base != kind.String { - allLiterals = false - break - } - fieldType, found := fieldInRecord(r, lit.Value.(string)) - if found && fieldType != nil { - resultTypes = append(resultTypes, fieldType) - } - } - if allLiterals && len(resultTypes) > 0 { - return indexResult{t: typ.NewUnion(resultTypes...), ok: true} - } + if keySet, ok := exactStringKeyDomain(keyType, depth+1); ok { + return indexRecordByExactStringKeyDomain(r, keySet, depth+1) } // Unknown string returns optional union of all field types if keyType.Kind() == kind.String { @@ -204,6 +185,13 @@ func indexDepth(t, keyType typ.Type, depth int) (typ.Type, bool) { return indexResult{t: typ.NewOptional(et), ok: true} }, + Recursive: func(rec *typ.Recursive) indexResult { + if rec.Body == nil || rec.Body == rec { + return indexResult{} + } + et, ok := indexDepth(rec.Body, keyType, depth+1) + return indexResult{t: et, ok: ok} + }, Alias: func(a *typ.Alias) indexResult { et, ok := indexDepth(a.Target, keyType, depth+1) return indexResult{t: et, ok: ok} @@ -258,3 +246,133 @@ func containsNilOrOptional(t typ.Type) bool { }, }) } + +// exactStringKeyDomain returns the finite set of string keys represented by t. +// +// It only succeeds when the key type is exactly a finite union of string literals +// after transparently traversing wrappers such as aliases and instantiated types. +// This lets record projection reason about the full key domain instead of relying +// on raw AST/type shape. +func exactStringKeyDomain(t typ.Type, depth int) ([]string, bool) { + if stopDepth(t, depth) { + return nil, false + } + + keys := typ.Visit(t, typ.Visitor[[]string]{ + Literal: func(lit *typ.Literal) []string { + if lit.Base != kind.String { + return nil + } + str, ok := lit.Value.(string) + if !ok { + return nil + } + return []string{str} + }, + Union: func(u *typ.Union) []string { + if len(u.Members) == 0 { + return nil + } + seen := make(map[string]struct{}, len(u.Members)) + var keys []string + for _, member := range u.Members { + memberKeys, ok := exactStringKeyDomain(member, depth+1) + if !ok { + return nil + } + for _, key := range memberKeys { + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + keys = append(keys, key) + } + } + sort.Strings(keys) + return keys + }, + Alias: func(a *typ.Alias) []string { + keys, ok := exactStringKeyDomain(a.Target, depth+1) + if !ok { + return nil + } + return keys + }, + Optional: func(o *typ.Optional) []string { + keys, ok := exactStringKeyDomain(o.Inner, depth+1) + if !ok { + return nil + } + return keys + }, + Recursive: func(rec *typ.Recursive) []string { + if rec.Body == nil || rec.Body == rec { + return nil + } + keys, ok := exactStringKeyDomain(rec.Body, depth+1) + if !ok { + return nil + } + return keys + }, + Instantiated: func(inst *typ.Instantiated) []string { + resolved, err := ResolveInstantiated(inst) + if err != nil { + return nil + } + keys, ok := exactStringKeyDomain(resolved, depth+1) + if !ok { + return nil + } + return keys + }, + TypeParam: func(tp *typ.TypeParam) []string { + if tp.Constraint == nil { + return nil + } + keys, ok := exactStringKeyDomain(tp.Constraint, depth+1) + if !ok { + return nil + } + return keys + }, + Default: func(t typ.Type) []string { + return nil + }, + }) + if keys == nil { + return nil, false + } + return keys, true +} + +func indexRecordByExactStringKeyDomain(r *typ.Record, keys []string, depth int) indexResult { + if len(keys) == 0 { + return indexResult{} + } + + var matched []typ.Type + missing := false + + for _, key := range keys { + fieldType, found := fieldInRecordDepth(r, key, depth+1) + if !found { + missing = true + continue + } + if fieldType != nil { + matched = append(matched, fieldType) + } + } + + if len(matched) == 0 { + return indexResult{} + } + + out := typ.NewUnion(matched...) + if missing && !containsNilOrOptional(out) { + out = typ.NewOptional(out) + } + + return indexResult{t: out, ok: true} +} diff --git a/types/query/core/index_test.go b/types/query/core/index_test.go index b99f673f..bbe95d91 100644 --- a/types/query/core/index_test.go +++ b/types/query/core/index_test.go @@ -130,6 +130,77 @@ func TestIndex_RecordWithMapComponent_GenericStringIncludesFieldAndMap(t *testin } } +func TestIndex_RecordWithAliasLiteralKeyUnion_KnownKeysStayDefinite(t *testing.T) { + rec := typ.NewRecord(). + Field("postgres", typ.String). + Field("sqlite", typ.Integer). + Field("mysql", typ.Boolean). + Build() + keyType := typ.NewAlias("DbType", typ.NewUnion( + typ.LiteralString("postgres"), + typ.LiteralString("sqlite"), + typ.LiteralString("mysql"), + )) + + got, ok := Index(rec, keyType) + if !ok { + t.Fatal("expected alias-wrapped literal key domain to resolve") + } + if ContainsNil(got) { + t.Fatalf("expected definite result, got %v", got) + } + if !subtype.IsSubtype(typ.String, got) || !subtype.IsSubtype(typ.Integer, got) || !subtype.IsSubtype(typ.Boolean, got) { + t.Fatalf("expected union of matching field types, got %v", got) + } +} + +func TestIndex_RecordWithAliasLiteralKey_PreservesSingleFieldPrecision(t *testing.T) { + rec := typ.NewRecord(). + Field("name", typ.String). + Field("count", typ.Integer). + Build() + keyType := typ.NewAlias("NameKey", typ.LiteralString("name")) + + got, ok := Index(rec, keyType) + if !ok { + t.Fatal("expected alias-wrapped literal key to resolve") + } + if !typ.TypeEquals(got, typ.String) { + t.Fatalf("expected string, got %v", got) + } +} + +func TestIndex_RecordWithLiteralKeyUnion_PartialMissBecomesOptional(t *testing.T) { + rec := typ.NewRecord(). + Field("present", typ.String). + Field("count", typ.Integer). + Build() + keyType := typ.NewUnion(typ.LiteralString("present"), typ.LiteralString("missing")) + + got, ok := Index(rec, keyType) + if !ok { + t.Fatal("expected partial literal key domain to resolve") + } + if !ContainsNil(got) { + t.Fatalf("expected optional result for partial miss, got %v", got) + } + if !subtype.IsSubtype(typ.String, got) { + t.Fatalf("expected present field type to survive, got %v", got) + } +} + +func TestIndex_RecordWithLiteralKeyUnion_AllMissingStillFails(t *testing.T) { + rec := typ.NewRecord(). + Field("present", typ.String). + Field("count", typ.Integer). + Build() + keyType := typ.NewUnion(typ.LiteralString("missing"), typ.LiteralString("also_missing")) + + if _, ok := Index(rec, keyType); ok { + t.Fatal("expected all-missing literal key domain to fail on closed record") + } +} + func TestIndexUnion(t *testing.T) { arr1 := typ.NewArray(typ.String) arr2 := typ.NewArray(typ.Integer) diff --git a/types/query/core/method.go b/types/query/core/method.go index d6dddee4..8a5ca432 100644 --- a/types/query/core/method.go +++ b/types/query/core/method.go @@ -187,6 +187,13 @@ func methodDepth(t typ.Type, name string, depth int) (typ.Type, bool) { mt, ok := methodDepth(o.Inner, name, depth+1) return fieldResult{t: mt, ok: ok} }, + Recursive: func(rec *typ.Recursive) fieldResult { + if rec.Body == nil || rec.Body == rec { + return fieldResult{} + } + mt, ok := methodDepth(rec.Body, name, depth+1) + return fieldResult{t: mt, ok: ok} + }, Alias: func(a *typ.Alias) fieldResult { mt, ok := methodDepth(a.Target, name, depth+1) return fieldResult{t: mt, ok: ok} @@ -261,6 +268,13 @@ func callableDepth(t typ.Type, depth int) (*typ.Function, bool) { fn, ok := callableDepth(o.Inner, depth+1) return callableResult{fn: fn, ok: ok} }, + Recursive: func(rec *typ.Recursive) callableResult { + if rec.Body == nil || rec.Body == rec { + return callableResult{} + } + fn, ok := callableDepth(rec.Body, depth+1) + return callableResult{fn: fn, ok: ok} + }, Alias: func(a *typ.Alias) callableResult { fn, ok := callableDepth(a.Target, depth+1) return callableResult{fn: fn, ok: ok} @@ -362,6 +376,13 @@ func getMetamethodDepth(t typ.Type, name string, depth int) (typ.Type, bool) { mt, ok := getMetamethodDepth(o.Inner, name, depth+1) return fieldResult{t: mt, ok: ok} }, + Recursive: func(rec *typ.Recursive) fieldResult { + if rec.Body == nil || rec.Body == rec { + return fieldResult{} + } + mt, ok := getMetamethodDepth(rec.Body, name, depth+1) + return fieldResult{t: mt, ok: ok} + }, Alias: func(a *typ.Alias) fieldResult { mt, ok := getMetamethodDepth(a.Target, name, depth+1) return fieldResult{t: mt, ok: ok} diff --git a/types/query/core/util.go b/types/query/core/util.go index aee1aa7e..85d250cf 100644 --- a/types/query/core/util.go +++ b/types/query/core/util.go @@ -79,6 +79,12 @@ func allFieldsDepth(t typ.Type, depth int) []string { return names }, + Recursive: func(rec *typ.Recursive) []string { + if rec.Body == nil || rec.Body == rec { + return nil + } + return allFieldsDepth(rec.Body, depth+1) + }, Alias: func(a *typ.Alias) []string { return allFieldsDepth(a.Target, depth+1) }, @@ -210,6 +216,12 @@ func allFieldTypesDepth(t typ.Type, depth int) map[string]typ.Type { } return fields }, + Recursive: func(rec *typ.Recursive) map[string]typ.Type { + if rec.Body == nil || rec.Body == rec { + return nil + } + return allFieldTypesDepth(rec.Body, depth+1) + }, Alias: func(a *typ.Alias) map[string]typ.Type { return allFieldTypesDepth(a.Target, depth+1) }, @@ -342,6 +354,12 @@ func allMethodsDepth(t typ.Type, depth int) []string { return allMethodsDepth(r.Metatable, depth+1) }, + Recursive: func(rec *typ.Recursive) []string { + if rec.Body == nil || rec.Body == rec { + return nil + } + return allMethodsDepth(rec.Body, depth+1) + }, Alias: func(a *typ.Alias) []string { return allMethodsDepth(a.Target, depth+1) }, @@ -382,6 +400,12 @@ func lengthDepth(t typ.Type, depth int) int { return -1 }, + Recursive: func(rec *typ.Recursive) int { + if rec.Body == nil || rec.Body == rec { + return -1 + } + return lengthDepth(rec.Body, depth+1) + }, Alias: func(a *typ.Alias) int { return lengthDepth(a.Target, depth+1) }, @@ -424,6 +448,12 @@ func iterableDepth(t typ.Type, depth int) bool { Record: func(r *typ.Record) bool { return true }, + Recursive: func(rec *typ.Recursive) bool { + if rec.Body == nil || rec.Body == rec { + return false + } + return iterableDepth(rec.Body, depth+1) + }, Union: func(u *typ.Union) bool { for _, member := range u.Members { if !iterableDepth(member, depth+1) { diff --git a/types/subtype/subtype.go b/types/subtype/subtype.go index 32e36745..993e878e 100644 --- a/types/subtype/subtype.go +++ b/types/subtype/subtype.go @@ -141,11 +141,19 @@ func (c *checker) check(sub, super typ.Type, depth int) bool { // Unwrap aliases if aa, ok := sub.(*typ.Alias); ok { - return c.check(aa.Target, super, depth+1) + return c.check(aa.UnaliasedTarget(), super, depth+1) } if aa, ok := super.(*typ.Alias); ok { - return c.check(sub, aa.Target, depth+1) + return c.check(sub, aa.UnaliasedTarget(), depth+1) + } + + if rr, ok := sub.(*typ.Recursive); ok && super.Kind() != kind.Recursive && rr.Body != nil && rr.Body != rr { + return c.check(rr.Body, super, depth+1) + } + + if rr, ok := super.(*typ.Recursive); ok && sub.Kind() != kind.Recursive && rr.Body != nil && rr.Body != rr { + return c.check(sub, rr.Body, depth+1) } // Handle instantiated generics @@ -992,7 +1000,9 @@ func (c *checker) checkInstantiated(sub, super *typ.Instantiated, depth int) boo func isTableLikeType(t typ.Type) bool { switch v := t.(type) { case *typ.Alias: - return isTableLikeType(v.Target) + return isTableLikeType(v.UnaliasedTarget()) + case *typ.Recursive: + return v.Body != nil && v.Body != v && isTableLikeType(v.Body) case *typ.Record, *typ.Map, *typ.Array, *typ.Tuple, *typ.Interface, *typ.Intersection: return true default: diff --git a/types/subtype/subtype_test.go b/types/subtype/subtype_test.go index 524d56e4..e3541a88 100644 --- a/types/subtype/subtype_test.go +++ b/types/subtype/subtype_test.go @@ -1057,6 +1057,63 @@ func TestTripleMutualRecursive(t *testing.T) { } } +func TestRecursiveAliasRecordSubtype_WithSelfMethodField(t *testing.T) { + rec := typ.NewRecursivePlaceholder("Message") + msgAlias := typ.NewAlias("Message", rec) + rec.SetBody(typ.NewRecord(). + Field("_topic", typ.String). + Field("topic", typ.Func().Param("self", rec).Returns(typ.String).Build()). + Build()) + + synthesized := typ.NewRecord(). + Field("_topic", typ.String). + Field("topic", typ.Func().Param("s", msgAlias).Returns(typ.String).Build()). + Build() + + if !IsSubtype(synthesized, msgAlias) { + t.Fatal("record literal with Message-annotated self method should subtype recursive Message alias") + } +} + +func TestRecursiveAliasRecordSubtype_InsideUnionMember(t *testing.T) { + rec := typ.NewRecursivePlaceholder("Message") + msgAlias := typ.NewAlias("Message", rec) + rec.SetBody(typ.NewRecord(). + Field("_topic", typ.String). + Field("topic", typ.Func().Param("self", rec).Returns(typ.String).Build()). + Build()) + + msgCh := typ.NewAlias("MsgCh", typ.NewRecord().Field("__tag", typ.LiteralString("msg")).Build()) + timerCh := typ.NewAlias("TimerCh", typ.NewRecord().Field("__tag", typ.LiteralString("timer")).Build()) + timer := typ.NewRecord().Field("elapsed", typ.Number).Build() + + result := typ.NewUnion( + typ.NewRecord(). + Field("channel", msgCh). + Field("value", msgAlias). + Field("ok", typ.Boolean). + Build(), + typ.NewRecord(). + Field("channel", timerCh). + Field("value", timer). + Field("ok", typ.Boolean). + Build(), + ) + + synthesized := typ.NewRecord(). + Field("channel", msgCh). + Field("value", typ.NewRecord(). + Field("_topic", typ.String). + Field("topic", typ.Func().Param("s", msgAlias).Returns(typ.String).Build()). + Build()). + Field("ok", typ.True). + Build() + + if !IsSubtype(synthesized, result) { + t.Fatal("record literal should subtype union member carrying recursive Message alias") + } +} + // Edge cases for empty unions and intersections func TestEmptyUnionIsNever(t *testing.T) { @@ -1338,6 +1395,17 @@ func TestRecordToMap(t *testing.T) { } } +func TestRecursiveRecordToMap(t *testing.T) { + rec := typ.NewRecursive("Node", func(self typ.Type) typ.Type { + return typ.NewRecord().Field("child", self).Build() + }) + mapType := typ.NewMap(typ.String, rec) + + if !IsSubtype(rec, mapType) { + t.Error("recursive record should be subtype of compatible recursive map") + } +} + func TestRecordToMapIncompatibleKey(t *testing.T) { rec := typ.NewRecord().Field("name", typ.String).Build() mapType := typ.NewMap(typ.Number, typ.String) diff --git a/types/typ/annotated.go b/types/typ/annotated.go index 5320d428..f20eeb28 100644 --- a/types/typ/annotated.go +++ b/types/typ/annotated.go @@ -21,6 +21,7 @@ type Annotated struct { Inner Type Annotations []Annotation hash uint64 + strCache stringCache } // NewAnnotated creates an annotated type wrapper. @@ -51,35 +52,37 @@ func (a *Annotated) Kind() kind.Kind { } func (a *Annotated) String() string { - var sb strings.Builder - if a.Inner != nil { - sb.WriteString(a.Inner.String()) - } else { - sb.WriteString("unknown") - } - for _, ann := range a.Annotations { - sb.WriteString(" @") - sb.WriteString(ann.Name) - if ann.Arg != nil { - sb.WriteString("(") - switch v := ann.Arg.(type) { - case string: - sb.WriteString("\"") - sb.WriteString(v) - sb.WriteString("\"") - case float64: - sb.WriteString(formatFloat(v)) - case int64: - sb.WriteString(formatInt(v)) - case int: - sb.WriteString(formatInt(int64(v))) - default: - sb.WriteString("...") + return a.strCache.get(func() string { + var sb strings.Builder + if a.Inner != nil { + sb.WriteString(a.Inner.String()) + } else { + sb.WriteString("unknown") + } + for _, ann := range a.Annotations { + sb.WriteString(" @") + sb.WriteString(ann.Name) + if ann.Arg != nil { + sb.WriteString("(") + switch v := ann.Arg.(type) { + case string: + sb.WriteString("\"") + sb.WriteString(v) + sb.WriteString("\"") + case float64: + sb.WriteString(formatFloat(v)) + case int64: + sb.WriteString(formatInt(v)) + case int: + sb.WriteString(formatInt(int64(v))) + default: + sb.WriteString("...") + } + sb.WriteString(")") } - sb.WriteString(")") } - } - return sb.String() + return sb.String() + }) } func (a *Annotated) Hash() uint64 { diff --git a/types/typ/container.go b/types/typ/container.go index e1354e35..2cf09df2 100644 --- a/types/typ/container.go +++ b/types/typ/container.go @@ -13,8 +13,10 @@ import ( // describes what each element contains. Arrays support ipairs iteration // and length operator (#). type Array struct { - Element Type - hash uint64 + Element Type + hash uint64 + softPrunable bool + strCache stringCache } // NewArray creates an array type. @@ -23,17 +25,19 @@ func NewArray(elem Type) *Array { elem = Unknown } h := internal.HashCombine(uint64(kind.Array), elem.Hash()) - return &Array{Element: elem, hash: h} + return &Array{Element: elem, hash: h, softPrunable: softPruneMayRewrite(elem)} } func (a *Array) Kind() kind.Kind { return kind.Array } func (a *Array) String() string { - if a.Element == nil { - return "unknown[]" - } - return a.Element.String() + "[]" + return a.strCache.get(func() string { + if a.Element == nil { + return "unknown[]" + } + return a.Element.String() + "[]" + }) } -func (a *Array) Hash() uint64 { return a.hash } +func (a *Array) Hash() uint64 { return a.hash } func (a *Array) Equals(o Type) bool { return TypeEquals(a, o) } @@ -44,9 +48,11 @@ func (a *Array) Equals(o Type) bool { // Unlike Records, Maps have uniform types for all entries rather than // named fields with potentially different types. type Map struct { - Key Type - Value Type - hash uint64 + Key Type + Value Type + hash uint64 + softPrunable bool + strCache stringCache } // NewMap creates a map type. @@ -60,21 +66,23 @@ func NewMap(key, value Type) *Map { h := internal.HashCombine(uint64(kind.Map), key.Hash()) h = internal.HashCombine(h, value.Hash()) - return &Map{Key: key, Value: value, hash: h} + return &Map{Key: key, Value: value, hash: h, softPrunable: softPruneAny(key, value)} } func (m *Map) Kind() kind.Kind { return kind.Map } func (m *Map) String() string { - ks, vs := "unknown", "unknown" - if m.Key != nil { - ks = m.Key.String() - } - if m.Value != nil { - vs = m.Value.String() - } - return "{[" + ks + "]: " + vs + "}" + return m.strCache.get(func() string { + ks, vs := "unknown", "unknown" + if m.Key != nil { + ks = m.Key.String() + } + if m.Value != nil { + vs = m.Value.String() + } + return "{[" + ks + "]: " + vs + "}" + }) } -func (m *Map) Hash() uint64 { return m.hash } +func (m *Map) Hash() uint64 { return m.hash } func (m *Map) Equals(o Type) bool { return TypeEquals(m, o) } @@ -85,38 +93,45 @@ func (m *Map) Equals(o Type) bool { // Unlike Arrays, each position can have a different type and the length // is fixed at compile time. type Tuple struct { - Elements []Type - hash uint64 + Elements []Type + hash uint64 + softPrunable bool + strCache stringCache } // NewTuple creates a tuple type. func NewTuple(elems ...Type) *Tuple { h := uint64(kind.Tuple) cleaned := make([]Type, len(elems)) + softPrunable := false for i, e := range elems { if e == nil { e = Unknown } cleaned[i] = e h = internal.HashCombine(h, e.Hash()) + if !softPrunable && softPruneMayRewrite(e) { + softPrunable = true + } } - return &Tuple{Elements: cleaned, hash: h} + return &Tuple{Elements: cleaned, hash: h, softPrunable: softPrunable} } func (t *Tuple) Kind() kind.Kind { return kind.Tuple } func (t *Tuple) String() string { - parts := make([]string, len(t.Elements)) - for i, e := range t.Elements { - if e == nil { - parts[i] = "unknown" - } else { - parts[i] = e.String() + return t.strCache.get(func() string { + parts := make([]string, len(t.Elements)) + for i, e := range t.Elements { + if e == nil { + parts[i] = "unknown" + } else { + parts[i] = e.String() + } } - } - - return "(" + strings.Join(parts, ", ") + ")" + return "(" + strings.Join(parts, ", ") + ")" + }) } func (t *Tuple) Hash() uint64 { return t.hash } diff --git a/types/typ/equals.go b/types/typ/equals.go index 4fcecede..aecc9589 100644 --- a/types/typ/equals.go +++ b/types/typ/equals.go @@ -240,7 +240,7 @@ func unwrapAliasForEquals(t Type, guard internal.RecursionGuard) Type { if !ok { return t } - t = alias.Target + t = alias.UnaliasedTarget() } return nil } diff --git a/types/typ/field_shape.go b/types/typ/field_shape.go new file mode 100644 index 00000000..d3762a71 --- /dev/null +++ b/types/typ/field_shape.go @@ -0,0 +1,56 @@ +package typ + +import "github.com/wippyai/go-lua/types/kind" + +// SplitNilableFieldType converts a nil-capable field value type into the shape +// used by table and record fields. In Lua, assigning nil to a table field +// removes the key, so a value expression of type T|nil in a table literal or +// field merge is represented as an optional field of type T rather than a +// required field of type T|nil. +func SplitNilableFieldType(t Type) (inner Type, optional bool) { + if t == nil { + return Unknown, true + } + if a, ok := t.(*Alias); ok { + if a == nil || a.Target == nil { + return t, false + } + if opt, ok := a.Target.(*Optional); ok && opt != nil && opt.Inner != nil { + return opt.Inner, true + } + if u, ok := a.Target.(*Union); ok && u != nil && len(u.Members) > 0 { + return splitNilableUnionMembers(u.Members, t) + } + return t, false + } + if opt, ok := t.(*Optional); ok && opt != nil && opt.Inner != nil { + return opt.Inner, true + } + if u, ok := t.(*Union); ok && u != nil && len(u.Members) > 0 { + return splitNilableUnionMembers(u.Members, t) + } + return t, false +} + +func splitNilableUnionMembers(members []Type, original Type) (Type, bool) { + hasNil := false + nonNil := make([]Type, 0, len(members)) + for _, m := range members { + if m != nil && m.Kind() == kind.Nil { + hasNil = true + continue + } + nonNil = append(nonNil, m) + } + if !hasNil { + return original, false + } + switch len(nonNil) { + case 0: + return Nil, true + case 1: + return nonNil[0], true + default: + return NewUnion(nonNil...), true + } +} diff --git a/types/typ/field_shape_test.go b/types/typ/field_shape_test.go new file mode 100644 index 00000000..b052a0b5 --- /dev/null +++ b/types/typ/field_shape_test.go @@ -0,0 +1,48 @@ +package typ + +import "testing" + +func TestSplitNilableFieldType(t *testing.T) { + t.Run("optional", func(t *testing.T) { + inner, optional := SplitNilableFieldType(NewOptional(String)) + if !optional { + t.Fatal("expected optional") + } + if !TypeEquals(inner, String) { + t.Fatalf("inner = %v, want string", inner) + } + }) + + t.Run("union with nil", func(t *testing.T) { + inner, optional := SplitNilableFieldType(NewUnion(String, Boolean, Nil)) + if !optional { + t.Fatal("expected optional") + } + want := NewUnion(String, Boolean) + if !TypeEquals(inner, want) { + t.Fatalf("inner = %v, want %v", inner, want) + } + }) + + t.Run("alias to optional", func(t *testing.T) { + maybeString := NewAlias("MaybeString", NewOptional(String)) + inner, optional := SplitNilableFieldType(maybeString) + if !optional { + t.Fatal("expected optional") + } + if !TypeEquals(inner, String) { + t.Fatalf("inner = %v, want string", inner) + } + }) + + t.Run("non optional alias preserved", func(t *testing.T) { + name := NewAlias("Name", String) + inner, optional := SplitNilableFieldType(name) + if optional { + t.Fatal("did not expect optional") + } + if inner != name { + t.Fatalf("inner = %v, want original alias", inner) + } + }) +} diff --git a/types/typ/function.go b/types/typ/function.go index 16f5c59f..e07b0954 100644 --- a/types/typ/function.go +++ b/types/typ/function.go @@ -21,14 +21,16 @@ type Param struct { // The Spec field holds Hoare-style contracts (pre/post conditions). // The Refinement field holds type narrowing constraints for predicate functions. type Function struct { - TypeParams []*TypeParam // Generic type parameters (empty for non-generic) - Params []Param // Positional parameters - Variadic Type // Variadic element type (nil if not variadic) - Returns []Type // Return types (empty for void functions) - Effects EffectInfo // Effect row (effect.Row) for mutation/throw/io tracking - Spec SpecInfo // Contract specification (*contract.Spec) - Refinement RefinementInfo // Type refinement effect (*constraint.FunctionEffect) - hash uint64 + TypeParams []*TypeParam // Generic type parameters (empty for non-generic) + Params []Param // Positional parameters + Variadic Type // Variadic element type (nil if not variadic) + Returns []Type // Return types (empty for void functions) + Effects EffectInfo // Effect row (effect.Row) for mutation/throw/io tracking + Spec SpecInfo // Contract specification (*contract.Spec) + Refinement RefinementInfo // Type refinement effect (*constraint.FunctionRefinement) + hash uint64 + softPrunable bool + strCache stringCache } // FunctionBuilder provides a fluent API for constructing function types. @@ -119,75 +121,77 @@ func (b *FunctionBuilder) Build() *Function { func (f *Function) Kind() kind.Kind { return kind.Function } func (f *Function) String() string { - var sb strings.Builder + return f.strCache.get(func() string { + var sb strings.Builder - sb.WriteString("fun") + sb.WriteString("fun") - if len(f.TypeParams) > 0 { - sb.WriteString("<") + if len(f.TypeParams) > 0 { + sb.WriteString("<") - for i, tp := range f.TypeParams { - if i > 0 { - sb.WriteString(", ") + for i, tp := range f.TypeParams { + if i > 0 { + sb.WriteString(", ") + } + + sb.WriteString(tp.String()) } - sb.WriteString(tp.String()) + sb.WriteString(">") } - sb.WriteString(">") - } - - sb.WriteString("(") + sb.WriteString("(") - for i, p := range f.Params { - if i > 0 { - sb.WriteString(", ") - } + for i, p := range f.Params { + if i > 0 { + sb.WriteString(", ") + } - if p.Name != "" { - sb.WriteString(p.Name) - sb.WriteString(": ") - } + if p.Name != "" { + sb.WriteString(p.Name) + sb.WriteString(": ") + } - sb.WriteString(p.Type.String()) + sb.WriteString(p.Type.String()) - if p.Optional { - sb.WriteString("?") + if p.Optional { + sb.WriteString("?") + } } - } - if f.Variadic != nil { - if len(f.Params) > 0 { - sb.WriteString(", ") + if f.Variadic != nil { + if len(f.Params) > 0 { + sb.WriteString(", ") + } + + sb.WriteString("...") + sb.WriteString(f.Variadic.String()) } - sb.WriteString("...") - sb.WriteString(f.Variadic.String()) - } + sb.WriteString(")") - sb.WriteString(")") + if len(f.Returns) > 0 { + sb.WriteString(" -> ") - if len(f.Returns) > 0 { - sb.WriteString(" -> ") + if len(f.Returns) == 1 { + sb.WriteString(f.Returns[0].String()) + } else { + sb.WriteString("(") - if len(f.Returns) == 1 { - sb.WriteString(f.Returns[0].String()) - } else { - sb.WriteString("(") + for i, r := range f.Returns { + if i > 0 { + sb.WriteString(", ") + } - for i, r := range f.Returns { - if i > 0 { - sb.WriteString(", ") + sb.WriteString(r.String()) } - sb.WriteString(r.String()) + sb.WriteString(")") } - - sb.WriteString(")") } - } - return sb.String() + return sb.String() + }) } func (f *Function) Hash() uint64 { return f.hash } diff --git a/types/typ/generic.go b/types/typ/generic.go index d98e010c..1298b903 100644 --- a/types/typ/generic.go +++ b/types/typ/generic.go @@ -73,6 +73,7 @@ type Generic struct { TypeParams []*TypeParam // Type parameters to be substituted Body Type // Template type with TypeParam references hash uint64 + strCache stringCache } // NewGeneric creates a generic type definition. @@ -97,22 +98,24 @@ func NewGeneric(name string, params []*TypeParam, body Type) *Generic { func (g *Generic) Kind() kind.Kind { return kind.Generic } func (g *Generic) String() string { - var sb strings.Builder + return g.strCache.get(func() string { + var sb strings.Builder - sb.WriteString(g.Name) - sb.WriteString("<") + sb.WriteString(g.Name) + sb.WriteString("<") - for i, p := range g.TypeParams { - if i > 0 { - sb.WriteString(", ") - } + for i, p := range g.TypeParams { + if i > 0 { + sb.WriteString(", ") + } - sb.WriteString(p.String()) - } + sb.WriteString(p.String()) + } - sb.WriteString(">") + sb.WriteString(">") - return sb.String() + return sb.String() + }) } func (g *Generic) Hash() uint64 { return g.hash } func (g *Generic) Equals(other Type) bool { @@ -125,39 +128,47 @@ func (g *Generic) Equals(other Type) bool { // type is created with Generic=Array and TypeArgs=[number]. The body can // be expanded by substituting type parameters with arguments. type Instantiated struct { - Generic *Generic // The generic being instantiated - TypeArgs []Type // Concrete types for each type parameter - hash uint64 + Generic *Generic // The generic being instantiated + TypeArgs []Type // Concrete types for each type parameter + hash uint64 + softPrunable bool + strCache stringCache } // Instantiate creates an instantiated generic type with the given arguments. func Instantiate(g *Generic, args ...Type) *Instantiated { h := internal.HashCombine(uint64(kind.Instantiated), g.Hash()) + softPrunable := false for _, a := range args { h = internal.HashCombine(h, a.Hash()) + if !softPrunable && softPruneMayRewrite(a) { + softPrunable = true + } } - return &Instantiated{Generic: g, TypeArgs: args, hash: h} + return &Instantiated{Generic: g, TypeArgs: args, hash: h, softPrunable: softPrunable} } func (i *Instantiated) Kind() kind.Kind { return kind.Instantiated } func (i *Instantiated) String() string { - var sb strings.Builder + return i.strCache.get(func() string { + var sb strings.Builder - sb.WriteString(i.Generic.Name) - sb.WriteString("<") + sb.WriteString(i.Generic.Name) + sb.WriteString("<") - for j, a := range i.TypeArgs { - if j > 0 { - sb.WriteString(", ") - } + for j, a := range i.TypeArgs { + if j > 0 { + sb.WriteString(", ") + } - sb.WriteString(a.String()) - } + sb.WriteString(a.String()) + } - sb.WriteString(">") + sb.WriteString(">") - return sb.String() + return sb.String() + }) } func (i *Instantiated) Hash() uint64 { return i.hash } func (i *Instantiated) Equals(other Type) bool { diff --git a/types/typ/info.go b/types/typ/info.go index 83cae961..874b853b 100644 --- a/types/typ/info.go +++ b/types/typ/info.go @@ -17,7 +17,7 @@ type SpecInfo interface { } // RefinementInfo describes type refinements from function calls. -// Implemented by *constraint.FunctionEffect. +// Implemented by *constraint.FunctionRefinement. type RefinementInfo interface { internal.Equaler IsRefinementInfo() diff --git a/types/typ/intersection.go b/types/typ/intersection.go index cd8553c0..7e084229 100644 --- a/types/typ/intersection.go +++ b/types/typ/intersection.go @@ -19,8 +19,9 @@ import ( // // Members are sorted by hash for deterministic comparison. type Intersection struct { - Members []Type - hash uint64 + Members []Type + hash uint64 + strCache stringCache } // NewIntersection creates a normalized intersection type. @@ -119,16 +120,17 @@ func NewIntersection(members ...Type) Type { func (i *Intersection) Kind() kind.Kind { return kind.Intersection } func (i *Intersection) String() string { - parts := make([]string, len(i.Members)) - for j, m := range i.Members { - if m == nil { - parts[j] = "unknown" - } else { - parts[j] = m.String() + return i.strCache.get(func() string { + parts := make([]string, len(i.Members)) + for j, m := range i.Members { + if m == nil { + parts[j] = "unknown" + } else { + parts[j] = m.String() + } } - } - - return strings.Join(parts, " & ") + return strings.Join(parts, " & ") + }) } func (i *Intersection) Hash() uint64 { return i.hash } diff --git a/types/typ/literal.go b/types/typ/literal.go index 2a296676..3f3f9097 100644 --- a/types/typ/literal.go +++ b/types/typ/literal.go @@ -21,6 +21,7 @@ type Literal struct { Base kind.Kind // Boolean, Number, Integer, or String Value any // bool, float64, int64, or string hash uint64 + str string } // True and False are singleton boolean literals. @@ -28,8 +29,8 @@ var ( trueHash = internal.HashCombine(internal.HashCombine(uint64(kind.Literal), uint64(kind.Boolean)), 1) falseHash = internal.HashCombine(uint64(kind.Literal), uint64(kind.Boolean)) - True = &Literal{Base: kind.Boolean, Value: true, hash: trueHash} - False = &Literal{Base: kind.Boolean, Value: false, hash: falseHash} + True = &Literal{Base: kind.Boolean, Value: true, hash: trueHash, str: "true"} + False = &Literal{Base: kind.Boolean, Value: false, hash: falseHash, str: "false"} ) // LiteralBool returns the canonical boolean literal type. @@ -46,7 +47,7 @@ func LiteralInt(v int64) *Literal { h := internal.HashCombine(uint64(kind.Literal), uint64(kind.Integer)) h = internal.HashCombine(h, uint64(v)) - return &Literal{Base: kind.Integer, Value: v, hash: h} + return &Literal{Base: kind.Integer, Value: v, hash: h, str: strconv.FormatInt(v, 10)} } // LiteralNumber creates a number literal type. @@ -54,7 +55,7 @@ func LiteralNumber(v float64) *Literal { h := internal.HashCombine(uint64(kind.Literal), uint64(kind.Number)) h = internal.HashCombine(h, uint64(v)) - return &Literal{Base: kind.Number, Value: v, hash: h} + return &Literal{Base: kind.Number, Value: v, hash: h, str: strconv.FormatFloat(v, 'g', -1, 64)} } // LiteralString creates a string literal type. @@ -62,12 +63,15 @@ func LiteralString(v string) *Literal { h := internal.HashCombine(uint64(kind.Literal), uint64(kind.String)) h = internal.HashCombine(h, internal.FnvString(v)) - return &Literal{Base: kind.String, Value: v, hash: h} + return &Literal{Base: kind.String, Value: v, hash: h, str: strconv.Quote(v)} } func (l *Literal) Kind() kind.Kind { return kind.Literal } func (l *Literal) String() string { + if l.str != "" { + return l.str + } switch l.Base { case kind.Boolean: if l.Value.(bool) { diff --git a/types/typ/nominal.go b/types/typ/nominal.go index 7322c5de..67f07eff 100644 --- a/types/typ/nominal.go +++ b/types/typ/nominal.go @@ -24,6 +24,7 @@ type Sum struct { Name string // Type name for display Variants []Variant // Possible cases hash uint64 + strCache stringCache } // NewSum creates a sum type. @@ -45,37 +46,39 @@ func NewSum(name string, variants []Variant) *Sum { func (s *Sum) Kind() kind.Kind { return kind.Sum } func (s *Sum) String() string { - var sb strings.Builder + return s.strCache.get(func() string { + var sb strings.Builder - sb.WriteString("enum ") - sb.WriteString(s.Name) - sb.WriteString(" { ") + sb.WriteString("enum ") + sb.WriteString(s.Name) + sb.WriteString(" { ") - for i, v := range s.Variants { - if i > 0 { - sb.WriteString(" | ") - } + for i, v := range s.Variants { + if i > 0 { + sb.WriteString(" | ") + } + + sb.WriteString(v.Tag) - sb.WriteString(v.Tag) + if len(v.Types) > 0 { + sb.WriteString("(") - if len(v.Types) > 0 { - sb.WriteString("(") + for j, t := range v.Types { + if j > 0 { + sb.WriteString(", ") + } - for j, t := range v.Types { - if j > 0 { - sb.WriteString(", ") + sb.WriteString(t.String()) } - sb.WriteString(t.String()) + sb.WriteString(")") } - - sb.WriteString(")") } - } - sb.WriteString(" }") + sb.WriteString(" }") - return sb.String() + return sb.String() + }) } func (s *Sum) Hash() uint64 { return s.hash } @@ -121,9 +124,10 @@ type Method struct { // Named interfaces (Name != "") use nominal identity for marker interfaces // (interfaces with no methods, like Channel). type Interface struct { - Name string // Interface name (empty for anonymous) - Methods []Method // Required methods - hash uint64 + Name string // Interface name (empty for anonymous) + Methods []Method // Required methods + hash uint64 + strCache stringCache } // NewInterface creates an interface type. @@ -143,28 +147,28 @@ func NewInterface(name string, methods []Method) *Interface { func (i *Interface) Kind() kind.Kind { return kind.Interface } func (i *Interface) String() string { - // Named interfaces render as just the name - if i.Name != "" { - return i.Name - } + return i.strCache.get(func() string { + if i.Name != "" { + return i.Name + } - // Anonymous interfaces expand methods - var sb strings.Builder - sb.WriteString("interface { ") + var sb strings.Builder + sb.WriteString("interface { ") - for j, m := range i.Methods { - if j > 0 { - sb.WriteString("; ") - } + for j, m := range i.Methods { + if j > 0 { + sb.WriteString("; ") + } - sb.WriteString(m.Name) - sb.WriteString(": ") - sb.WriteString(m.Type.String()) - } + sb.WriteString(m.Name) + sb.WriteString(": ") + sb.WriteString(m.Type.String()) + } - sb.WriteString(" }") + sb.WriteString(" }") - return sb.String() + return sb.String() + }) } func (i *Interface) Hash() uint64 { return i.hash } diff --git a/types/typ/optional.go b/types/typ/optional.go index cd3c1adc..c9f983f4 100644 --- a/types/typ/optional.go +++ b/types/typ/optional.go @@ -13,8 +13,10 @@ import ( // The Inner field holds the non-nil type. An Optional never contains // another Optional (they are flattened during construction). type Optional struct { - Inner Type - hash uint64 + Inner Type + hash uint64 + softPrunable bool + strCache stringCache } // NewOptional creates an optional type (T | nil). @@ -48,16 +50,18 @@ func NewOptional(inner Type) Type { h := internal.HashCombine(uint64(kind.Optional), inner.Hash()) - return &Optional{Inner: inner, hash: h} + return &Optional{Inner: inner, hash: h, softPrunable: softPruneMayRewrite(inner)} } func (o *Optional) Kind() kind.Kind { return kind.Optional } func (o *Optional) String() string { - if o.Inner == nil { - return "nil?" - } - return o.Inner.String() + "?" + return o.strCache.get(func() string { + if o.Inner == nil { + return "nil?" + } + return o.Inner.String() + "?" + }) } func (o *Optional) Hash() uint64 { return o.hash } diff --git a/types/typ/rebuild.go b/types/typ/rebuild.go index 981cc371..8c580461 100644 --- a/types/typ/rebuild.go +++ b/types/typ/rebuild.go @@ -50,16 +50,18 @@ func buildFunctionType( copy(paramsCopy, params) returnsCopy := make([]Type, len(returns)) copy(returnsCopy, returns) + softPrunable := softPruneParams(paramsCopy) || softPruneAny(variadic) || softPruneAny(returnsCopy...) return &Function{ - TypeParams: typeParamsCopy, - Params: paramsCopy, - Variadic: variadic, - Returns: returnsCopy, - Effects: effects, - Spec: spec, - Refinement: refinement, - hash: h, + TypeParams: typeParamsCopy, + Params: paramsCopy, + Variadic: variadic, + Returns: returnsCopy, + Effects: effects, + Spec: spec, + Refinement: refinement, + hash: h, + softPrunable: softPrunable, } } @@ -110,15 +112,17 @@ func buildRecordType(fields []Field, metatable, mapKey, mapValue Type, open bool h = internal.HashCombine(h, recordMapValueHash) h = internal.HashCombine(h, mapValue.Hash()) } + softPrunable := softPruneFields(sorted) || softPruneAny(metatable, mapKey, mapValue) return &Record{ - Fields: sorted, - Metatable: metatable, - MapKey: mapKey, - MapValue: mapValue, - Open: open, - sorted: true, - hash: h, + Fields: sorted, + Metatable: metatable, + MapKey: mapKey, + MapValue: mapValue, + Open: open, + sorted: true, + hash: h, + softPrunable: softPrunable, } } diff --git a/types/typ/record.go b/types/typ/record.go index 716e8e58..333acaea 100644 --- a/types/typ/record.go +++ b/types/typ/record.go @@ -27,13 +27,15 @@ type Field struct { // // Fields are sorted by name for deterministic hashing and comparison. type Record struct { - Fields []Field - Metatable Type // Metatable type for metamethod lookup - MapKey Type // Map component key type (nil if no map component) - MapValue Type // Map component value type (nil if no map component) - Open bool // Allow access to undefined fields - sorted bool - hash uint64 + Fields []Field + Metatable Type // Metatable type for metamethod lookup + MapKey Type // Map component key type (nil if no map component) + MapValue Type // Map component value type (nil if no map component) + Open bool // Allow access to undefined fields + sorted bool + hash uint64 + softPrunable bool + strCache stringCache } // RecordBuilder provides a fluent API for constructing record types. @@ -119,61 +121,63 @@ func (b *RecordBuilder) Build() *Record { func (r *Record) Kind() kind.Kind { return kind.Record } func (r *Record) String() string { - var sb strings.Builder + return r.strCache.get(func() string { + var sb strings.Builder - sb.WriteString("{") + sb.WriteString("{") - for i, f := range r.Fields { - if i > 0 { - sb.WriteString(", ") - } + for i, f := range r.Fields { + if i > 0 { + sb.WriteString(", ") + } - if f.Readonly { - sb.WriteString("readonly ") - } + if f.Readonly { + sb.WriteString("readonly ") + } - sb.WriteString(f.Name) + sb.WriteString(f.Name) - if f.Optional { - sb.WriteString("?") - } + if f.Optional { + sb.WriteString("?") + } - sb.WriteString(": ") - if f.Type != nil { - sb.WriteString(f.Type.String()) - } else { - sb.WriteString("unknown") + sb.WriteString(": ") + if f.Type != nil { + sb.WriteString(f.Type.String()) + } else { + sb.WriteString("unknown") + } } - } - if r.HasMapComponent() { - if len(r.Fields) > 0 { - sb.WriteString(", ") + if r.HasMapComponent() { + if len(r.Fields) > 0 { + sb.WriteString(", ") + } + sb.WriteString("[") + if r.MapKey != nil { + sb.WriteString(r.MapKey.String()) + } else { + sb.WriteString("unknown") + } + sb.WriteString("]: ") + if r.MapValue != nil { + sb.WriteString(r.MapValue.String()) + } else { + sb.WriteString("unknown") + } } - sb.WriteString("[") - if r.MapKey != nil { - sb.WriteString(r.MapKey.String()) - } else { - sb.WriteString("unknown") - } - sb.WriteString("]: ") - if r.MapValue != nil { - sb.WriteString(r.MapValue.String()) - } else { - sb.WriteString("unknown") - } - } - if r.Open { - if len(r.Fields) > 0 || r.HasMapComponent() { - sb.WriteString(", ") + if r.Open { + if len(r.Fields) > 0 || r.HasMapComponent() { + sb.WriteString(", ") + } + sb.WriteString("...") } - sb.WriteString("...") - } - sb.WriteString("}") + sb.WriteString("}") - return sb.String() + return sb.String() + }) } func (r *Record) Hash() uint64 { return r.hash } diff --git a/types/typ/ref.go b/types/typ/ref.go index 9453a234..a2f1a1b5 100644 --- a/types/typ/ref.go +++ b/types/typ/ref.go @@ -17,6 +17,7 @@ type Ref struct { Module string // Module path (empty for local references) Name string // Type name hash uint64 + str string } // NewRef creates a type reference. @@ -24,17 +25,17 @@ func NewRef(module, name string) *Ref { h := internal.HashCombine(uint64(kind.Ref), internal.FnvString(module)) h = internal.HashCombine(h, internal.FnvString(name)) - return &Ref{Module: module, Name: name, hash: h} + str := name + if module != "" { + str = module + "." + name + } + return &Ref{Module: module, Name: name, hash: h, str: str} } func (r *Ref) Kind() kind.Kind { return kind.Ref } func (r *Ref) String() string { - if r.Module == "" { - return r.Name - } - - return r.Module + "." + r.Name + return r.str } func (r *Ref) Hash() uint64 { return r.hash } @@ -57,9 +58,11 @@ func (r *Ref) Equals(other Type) bool { // // Example: type UserId = number creates Alias{Name: "UserId", Target: number} type Alias struct { - Name string // Alias name - Target Type // Underlying type - hash uint64 + Name string // Alias name + Target Type // Underlying type + unaliased Type + hash uint64 + softPrunable bool } // NewAlias creates a type alias. @@ -67,18 +70,50 @@ func NewAlias(name string, target Type) *Alias { h := internal.HashCombine(uint64(kind.Alias), internal.FnvString(name)) h = internal.HashCombine(h, target.Hash()) - return &Alias{Name: name, Target: target, hash: h} + return &Alias{ + Name: name, + Target: target, + unaliased: flattenAliasTarget(target), + hash: h, + softPrunable: softPruneMayRewrite(target), + } } func (a *Alias) Kind() kind.Kind { return kind.Alias } func (a *Alias) String() string { return a.Name } func (a *Alias) Hash() uint64 { return a.hash } +func (a *Alias) UnaliasedTarget() Type { + if a == nil || a.unaliased == nil { + return a.Target + } + return a.unaliased +} + // Equals compares structurally through the alias target. func (a *Alias) Equals(other Type) bool { return TypeEquals(a.Target, other) } +func flattenAliasTarget(target Type) Type { + current := target + for depth := 0; depth < DefaultRecursionDepth; depth++ { + alias, ok := current.(*Alias) + if !ok || alias == nil { + return current + } + next := alias.Target + if alias.unaliased != nil { + next = alias.unaliased + } + if next == nil || next == current { + return current + } + current = next + } + return current +} + // Platform represents a platform-specific opaque type. // // Platform types are provided by the runtime environment and have @@ -116,8 +151,9 @@ func (p *Platform) Equals(other Type) bool { // // Example: typeof(Point) has type Meta{Of: Point} type Meta struct { - Of Type // The type being wrapped - hash uint64 + Of Type // The type being wrapped + hash uint64 + strCache stringCache } // NewMeta creates a metatype. @@ -127,8 +163,12 @@ func NewMeta(of Type) *Meta { } func (m *Meta) Kind() kind.Kind { return kind.Meta } -func (m *Meta) String() string { return "typeof(" + m.Of.String() + ")" } -func (m *Meta) Hash() uint64 { return m.hash } +func (m *Meta) String() string { + return m.strCache.get(func() string { + return "typeof(" + m.Of.String() + ")" + }) +} +func (m *Meta) Hash() uint64 { return m.hash } func (m *Meta) Equals(other Type) bool { if other.Kind() != kind.Meta { return false diff --git a/types/typ/soft.go b/types/typ/soft.go index 53e4123a..3a9bb4b0 100644 --- a/types/typ/soft.go +++ b/types/typ/soft.go @@ -4,7 +4,6 @@ import ( "sync" "github.com/wippyai/go-lua/internal" - "github.com/wippyai/go-lua/types/kind" ) // SoftPolicy controls how soft-placeholder detection behaves. @@ -92,7 +91,7 @@ func PruneSoftUnionMembers(t Type) Type { if t == nil { return nil } - if !softPruneCanDescend(t) { + if !softPruneMayRewrite(t) { return t } state := getSoftPruneState() @@ -145,26 +144,6 @@ func putSoftPruneState(state *softPruneState) { softPruneStatePool.Put(state) } -func softPruneCanDescend(t Type) bool { - if t == nil { - return false - } - switch t.Kind() { - case kind.Optional, - kind.Union, - kind.Array, - kind.Map, - kind.Tuple, - kind.Function, - kind.Record, - kind.Alias, - kind.Instantiated: - return true - default: - return false - } -} - func pruneSoftUnionMembersMemo( t Type, guard internal.RecursionGuard, @@ -175,7 +154,7 @@ func pruneSoftUnionMembersMemo( if t == nil { return t } - if !softPruneCanDescend(t) { + if !softPruneMayRewrite(t) { return t } if cached, ok := memo[t]; ok { @@ -197,6 +176,32 @@ func pruneSoftUnionMembersMemo( case *Record: out = pruneSoftRecord(node, t, next, memo, visiting, softMemo) case *Union: + // Fast path: if this union has no soft members and no nested changes needed, + // skip the expensive isSoftWithMemo checks entirely. + if !node.HasSoftMember() { + anyChildChanged := false + var rewrittenFast []Type + for idx, m := range node.Members { + pm := pruneSoftUnionMembersMemo(m, next, memo, visiting, softMemo) + if pm != m { + if rewrittenFast == nil { + rewrittenFast = make([]Type, len(node.Members)) + copy(rewrittenFast, node.Members) + } + rewrittenFast[idx] = pm + anyChildChanged = true + } else if rewrittenFast != nil { + rewrittenFast[idx] = m + } + } + if !anyChildChanged { + out = t + } else { + out = NewUnion(rewrittenFast...) + } + break + } + var rewritten []Type softCount := 0 changed := false diff --git a/types/typ/soft_flags.go b/types/typ/soft_flags.go new file mode 100644 index 00000000..163c60cd --- /dev/null +++ b/types/typ/soft_flags.go @@ -0,0 +1,59 @@ +package typ + +// softPruneMayRewrite reports whether PruneSoftUnionMembers can rewrite t or +// any of its descendants. The flag is computed once at type construction time +// for immutable structural types so hot paths can skip recursive descent. +func softPruneMayRewrite(t Type) bool { + if t == nil { + return false + } + switch node := unwrapTransparentSoft(t).(type) { + case *Union: + return node.softPrunable + case *Optional: + return node.softPrunable + case *Array: + return node.softPrunable + case *Map: + return node.softPrunable + case *Tuple: + return node.softPrunable + case *Function: + return node.softPrunable + case *Record: + return node.softPrunable + case *Alias: + return node.softPrunable + case *Instantiated: + return node.softPrunable + default: + return false + } +} + +func softPruneAny(types ...Type) bool { + for _, t := range types { + if softPruneMayRewrite(t) { + return true + } + } + return false +} + +func softPruneParams(params []Param) bool { + for _, p := range params { + if softPruneMayRewrite(p.Type) { + return true + } + } + return false +} + +func softPruneFields(fields []Field) bool { + for _, f := range fields { + if softPruneMayRewrite(f.Type) { + return true + } + } + return false +} diff --git a/types/typ/string_cache.go b/types/typ/string_cache.go new file mode 100644 index 00000000..af291879 --- /dev/null +++ b/types/typ/string_cache.go @@ -0,0 +1,15 @@ +package typ + +import "sync" + +type stringCache struct { + once sync.Once + value string +} + +func (c *stringCache) get(build func() string) string { + c.once.Do(func() { + c.value = build() + }) + return c.value +} diff --git a/types/typ/subst/subst.go b/types/typ/subst/subst.go index f5979d76..d4b1845c 100644 --- a/types/typ/subst/subst.go +++ b/types/typ/subst/subst.go @@ -5,6 +5,8 @@ package subst import ( + "sync" + "github.com/wippyai/go-lua/internal" "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/typ" @@ -66,18 +68,55 @@ func Self(t typ.Type, selfType typ.Type) typ.Type { // // Does not enforce generic constraints; use subtype checking for that. func ExpandInstantiated(t typ.Type) typ.Type { - return expandInstantiatedWithDepth(t, typ.DeepRecursionDepth) + if t == nil || !expandInstantiatedCanDescend(t) { + return t + } + memo := getExpandMemo() + defer putExpandMemo(memo) + guard := typ.GuardForDepth(typ.DefaultRecursionDepth) + return expandInstantiatedGuard(t, guard, memo) +} + +const expandMemoMaxEntries = 2048 + +var expandMemoPool = sync.Pool{ + New: func() any { + return make(map[typ.Type]typ.Type, 32) + }, +} + +func getExpandMemo() map[typ.Type]typ.Type { + return expandMemoPool.Get().(map[typ.Type]typ.Type) +} + +func putExpandMemo(m map[typ.Type]typ.Type) { + if len(m) > expandMemoMaxEntries { + expandMemoPool.Put(make(map[typ.Type]typ.Type, 32)) + return + } + clear(m) + expandMemoPool.Put(m) } func expandInstantiatedWithDepth(t typ.Type, maxDepth int) typ.Type { + if t == nil || !expandInstantiatedCanDescend(t) { + return t + } + memo := getExpandMemo() + defer putExpandMemo(memo) guard := typ.GuardForDepth(maxDepth) - return expandInstantiatedGuard(t, guard) + return expandInstantiatedGuard(t, guard, memo) } -func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type { +func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard, memo map[typ.Type]typ.Type) typ.Type { if t == nil || !expandInstantiatedCanDescend(t) { return t } + + if cached, ok := memo[t]; ok { + return cached + } + next, ok := guard.Enter(t) if !ok { return t @@ -92,6 +131,12 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type t = ann.Inner } + result := expandInstantiatedCore(t, orig, next, memo) + memo[orig] = result + return result +} + +func expandInstantiatedCore(t typ.Type, orig typ.Type, guard internal.RecursionGuard, memo map[typ.Type]typ.Type) typ.Type { switch v := t.(type) { case *typ.Instantiated: if v.Generic == nil || len(v.TypeArgs) != len(v.Generic.TypeParams) || v.Generic.Body == nil { @@ -99,9 +144,9 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type } body := Params(v.Generic.Body, v.Generic.TypeParams, v.TypeArgs) body = Self(body, orig) - return expandInstantiatedGuard(body, next) + return expandInstantiatedGuard(body, guard, memo) case *typ.Optional: - inner := expandInstantiatedGuard(v.Inner, next) + inner := expandInstantiatedGuard(v.Inner, guard, memo) if inner == v.Inner { return orig } @@ -109,7 +154,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type case *typ.Union: var members []typ.Type for i, m := range v.Members { - newMember := expandInstantiatedGuard(m, next) + newMember := expandInstantiatedGuard(m, guard, memo) if newMember != m { if members == nil { members = make([]typ.Type, len(v.Members)) @@ -127,7 +172,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type case *typ.Intersection: var members []typ.Type for i, m := range v.Members { - newMember := expandInstantiatedGuard(m, next) + newMember := expandInstantiatedGuard(m, guard, memo) if newMember != m { if members == nil { members = make([]typ.Type, len(v.Members)) @@ -143,14 +188,14 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type } return typ.NewIntersection(members...) case *typ.Array: - elem := expandInstantiatedGuard(v.Element, next) + elem := expandInstantiatedGuard(v.Element, guard, memo) if elem == v.Element { return orig } return typ.NewArray(elem) case *typ.Map: - key := expandInstantiatedGuard(v.Key, next) - value := expandInstantiatedGuard(v.Value, next) + key := expandInstantiatedGuard(v.Key, guard, memo) + value := expandInstantiatedGuard(v.Value, guard, memo) if key == v.Key && value == v.Value { return orig } @@ -158,7 +203,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type case *typ.Tuple: var elems []typ.Type for i, e := range v.Elements { - newElem := expandInstantiatedGuard(e, next) + newElem := expandInstantiatedGuard(e, guard, memo) if newElem != e { if elems == nil { elems = make([]typ.Type, len(v.Elements)) @@ -179,7 +224,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type for i, p := range v.Params { newType := p.Type if _, isInst := p.Type.(*typ.Instantiated); !isInst { - newType = expandInstantiatedGuard(p.Type, next) + newType = expandInstantiatedGuard(p.Type, guard, memo) } if newType != p.Type { if params == nil { @@ -195,7 +240,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type var returns []typ.Type for i, r := range v.Returns { - newRet := expandInstantiatedGuard(r, next) + newRet := expandInstantiatedGuard(r, guard, memo) if newRet != r { if returns == nil { returns = make([]typ.Type, len(v.Returns)) @@ -210,7 +255,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type variadic := v.Variadic if v.Variadic != nil { - newVariadic := expandInstantiatedGuard(v.Variadic, next) + newVariadic := expandInstantiatedGuard(v.Variadic, guard, memo) if newVariadic != v.Variadic { changed = true variadic = newVariadic @@ -257,7 +302,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type changed := false var fields []typ.Field for i, f := range v.Fields { - newType := expandInstantiatedGuard(f.Type, next) + newType := expandInstantiatedGuard(f.Type, guard, memo) if newType != f.Type { if fields == nil { fields = make([]typ.Field, len(v.Fields)) @@ -272,7 +317,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type metatable := v.Metatable if v.Metatable != nil { - newMetatable := expandInstantiatedGuard(v.Metatable, next) + newMetatable := expandInstantiatedGuard(v.Metatable, guard, memo) if newMetatable != v.Metatable { changed = true metatable = newMetatable @@ -282,11 +327,11 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type mapKey := v.MapKey mapValue := v.MapValue if v.HasMapComponent() { - mapKey = expandInstantiatedGuard(v.MapKey, next) + mapKey = expandInstantiatedGuard(v.MapKey, guard, memo) if mapKey != v.MapKey { changed = true } - mapValue = expandInstantiatedGuard(v.MapValue, next) + mapValue = expandInstantiatedGuard(v.MapValue, guard, memo) if mapValue != v.MapValue { changed = true } @@ -324,7 +369,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type } return builder.Build() case *typ.Alias: - target := expandInstantiatedGuard(v.Target, next) + target := expandInstantiatedGuard(v.Target, guard, memo) if target == v.Target { return orig } @@ -334,7 +379,7 @@ func expandInstantiatedGuard(t typ.Type, guard internal.RecursionGuard) typ.Type var methods []typ.Method for idx := range v.Methods { m := v.Methods[idx] - newType := expandInstantiatedGuard(m.Type, next) + newType := expandInstantiatedGuard(m.Type, guard, memo) fn, ok := newType.(*typ.Function) if !ok { fn = m.Type diff --git a/types/typ/union.go b/types/typ/union.go index d0d56fb6..a8090118 100644 --- a/types/typ/union.go +++ b/types/typ/union.go @@ -20,8 +20,11 @@ import ( // // Members are sorted by hash for deterministic comparison and serialization. type Union struct { - Members []Type - hash uint64 + Members []Type + hash uint64 + hasSoftMbr bool // true if any member is a soft placeholder + softPrunable bool + strCache stringCache } // NewUnion creates a normalized union type from the given members. @@ -188,28 +191,92 @@ func NewUnion(members ...Type) Type { return unique[0] } - // Compute hash + // Compute hash and check for soft members h := uint64(kind.Union) + hasSoft := false + hasNonSoft := false + softPrunable := false for _, m := range unique { h = internal.HashCombine(h, m.Hash()) + if softPruneMayRewrite(m) { + softPrunable = true + } + if memberIsSoft(m) { + hasSoft = true + } else { + hasNonSoft = true + } + } + if hasSoft && hasNonSoft { + softPrunable = true } - return &Union{Members: unique, hash: h} + return &Union{Members: unique, hash: h, hasSoftMbr: hasSoft, softPrunable: softPrunable} } +// HasSoftMember reports whether any union member is a soft placeholder type. +// Computed at construction time for O(1) access. +func (u *Union) HasSoftMember() bool { return u.hasSoftMbr } + func (u *Union) Kind() kind.Kind { return kind.Union } -func (u *Union) String() string { - parts := make([]string, len(u.Members)) - for i, m := range u.Members { - if m == nil { - parts[i] = "nil" - } else { - parts[i] = m.String() +// memberIsSoft checks if a type is a soft placeholder. +// Used at Union construction to set the hasSoftMbr flag. +// Mirrors isSoft logic but without recursion guard (unions are already flat). +func memberIsSoft(t Type) bool { + if t == nil { + return false + } + for { + if ann, ok := t.(*Annotated); ok && ann.Inner != nil { + t = ann.Inner + continue + } + break + } + if t.Kind().IsPlaceholder() { + return true + } + switch v := t.(type) { + case *Optional: + return memberIsSoft(v.Inner) + case *Alias: + return memberIsSoft(v.Target) + case *Array: + return memberIsSoft(v.Element) + case *Map: + return memberIsSoft(v.Value) + case *Record: + if len(v.Fields) == 0 && !v.HasMapComponent() { + return true + } + if v.HasMapComponent() && len(v.Fields) == 0 { + return memberIsSoft(v.MapValue) } + return false + case *Union: + for _, m := range v.Members { + if !memberIsSoft(m) { + return false + } + } + return len(v.Members) > 0 } + return false +} - return strings.Join(parts, " | ") +func (u *Union) String() string { + return u.strCache.get(func() string { + parts := make([]string, len(u.Members)) + for i, m := range u.Members { + if m == nil { + parts[i] = "nil" + } else { + parts[i] = m.String() + } + } + return strings.Join(parts, " | ") + }) } func (u *Union) Hash() uint64 { return u.hash } diff --git a/types/typ/unwrap/unwrap.go b/types/typ/unwrap/unwrap.go index 09fce4ac..e057b703 100644 --- a/types/typ/unwrap/unwrap.go +++ b/types/typ/unwrap/unwrap.go @@ -22,7 +22,7 @@ func underlyingDepth(t typ.Type, guard internal.RecursionGuard) typ.Type { return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[typ.Type] { return typ.Visitor[typ.Type]{ Alias: func(a *typ.Alias) typ.Type { - return underlyingDepth(a.Target, next) + return underlyingDepth(a.UnaliasedTarget(), next) }, Optional: func(o *typ.Optional) typ.Type { return underlyingDepth(o.Inner, next) @@ -43,7 +43,7 @@ func unwrapAliasDepth(t typ.Type, guard internal.RecursionGuard) typ.Type { return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[typ.Type] { return typ.Visitor[typ.Type]{ Alias: func(a *typ.Alias) typ.Type { - return unwrapAliasDepth(a.Target, next) + return unwrapAliasDepth(a.UnaliasedTarget(), next) }, Default: func(t typ.Type) typ.Type { return t @@ -62,7 +62,7 @@ func unwrapOptionalDepth(t typ.Type, guard internal.RecursionGuard) typ.Type { return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[typ.Type] { return typ.Visitor[typ.Type]{ Alias: func(a *typ.Alias) typ.Type { - return unwrapOptionalDepth(a.Target, next) + return unwrapOptionalDepth(a.UnaliasedTarget(), next) }, Optional: func(o *typ.Optional) typ.Type { return unwrapOptionalDepth(o.Inner, next) @@ -171,8 +171,14 @@ func unwrapFunctionDepth(t typ.Type, guard internal.RecursionGuard) *typ.Functio Optional: func(o *typ.Optional) *typ.Function { return unwrapFunctionDepth(o.Inner, next) }, + Recursive: func(rec *typ.Recursive) *typ.Function { + if rec.Body == nil || rec.Body == rec { + return nil + } + return unwrapFunctionDepth(rec.Body, next) + }, Alias: func(a *typ.Alias) *typ.Function { - return unwrapFunctionDepth(a.Target, next) + return unwrapFunctionDepth(a.UnaliasedTarget(), next) }, Default: func(t typ.Type) *typ.Function { return nil @@ -192,8 +198,14 @@ func unwrapRecordDepth(t typ.Type, guard internal.RecursionGuard) *typ.Record { Record: func(rec *typ.Record) *typ.Record { return rec }, + Recursive: func(rec *typ.Recursive) *typ.Record { + if rec.Body == nil || rec.Body == rec { + return nil + } + return unwrapRecordDepth(rec.Body, next) + }, Alias: func(a *typ.Alias) *typ.Record { - return unwrapRecordDepth(a.Target, next) + return unwrapRecordDepth(a.UnaliasedTarget(), next) }, Optional: func(o *typ.Optional) *typ.Record { return unwrapRecordDepth(o.Inner, next) @@ -223,8 +235,14 @@ func unwrapUnionDepth(t typ.Type, guard internal.RecursionGuard) *typ.Union { Union: func(u *typ.Union) *typ.Union { return u }, + Recursive: func(rec *typ.Recursive) *typ.Union { + if rec.Body == nil || rec.Body == rec { + return nil + } + return unwrapUnionDepth(rec.Body, next) + }, Alias: func(a *typ.Alias) *typ.Union { - return unwrapUnionDepth(a.Target, next) + return unwrapUnionDepth(a.UnaliasedTarget(), next) }, Optional: func(o *typ.Optional) *typ.Union { return unwrapUnionDepth(o.Inner, next) @@ -269,8 +287,14 @@ func unwrapToKindDepth(t typ.Type, k kind.Kind, guard internal.RecursionGuard) t } return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[typ.Type] { return typ.Visitor[typ.Type]{ + Recursive: func(rec *typ.Recursive) typ.Type { + if rec.Body == nil || rec.Body == rec { + return nil + } + return unwrapToKindDepth(rec.Body, k, next) + }, Alias: func(a *typ.Alias) typ.Type { - return unwrapToKindDepth(a.Target, k, next) + return unwrapToKindDepth(a.UnaliasedTarget(), k, next) }, Default: func(t typ.Type) typ.Type { return nil