diff --git a/compiler/check/synth/ops/check.go b/compiler/check/synth/ops/check.go index ae02a30e..e7366150 100644 --- a/compiler/check/synth/ops/check.go +++ b/compiler/check/synth/ops/check.go @@ -98,17 +98,19 @@ func CheckTable(fields []FieldDef, arrayElems []typ.Type, expected typ.Type) Che return CheckResult{Type: expected, Errors: errors} } - switch expected.Kind() { + unwrapped := typ.UnwrapAnnotated(expected) + + switch unwrapped.Kind() { case kind.Array: - return checkTableAsArray(fields, arrayElems, expected.(*typ.Array)) + return checkTableAsArray(fields, arrayElems, unwrapped.(*typ.Array)) case kind.Map: - return checkTableAsMap(fields, arrayElems, expected.(*typ.Map)) + return checkTableAsMap(fields, arrayElems, unwrapped.(*typ.Map)) case kind.Record: - return checkTableAsRecord(fields, arrayElems, expected.(*typ.Record)) + return checkTableAsRecord(fields, arrayElems, unwrapped.(*typ.Record)) case kind.Tuple: - return checkTableAsTuple(arrayElems, expected.(*typ.Tuple)) + return checkTableAsTuple(arrayElems, unwrapped.(*typ.Tuple)) case kind.Union: - return checkTableAsUnion(fields, arrayElems, expected.(*typ.Union)) + return checkTableAsUnion(fields, arrayElems, unwrapped.(*typ.Union)) default: // Try synthesis and check compatibility synthesized := tableConstructor(fields, arrayElems) diff --git a/ltype.go b/ltype.go index 39e23e45..f1586dc3 100644 --- a/ltype.go +++ b/ltype.go @@ -327,6 +327,10 @@ func validateValue(val LValue, t typ.Type, resolver *typeResolver) bool { } return false + case *typ.Interface: + _, ok := val.(*LTable) + return ok + case *typ.Generic: // Generic types need to be instantiated before validation return false @@ -536,6 +540,12 @@ func validateWithErrorResolver(val LValue, t typ.Type, resolver *typeResolver, p } return false, formatValidationError(path, typeName, luaTypeName(val)) + case *typ.Interface: + if _, ok := val.(*LTable); ok { + return true, "" + } + return false, formatValidationError(path, "table", luaTypeName(val)) + case *typ.Instantiated: expanded := subst.ExpandInstantiated(tt) if expanded != nil && expanded != tt { diff --git a/ltype_test.go b/ltype_test.go index d9f914f1..78e29a29 100644 --- a/ltype_test.go +++ b/ltype_test.go @@ -1506,3 +1506,85 @@ func TestLTypeCallAsLastArg(t *testing.T) { t.Errorf("typeCall as non-last arg: %v", err) } } + +func TestLTypeInterfaceTable(t *testing.T) { + L := NewState() + defer L.Close() + + tableType := <ype{inner: typ.NewInterface("table", nil)} + + plainTable := L.NewTable() + plainTable.RawSetString("key", LString("value")) + + tableWithArray := L.NewTable() + tableWithArray.RawSetString("tags", L.NewTable()) + tableWithArray.RawGet(LString("tags")).(*LTable).Append(LString("a")) + tableWithArray.RawGet(LString("tags")).(*LTable).Append(LString("b")) + + tests := []struct { + name string + value LValue + expected bool + }{ + {"table accepts plain table", plainTable, true}, + {"table accepts table with nested arrays", tableWithArray, true}, + {"table accepts empty table", L.NewTable(), true}, + {"table rejects string", LString("hello"), false}, + {"table rejects number", LNumber(42), false}, + {"table rejects nil", LNil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tableType.Validate(L, tt.value) + if result != tt.expected { + t.Errorf("Validate() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestLTypeRecordWithTableField(t *testing.T) { + L := NewState() + defer L.Close() + + // {kind: string, data: table} + commandType := <ype{ + inner: typ.NewRecord(). + Field("kind", typ.String). + Field("data", typ.NewInterface("table", nil)). + Build(), + } + + valid := L.NewTable() + valid.RawSetString("kind", LString("create")) + data := L.NewTable() + data.RawSetString("name", LString("test")) + tags := L.NewTable() + tags.Append(LString("a")) + tags.Append(LString("b")) + data.RawSetString("tags", tags) + valid.RawSetString("data", data) + + invalid := L.NewTable() + invalid.RawSetString("kind", LString("create")) + invalid.RawSetString("data", LString("not a table")) + + tests := []struct { + name string + value LValue + expected bool + }{ + {"record with table field accepts nested arrays", valid, true}, + {"record with table field rejects non-table data", invalid, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := commandType.Validate(L, tt.value) + if result != tt.expected { + t.Errorf("Validate() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/pcall_yield_test.go b/pcall_yield_test.go index 61dbbafb..2edc9cea 100644 --- a/pcall_yield_test.go +++ b/pcall_yield_test.go @@ -969,3 +969,115 @@ func TestPcallErrorWithGoFunctionCallAfter(t *testing.T) { } }) } + +// TestPooledStateYieldedFlagReset_ResetLState verifies that resetLState +// (called during Close/pool return) clears the yielded flag. +func TestPooledStateYieldedFlagReset_ResetLState(t *testing.T) { + L := NewState() + L.yielded = true + resetLState(L) + if L.yielded { + t.Fatal("resetLState must clear yielded flag") + } +} + +// TestPooledStateYieldedFlagReset_NewLState verifies that newLState +// clears yielded on a state retrieved from the pool. +func TestPooledStateYieldedFlagReset_NewLState(t *testing.T) { + for statePool.Get() != nil { + } + + dirty := NewState() + dirty.yielded = true + statePool.Put(dirty) + + reused := NewState() + defer reused.Close() + if reused.yielded { + t.Fatal("newLState must reset yielded on pooled state") + } +} + +// TestPooledStateYieldedFlagReset_NewLStateWithGAndAlloc verifies that +// newLStateWithGAndAlloc clears yielded on a state retrieved from the pool. +func TestPooledStateYieldedFlagReset_NewLStateWithGAndAlloc(t *testing.T) { + for statePool.Get() != nil { + } + + // Put two dirty states: one for NewState (parent), one for NewThreadWithContext + d1 := NewState() + d1.yielded = true + d2 := NewState() + d2.yielded = true + statePool.Put(d1) + statePool.Put(d2) + + parent := NewState() + defer parent.Close() + thread := parent.NewThreadWithContext(context.TODO()) + if thread.yielded { + t.Fatal("newLStateWithGAndAlloc must reset yielded on pooled state") + } +} + +// TestPooledStateYieldedFlagReset_EndToEnd simulates the real-world scenario: +// a coroutine yields, its state is pooled, then reused for a new execution. +func TestPooledStateYieldedFlagReset_EndToEnd(t *testing.T) { + for statePool.Get() != nil { + } + + // Phase 1: Yield inside a coroutine, then close the parent to pool everything + func() { + L := NewState() + if err := L.DoString(` + function yielder() + coroutine.yield("y") + return "done" + end + `); err != nil { + t.Fatal(err) + } + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("yielder").(*LFunction) + state, _, err := L.Resume(co, fn) + if err != nil { + t.Fatalf("resume: %v", err) + } + if state != ResumeYield { + t.Fatalf("expected ResumeYield, got %v", state) + } + if !co.yielded { + t.Fatal("co.yielded must be true after yield") + } + L.Close() + }() + + // Phase 2: Reuse pooled states - everything must work correctly + L := NewState() + defer L.Close() + + if err := L.DoString(` + function compute() + local ok, val = pcall(function() + return "result" + end) + if not ok then error("pcall failed") end + return val + end + `); err != nil { + t.Fatal(err) + } + + co := L.NewThreadWithContext(context.TODO()) + fn := L.GetGlobal("compute").(*LFunction) + state, results, err := L.Resume(co, fn) + if err != nil { + t.Fatalf("resume: %v", err) + } + if state != ResumeOK { + t.Fatalf("expected ResumeOK, got %v", state) + } + if len(results) < 1 || results[0].String() != "result" { + t.Fatalf("expected 'result', got %v", results) + } +} diff --git a/state.go b/state.go index ffb068c1..2e0b3dae 100644 --- a/state.go +++ b/state.go @@ -455,6 +455,7 @@ func newLState(options Options) *LState { ls.stop = 0 ls.currentFrame = nil ls.wrapped = false + ls.yielded = false ls.uvcache = nil ls.hasErrorFunc = false ls.mainLoop = mainLoop diff --git a/state_pool.go b/state_pool.go index 5b03ead2..ba9e8768 100644 --- a/state_pool.go +++ b/state_pool.go @@ -43,6 +43,7 @@ func resetLState(ls *LState) { ls.G = nil ls.hasErrorFunc = false ls.wrapped = false + ls.yielded = false // Clear frame extensions to prevent stale continuations from being invoked ls.frameExt = nil @@ -79,6 +80,7 @@ func newLStateWithGAndAlloc(options Options, G *Global, env *LTable, parentAlloc ls.mainLoop = mainLoop ls.alloc = parentAlloc ls.stop = 0 + ls.yielded = false ls.ctx = nil ls.ctxDone = nil diff --git a/types/subtype/normalize.go b/types/subtype/normalize.go index 63af3192..829a8aa7 100644 --- a/types/subtype/normalize.go +++ b/types/subtype/normalize.go @@ -226,12 +226,14 @@ func flattenUnion(acc []typ.Type, types []typ.Type) []typ.Type { continue } - switch t.Kind() { + unwrapped := typ.UnwrapAnnotated(t) + + switch unwrapped.Kind() { case kind.Union: - acc = flattenUnion(acc, t.(*typ.Union).Members) + acc = flattenUnion(acc, unwrapped.(*typ.Union).Members) case kind.Optional: acc = append(acc, typ.Nil) - acc = flattenUnion(acc, []typ.Type{t.(*typ.Optional).Inner}) + acc = flattenUnion(acc, []typ.Type{unwrapped.(*typ.Optional).Inner}) default: acc = append(acc, t) } @@ -248,7 +250,7 @@ func flattenIntersection(acc []typ.Type, types []typ.Type) []typ.Type { continue } - if i, ok := t.(*typ.Intersection); ok { + if i, ok := typ.UnwrapAnnotated(t).(*typ.Intersection); ok { acc = flattenIntersection(acc, i.Members) } else { acc = append(acc, t) diff --git a/types/typ/intersection.go b/types/typ/intersection.go index bd3119d4..0fa37916 100644 --- a/types/typ/intersection.go +++ b/types/typ/intersection.go @@ -45,7 +45,9 @@ func NewIntersection(members ...Type) Type { continue } - switch m.Kind() { + unwrapped := UnwrapAnnotated(m) + + switch unwrapped.Kind() { case kind.Any: continue // Any is identity for intersection case kind.Never: @@ -53,7 +55,7 @@ func NewIntersection(members ...Type) Type { case kind.Nil: hasNil = true case kind.Intersection: - flat = append(flat, m.(*Intersection).Members...) + flat = append(flat, unwrapped.(*Intersection).Members...) default: flat = append(flat, m) } diff --git a/types/typ/optional.go b/types/typ/optional.go index 8d16636d..153eb450 100644 --- a/types/typ/optional.go +++ b/types/typ/optional.go @@ -38,7 +38,7 @@ func NewOptional(inner Type) Type { } if inner.Kind() == kind.Union { - u := inner.(*Union) + u := UnwrapAnnotated(inner).(*Union) members := make([]Type, 0, len(u.Members)+1) members = append(members, Nil) members = append(members, u.Members...) diff --git a/types/typ/optional_test.go b/types/typ/optional_test.go index 028b662c..674bd9f6 100644 --- a/types/typ/optional_test.go +++ b/types/typ/optional_test.go @@ -132,3 +132,14 @@ func TestOptionalOfUnionWithoutNil(t *testing.T) { t.Error("Normalized union should contain Nil") } } + +func TestOptionalAnnotatedUnion(t *testing.T) { + // NewOptional with Annotated wrapping Union should not panic + inner := NewUnion(String, Number) + annotated := NewAnnotated(inner, []Annotation{{Name: "max_len", Arg: int64(255)}}) + o := NewOptional(annotated) + + if o == nil { + t.Fatal("optional should not be nil") + } +} diff --git a/types/typ/union.go b/types/typ/union.go index 1e101c21..9c79f9a3 100644 --- a/types/typ/union.go +++ b/types/typ/union.go @@ -48,7 +48,13 @@ func NewUnion(members ...Type) Type { return } - switch m.Kind() { + // Unwrap Annotated to access structural type for flattening. + // Annotations delegate Kind() to their inner type, so type + // assertions on concrete wrappers (Union, Optional) require + // operating on the unwrapped type. + unwrapped := UnwrapAnnotated(m) + + switch unwrapped.Kind() { case kind.Never: return // Never is identity for union case kind.Unknown: @@ -59,12 +65,12 @@ func NewUnion(members ...Type) Type { case kind.Nil: hasNil = true case kind.Union: - for _, member := range m.(*Union).Members { + for _, member := range unwrapped.(*Union).Members { addMember(member) } case kind.Optional: hasNil = true - addMember(m.(*Optional).Inner) + addMember(unwrapped.(*Optional).Inner) default: flat = append(flat, m) } diff --git a/types/typ/union_test.go b/types/typ/union_test.go index 78525907..c0ab77b1 100644 --- a/types/typ/union_test.go +++ b/types/typ/union_test.go @@ -335,3 +335,24 @@ func TestUnionNestedOptionalAndUnionNilDedups(t *testing.T) { t.Fatalf("expected exactly one nil in union, got %d in %v", nilCount, u) } } + +func TestUnionAnnotatedOptionalMember(t *testing.T) { + // Annotated wrapping Optional should not panic during union construction + annotatedOpt := NewAnnotated(NewOptional(String), []Annotation{{Name: "min_len", Arg: int64(1)}}) + u := NewUnion(annotatedOpt, Number) + + if u == nil { + t.Fatal("union should not be nil") + } +} + +func TestUnionAnnotatedUnionMember(t *testing.T) { + // Annotated wrapping Union should not panic during union construction + inner := NewUnion(String, Number) + annotatedUnion := NewAnnotated(inner, []Annotation{{Name: "max_len", Arg: int64(255)}}) + u := NewUnion(annotatedUnion, Boolean) + + if u == nil { + t.Fatal("union should not be nil") + } +}