Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions compiler/check/synth/ops/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,19 @@ func CheckTable(fields []FieldDef, arrayElems []typ.Type, expected typ.Type) Che
return CheckResult{Type: expected, Errors: errors}
}

switch expected.Kind() {
unwrapped := typ.UnwrapAnnotated(expected)

switch unwrapped.Kind() {
case kind.Array:
return checkTableAsArray(fields, arrayElems, expected.(*typ.Array))
return checkTableAsArray(fields, arrayElems, unwrapped.(*typ.Array))
case kind.Map:
return checkTableAsMap(fields, arrayElems, expected.(*typ.Map))
return checkTableAsMap(fields, arrayElems, unwrapped.(*typ.Map))
case kind.Record:
return checkTableAsRecord(fields, arrayElems, expected.(*typ.Record))
return checkTableAsRecord(fields, arrayElems, unwrapped.(*typ.Record))
case kind.Tuple:
return checkTableAsTuple(arrayElems, expected.(*typ.Tuple))
return checkTableAsTuple(arrayElems, unwrapped.(*typ.Tuple))
case kind.Union:
return checkTableAsUnion(fields, arrayElems, expected.(*typ.Union))
return checkTableAsUnion(fields, arrayElems, unwrapped.(*typ.Union))
default:
// Try synthesis and check compatibility
synthesized := tableConstructor(fields, arrayElems)
Expand Down
10 changes: 10 additions & 0 deletions ltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ func validateValue(val LValue, t typ.Type, resolver *typeResolver) bool {
}
return false

case *typ.Interface:
_, ok := val.(*LTable)
return ok

case *typ.Generic:
// Generic types need to be instantiated before validation
return false
Expand Down Expand Up @@ -536,6 +540,12 @@ func validateWithErrorResolver(val LValue, t typ.Type, resolver *typeResolver, p
}
return false, formatValidationError(path, typeName, luaTypeName(val))

case *typ.Interface:
if _, ok := val.(*LTable); ok {
return true, ""
}
return false, formatValidationError(path, "table", luaTypeName(val))

case *typ.Instantiated:
expanded := subst.ExpandInstantiated(tt)
if expanded != nil && expanded != tt {
Expand Down
82 changes: 82 additions & 0 deletions ltype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1506,3 +1506,85 @@ func TestLTypeCallAsLastArg(t *testing.T) {
t.Errorf("typeCall as non-last arg: %v", err)
}
}

func TestLTypeInterfaceTable(t *testing.T) {
L := NewState()
defer L.Close()

tableType := &LType{inner: typ.NewInterface("table", nil)}

plainTable := L.NewTable()
plainTable.RawSetString("key", LString("value"))

tableWithArray := L.NewTable()
tableWithArray.RawSetString("tags", L.NewTable())
tableWithArray.RawGet(LString("tags")).(*LTable).Append(LString("a"))
tableWithArray.RawGet(LString("tags")).(*LTable).Append(LString("b"))

tests := []struct {
name string
value LValue
expected bool
}{
{"table accepts plain table", plainTable, true},
{"table accepts table with nested arrays", tableWithArray, true},
{"table accepts empty table", L.NewTable(), true},
{"table rejects string", LString("hello"), false},
{"table rejects number", LNumber(42), false},
{"table rejects nil", LNil, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tableType.Validate(L, tt.value)
if result != tt.expected {
t.Errorf("Validate() = %v, want %v", result, tt.expected)
}
})
}
}

func TestLTypeRecordWithTableField(t *testing.T) {
L := NewState()
defer L.Close()

// {kind: string, data: table}
commandType := &LType{
inner: typ.NewRecord().
Field("kind", typ.String).
Field("data", typ.NewInterface("table", nil)).
Build(),
}

valid := L.NewTable()
valid.RawSetString("kind", LString("create"))
data := L.NewTable()
data.RawSetString("name", LString("test"))
tags := L.NewTable()
tags.Append(LString("a"))
tags.Append(LString("b"))
data.RawSetString("tags", tags)
valid.RawSetString("data", data)

invalid := L.NewTable()
invalid.RawSetString("kind", LString("create"))
invalid.RawSetString("data", LString("not a table"))

tests := []struct {
name string
value LValue
expected bool
}{
{"record with table field accepts nested arrays", valid, true},
{"record with table field rejects non-table data", invalid, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := commandType.Validate(L, tt.value)
if result != tt.expected {
t.Errorf("Validate() = %v, want %v", result, tt.expected)
}
})
}
}
112 changes: 112 additions & 0 deletions pcall_yield_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -969,3 +969,115 @@ func TestPcallErrorWithGoFunctionCallAfter(t *testing.T) {
}
})
}

