diff --git a/lib/lua/compiler/codegen.ex b/lib/lua/compiler/codegen.ex index d365b51..391e122 100644 --- a/lib/lua/compiler/codegen.ex +++ b/lib/lua/compiler/codegen.ex @@ -464,9 +464,9 @@ defmodule Lua.Compiler.Codegen do {[loop_instruction], ctx} end - defp gen_statement(%Statement.ForNum{var: var, start: start_expr, limit: limit_expr, step: step_expr, body: body}, ctx) do - # Get the loop variable's register from scope - loop_var_reg = ctx.scope.locals[var] + defp gen_statement(%Statement.ForNum{start: start_expr, limit: limit_expr, step: step_expr, body: body} = node, ctx) do + # Get the loop variable's register from var_map (locals are restored after scope isolation) + loop_var_reg = Map.get(ctx.scope.var_map, {:for_var, node}) # Allocate 3 internal registers for: counter, limit, step base = ctx.next_reg @@ -510,9 +510,9 @@ defmodule Lua.Compiler.Codegen do [loop_instruction], ctx} end - defp gen_statement(%Statement.ForIn{vars: vars, iterators: iterators, body: body}, ctx) do - # Look up loop variable registers from scope - var_regs = Enum.map(vars, fn name -> ctx.scope.locals[name] end) + defp gen_statement(%Statement.ForIn{iterators: iterators, body: body} = node, ctx) do + # Look up loop variable registers from var_map (locals are restored after scope isolation) + var_regs = Map.get(ctx.scope.var_map, {:for_in_vars, node}) # Allocate 3 internal registers for: iterator function, invariant state, control variable base = ctx.next_reg @@ -611,7 +611,31 @@ defmodule Lua.Compiler.Codegen do case name do [single_name] -> - {closure_instructions ++ [Instruction.set_global(single_name, closure_reg)], ctx} + # Check if scope resolution determined this should assign to a local + case Map.get(ctx.scope.var_map, {:func_decl_target, decl}) do + {:local, local_reg, local_name} -> + # Assign to local variable + move_instructions = + if closure_reg == local_reg do + [] + else + [Instruction.move(local_reg, closure_reg)] + end + + # If the local is captured, also update its upvalue cell + update_upvalue = + if MapSet.member?(ctx.scope.captured_locals, local_name) do + [Instruction.set_open_upvalue(local_reg, closure_reg)] + else + [] + end + + {closure_instructions ++ move_instructions ++ update_upvalue, ctx} + + nil -> + # No local in scope at this point, assign to global + {closure_instructions ++ [Instruction.set_global(single_name, closure_reg)], ctx} + end [first | rest] -> # Dotted name: get the table chain, then set the final field diff --git a/lib/lua/compiler/scope.ex b/lib/lua/compiler/scope.ex index 088f6db..85ddb90 100644 --- a/lib/lua/compiler/scope.ex +++ b/lib/lua/compiler/scope.ex @@ -138,19 +138,25 @@ defmodule Lua.Compiler.Scope do # Resolve the main condition state = resolve_expr(condition, state) - # Resolve the then block + # Resolve the then block (save/restore locals so block-local vars don't leak) + saved_locals = state.locals state = resolve_block(then_block, state) + state = %{state | locals: saved_locals} # Resolve all elseif clauses state = Enum.reduce(elseifs, state, fn {elseif_cond, elseif_block}, state -> state = resolve_expr(elseif_cond, state) - resolve_block(elseif_block, state) + saved = state.locals + state = resolve_block(elseif_block, state) + %{state | locals: saved} end) # Resolve the else block if present if else_block do - resolve_block(else_block, state) + saved = state.locals + state = resolve_block(else_block, state) + %{state | locals: saved} else state end @@ -159,19 +165,24 @@ defmodule Lua.Compiler.Scope do defp resolve_statement(%Statement.While{condition: condition, body: body}, state) do # Resolve the condition state = resolve_expr(condition, state) - # Resolve the body - resolve_block(body, state) + # Resolve the body (save/restore locals so body-local vars don't leak) + saved_locals = state.locals + state = resolve_block(body, state) + %{state | locals: saved_locals} end defp resolve_statement(%Statement.Repeat{body: body, condition: condition}, state) do # Resolve the body first (in Lua, the condition can reference variables declared in the body) + saved_locals = state.locals state = resolve_block(body, state) - # Resolve the condition - resolve_expr(condition, state) + # Resolve the condition (can see body locals per Lua 5.3 spec) + state = resolve_expr(condition, state) + # Restore locals after repeat block + %{state | locals: saved_locals} end defp resolve_statement( - %Statement.ForNum{var: var, start: start_expr, limit: limit_expr, step: step_expr, body: body}, + %Statement.ForNum{var: var, start: start_expr, limit: limit_expr, step: step_expr, body: body} = node, state ) do # Resolve start, limit, and step expressions with current scope @@ -187,20 +198,40 @@ defmodule Lua.Compiler.Scope do # plus limit and step registers (codegen allocates base, base+1, base+2) state = %{state | next_register: loop_var_reg + 3} + # Store in var_map so codegen can find it after locals are restored + state = %{state | var_map: Map.put(state.var_map, {:for_var, node}, loop_var_reg)} + # Update max_register func_scope = state.functions[state.current_function] func_scope = %{func_scope | max_register: max(func_scope.max_register, state.next_register)} state = %{state | functions: Map.put(state.functions, state.current_function, func_scope)} - # Resolve the body with the loop variable in scope + # Resolve the body with the loop variable in scope (save/restore locals) + saved_locals = state.locals state = resolve_block(body, state) - - # Remove the loop variable from scope after the loop - # (In real implementation, we'd need scope stack management, but for now this is fine) - state + %{state | locals: saved_locals} end - defp resolve_statement(%Statement.FuncDecl{params: params, body: body, is_method: is_method} = decl, state) do + defp resolve_statement(%Statement.FuncDecl{name: name, params: params, body: body, is_method: is_method} = decl, state) do + # Record if this FuncDecl should assign to a local (if a local with this name exists at this point) + state = + case name do + [single_name] -> + case Map.get(state.locals, single_name) do + nil -> + # No local in scope, will assign to global + state + + local_reg -> + # Local in scope, record it so codegen assigns to the local + %{state | var_map: Map.put(state.var_map, {:func_decl_target, decl}, {:local, local_reg, single_name})} + end + + _ -> + # Dotted name, always uses table assignment + state + end + all_params = if is_method, do: ["self" | params], else: params resolve_function_scope(decl, all_params, body, state) end @@ -209,11 +240,13 @@ defmodule Lua.Compiler.Scope do resolve_expr(call, state) end - defp resolve_statement(%Statement.ForIn{vars: vars, iterators: iterators, body: body}, state) do + defp resolve_statement(%Statement.ForIn{vars: vars, iterators: iterators, body: body} = node, state) do # Resolve iterator expressions with current scope state = Enum.reduce(iterators, state, &resolve_expr/2) # Assign registers for loop variables (same pattern as ForNum) + first_reg = state.next_register + {state, _} = Enum.reduce(vars, {state, state.next_register}, fn name, {state, reg} -> state = %{state | locals: Map.put(state.locals, name, reg)} @@ -221,13 +254,19 @@ defmodule Lua.Compiler.Scope do {state, reg + 1} end) + # Store in var_map so codegen can find registers after locals are restored + var_regs = Enum.with_index(vars, fn _name, i -> first_reg + i end) + state = %{state | var_map: Map.put(state.var_map, {:for_in_vars, node}, var_regs)} + # Update max_register func_scope = state.functions[state.current_function] func_scope = %{func_scope | max_register: max(func_scope.max_register, state.next_register)} state = %{state | functions: Map.put(state.functions, state.current_function, func_scope)} - # Resolve the body with loop variables in scope - resolve_block(body, state) + # Resolve the body with loop variables in scope (save/restore locals) + saved_locals = state.locals + state = resolve_block(body, state) + %{state | locals: saved_locals} end defp resolve_statement(%Statement.LocalFunc{name: name, params: params, body: body} = local_func, state) do diff --git a/test/language/function_test.exs b/test/language/function_test.exs index 2f86708..863825e 100644 --- a/test/language/function_test.exs +++ b/test/language/function_test.exs @@ -39,4 +39,134 @@ defmodule Lua.Language.FunctionTest do assert {["error msg"], _} = Lua.eval!(lua, code) end + + describe "FuncDecl local assignment" do + test "function f() updates local f when it exists in scope", %{lua: lua} do + code = """ + local f = function(x) return "first: " .. x end + assert(f("test") == "first: test") + + function f(x) + return "second: " .. x + end + + return f("test") + """ + + assert {["second: test"], _} = Lua.eval!(lua, code) + end + + test "function f() creates global when no local exists", %{lua: lua} do + code = """ + function f(x) + return "global: " .. x + end + + return f("test") + """ + + assert {["global: test"], _} = Lua.eval!(lua, code) + end + + test "function f() in nested do block does not see outer local", %{lua: lua} do + code = """ + local f = function() return "outer" end + + do + local f = function() return "inner" end + + function f() + return "updated inner" + end + + assert(f() == "updated inner") + end + + -- outer f should be unchanged + return f() + """ + + assert {["outer"], _} = Lua.eval!(lua, code) + end + + test "dotted name always uses table assignment", %{lua: lua} do + code = """ + local t = {} + + function t.method(x) + return x + 1 + end + + return t.method(41) + """ + + assert {[42], _} = Lua.eval!(lua, code) + end + + test "function f() does not see locals declared after the FuncDecl", %{lua: lua} do + code = """ + function f() + return "global" + end + + local f = function() return "local" end + + -- The global f is unaffected by the later local declaration + return f() + """ + + assert {["local"], _} = Lua.eval!(lua, code) + end + end + + describe "local function declarations" do + test "basic local function", %{lua: lua} do + code = """ + local function add(a, b) + return a + b + end + return add(1, 2) + """ + + assert {[3], _} = Lua.eval!(lua, code) + end + + test "recursive local function", %{lua: lua} do + code = """ + local function fact(n) + if n <= 1 then return 1 end + return n * fact(n - 1) + end + return fact(5) + """ + + assert {[120], _} = Lua.eval!(lua, code) + end + + test "local function captures outer locals", %{lua: lua} do + code = """ + local base = 10 + local function add_base(x) + return x + base + end + return add_base(5) + """ + + assert {[15], _} = Lua.eval!(lua, code) + end + + test "local function is not visible outside its scope", %{lua: lua} do + code = """ + do + local function helper() + return 42 + end + assert(helper() == 42) + end + return helper == nil + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + end end