// TestPooledStateYieldedFlagReset_ResetLState verifies that resetLState
// (called during Close/pool return) clears the yielded flag.
func TestPooledStateYieldedFlagReset_ResetLState(t *testing.T) {
L := NewState()
L.yielded = true
resetLState(L)
if L.yielded {
t.Fatal("resetLState must clear yielded flag")
}
}

// TestPooledStateYieldedFlagReset_NewLState verifies that newLState
// clears yielded on a state retrieved from the pool.
func TestPooledStateYieldedFlagReset_NewLState(t *testing.T) {
for statePool.Get() != nil {
}

dirty := NewState()
dirty.yielded = true
statePool.Put(dirty)

reused := NewState()
defer reused.Close()
if reused.yielded {
t.Fatal("newLState must reset yielded on pooled state")
}
}

// TestPooledStateYieldedFlagReset_NewLStateWithGAndAlloc verifies that
// newLStateWithGAndAlloc clears yielded on a state retrieved from the pool.
func TestPooledStateYieldedFlagReset_NewLStateWithGAndAlloc(t *testing.T) {
for statePool.Get() != nil {
}

// Put two dirty states: one for NewState (parent), one for NewThreadWithContext
d1 := NewState()
d1.yielded = true
d2 := NewState()
d2.yielded = true
statePool.Put(d1)
statePool.Put(d2)

parent := NewState()
defer parent.Close()
thread := parent.NewThreadWithContext(context.TODO())
if thread.yielded {
t.Fatal("newLStateWithGAndAlloc must reset yielded on pooled state")
}
}

// TestPooledStateYieldedFlagReset_EndToEnd simulates the real-world scenario:
// a coroutine yields, its state is pooled, then reused for a new execution.
func TestPooledStateYieldedFlagReset_EndToEnd(t *testing.T) {
for statePool.Get() != nil {
}

// Phase 1: Yield inside a coroutine, then close the parent to pool everything
func() {
L := NewState()
if err := L.DoString(`
function yielder()
coroutine.yield("y")
return "done"
end
`); err != nil {
t.Fatal(err)
}
co := L.NewThreadWithContext(context.TODO())
fn := L.GetGlobal("yielder").(*LFunction)
state, _, err := L.Resume(co, fn)
if err != nil {
t.Fatalf("resume: %v", err)
}
if state != ResumeYield {
t.Fatalf("expected ResumeYield, got %v", state)
}
if !co.yielded {
t.Fatal("co.yielded must be true after yield")
}
L.Close()
}()

// Phase 2: Reuse pooled states - everything must work correctly
L := NewState()
defer L.Close()

if err := L.DoString(`
function compute()
local ok, val = pcall(function()
return "result"
end)
if not ok then error("pcall failed") end
return val
end
`); err != nil {
t.Fatal(err)
}

co := L.NewThreadWithContext(context.TODO())
fn := L.GetGlobal("compute").(*LFunction)
state, results, err := L.Resume(co, fn)
if err != nil {
t.Fatalf("resume: %v", err)
}
if state != ResumeOK {
t.Fatalf("expected ResumeOK, got %v", state)
}
if len(results) < 1 || results[0].String() != "result" {
t.Fatalf("expected 'result', got %v", results)
}
}
1 change: 1 addition & 0 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ func newLState(options Options) *LState {
ls.stop = 0
ls.currentFrame = nil
ls.wrapped = false
ls.yielded = false
ls.uvcache = nil
ls.hasErrorFunc = false
ls.mainLoop = mainLoop
Expand Down
2 changes: 2 additions & 0 deletions state_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func resetLState(ls *LState) {
ls.G = nil
ls.hasErrorFunc = false
ls.wrapped = false
ls.yielded = false

// Clear frame extensions to prevent stale continuations from being invoked
ls.frameExt = nil
Expand Down Expand Up @@ -79,6 +80,7 @@ func newLStateWithGAndAlloc(options Options, G *Global, env *LTable, parentAlloc
ls.mainLoop = mainLoop
ls.alloc = parentAlloc
ls.stop = 0
ls.yielded = false
ls.ctx = nil
ls.ctxDone = nil

Expand Down
10 changes: 6 additions & 4 deletions types/subtype/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,14 @@ func flattenUnion(acc []typ.Type, types []typ.Type) []typ.Type {
continue
}

switch t.Kind() {
unwrapped := typ.UnwrapAnnotated(t)

switch unwrapped.Kind() {
case kind.Union:
acc = flattenUnion(acc, t.(*typ.Union).Members)
acc = flattenUnion(acc, unwrapped.(*typ.Union).Members)
case kind.Optional:
acc = append(acc, typ.Nil)
acc = flattenUnion(acc, []typ.Type{t.(*typ.Optional).Inner})
acc = flattenUnion(acc, []typ.Type{unwrapped.(*typ.Optional).Inner})
default:
acc = append(acc, t)
}
Expand All @@ -248,7 +250,7 @@ func flattenIntersection(acc []typ.Type, types []typ.Type) []typ.Type {
continue
}

if i, ok := t.(*typ.Intersection); ok {
if i, ok := typ.UnwrapAnnotated(t).(*typ.Intersection); ok {
acc = flattenIntersection(acc, i.Members)
} else {
acc = append(acc, t)
Expand Down
6 changes: 4 additions & 2 deletions types/typ/intersection.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ func NewIntersection(members ...Type) Type {
continue
}

switch m.Kind() {
unwrapped := UnwrapAnnotated(m)

switch unwrapped.Kind() {
case kind.Any:
continue // Any is identity for intersection
case kind.Never:
hasNever = true
case kind.Nil:
hasNil = true
case kind.Intersection:
flat = append(flat, m.(*Intersection).Members...)
flat = append(flat, unwrapped.(*Intersection).Members...)
default:
flat = append(flat, m)
}
Expand Down
2 changes: 1 addition & 1 deletion types/typ/optional.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func NewOptional(inner Type) Type {
}

if inner.Kind() == kind.Union {
u := inner.(*Union)
u := UnwrapAnnotated(inner).(*Union)
members := make([]Type, 0, len(u.Members)+1)
members = append(members, Nil)
members = append(members, u.Members...)
Expand Down
11 changes: 11 additions & 0 deletions types/typ/optional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,14 @@ func TestOptionalOfUnionWithoutNil(t *testing.T) {
t.Error("Normalized union should contain Nil")
}
}

func TestOptionalAnnotatedUnion(t *testing.T) {
// NewOptional with Annotated wrapping Union should not panic
inner := NewUnion(String, Number)
annotated := NewAnnotated(inner, []Annotation{{Name: "max_len", Arg: int64(255)}})
o := NewOptional(annotated)

if o == nil {
t.Fatal("optional should not be nil")
}
}
12 changes: 9 additions & 3 deletions types/typ/union.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ func NewUnion(members ...Type) Type {
return
}

switch m.Kind() {
// Unwrap Annotated to access structural type for flattening.
// Annotations delegate Kind() to their inner type, so type
// assertions on concrete wrappers (Union, Optional) require
// operating on the unwrapped type.
unwrapped := UnwrapAnnotated(m)

switch unwrapped.Kind() {
case kind.Never:
return // Never is identity for union
case kind.Unknown:
Expand All @@ -59,12 +65,12 @@ func NewUnion(members ...Type) Type {
case kind.Nil:
hasNil = true
case kind.Union:
for _, member := range m.(*Union).Members {
for _, member := range unwrapped.(*Union).Members {
addMember(member)
}
case kind.Optional:
hasNil = true
addMember(m.(*Optional).Inner)
addMember(unwrapped.(*Optional).Inner)
default:
flat = append(flat, m)
}
Expand Down
Loading