Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions lib/lua/compiler/codegen.ex
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,9 @@ defmodule Lua.Compiler.Codegen do
{[loop_instruction], ctx}
end

defp gen_statement(%Statement.ForNum{var: var, start: start_expr, limit: limit_expr, step: step_expr, body: body}, ctx) do
# Get the loop variable's register from scope
loop_var_reg = ctx.scope.locals[var]
defp gen_statement(%Statement.ForNum{start: start_expr, limit: limit_expr, step: step_expr, body: body} = node, ctx) do
# Get the loop variable's register from var_map (locals are restored after scope isolation)
loop_var_reg = Map.get(ctx.scope.var_map, {:for_var, node})

# Allocate 3 internal registers for: counter, limit, step
base = ctx.next_reg
Expand Down Expand Up @@ -510,9 +510,9 @@ defmodule Lua.Compiler.Codegen do
[loop_instruction], ctx}
end

defp gen_statement(%Statement.ForIn{vars: vars, iterators: iterators, body: body}, ctx) do
# Look up loop variable registers from scope
var_regs = Enum.map(vars, fn name -> ctx.scope.locals[name] end)
defp gen_statement(%Statement.ForIn{iterators: iterators, body: body} = node, ctx) do
# Look up loop variable registers from var_map (locals are restored after scope isolation)
var_regs = Map.get(ctx.scope.var_map, {:for_in_vars, node})

# Allocate 3 internal registers for: iterator function, invariant state, control variable
base = ctx.next_reg
Expand Down Expand Up @@ -611,7 +611,31 @@ defmodule Lua.Compiler.Codegen do

case name do
[single_name] ->
{closure_instructions ++ [Instruction.set_global(single_name, closure_reg)], ctx}
# Check if scope resolution determined this should assign to a local
case Map.get(ctx.scope.var_map, {:func_decl_target, decl}) do
{:local, local_reg, local_name} ->
# Assign to local variable
move_instructions =
if closure_reg == local_reg do
[]
else
[Instruction.move(local_reg, closure_reg)]
end

# If the local is captured, also update its upvalue cell
update_upvalue =
if MapSet.member?(ctx.scope.captured_locals, local_name) do
[Instruction.set_open_upvalue(local_reg, closure_reg)]
else
[]
end

{closure_instructions ++ move_instructions ++ update_upvalue, ctx}

nil ->
# No local in scope at this point, assign to global
{closure_instructions ++ [Instruction.set_global(single_name, closure_reg)], ctx}
end

[first | rest] ->
# Dotted name: get the table chain, then set the final field
Expand Down
73 changes: 56 additions & 17 deletions lib/lua/compiler/scope.ex
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,25 @@ defmodule Lua.Compiler.Scope do
# Resolve the main condition
state = resolve_expr(condition, state)

# Resolve the then block
# Resolve the then block (save/restore locals so block-local vars don't leak)
saved_locals = state.locals
state = resolve_block(then_block, state)
state = %{state | locals: saved_locals}

# Resolve all elseif clauses
state =
Enum.reduce(elseifs, state, fn {elseif_cond, elseif_block}, state ->
state = resolve_expr(elseif_cond, state)
resolve_block(elseif_block, state)
saved = state.locals
state = resolve_block(elseif_block, state)
%{state | locals: saved}
end)

# Resolve the else block if present
if else_block do
resolve_block(else_block, state)
saved = state.locals
state = resolve_block(else_block, state)
%{state | locals: saved}
else
state
end
Expand All @@ -159,19 +165,24 @@ defmodule Lua.Compiler.Scope do
defp resolve_statement(%Statement.While{condition: condition, body: body}, state) do
# Resolve the condition
state = resolve_expr(condition, state)
# Resolve the body
resolve_block(body, state)
# Resolve the body (save/restore locals so body-local vars don't leak)
saved_locals = state.locals
state = resolve_block(body, state)
%{state | locals: saved_locals}
end

defp resolve_statement(%Statement.Repeat{body: body, condition: condition}, state) do
# Resolve the body first (in Lua, the condition can reference variables declared in the body)
saved_locals = state.locals
state = resolve_block(body, state)
# Resolve the condition
resolve_expr(condition, state)
# Resolve the condition (can see body locals per Lua 5.3 spec)
state = resolve_expr(condition, state)
# Restore locals after repeat block
%{state | locals: saved_locals}
end

defp resolve_statement(
%Statement.ForNum{var: var, start: start_expr, limit: limit_expr, step: step_expr, body: body},
%Statement.ForNum{var: var, start: start_expr, limit: limit_expr, step: step_expr, body: body} = node,
state
) do
# Resolve start, limit, and step expressions with current scope
Expand All @@ -187,20 +198,40 @@ defmodule Lua.Compiler.Scope do
# plus limit and step registers (codegen allocates base, base+1, base+2)
state = %{state | next_register: loop_var_reg + 3}

# Store in var_map so codegen can find it after locals are restored
state = %{state | var_map: Map.put(state.var_map, {:for_var, node}, loop_var_reg)}

# Update max_register
func_scope = state.functions[state.current_function]
func_scope = %{func_scope | max_register: max(func_scope.max_register, state.next_register)}
state = %{state | functions: Map.put(state.functions, state.current_function, func_scope)}

# Resolve the body with the loop variable in scope
# Resolve the body with the loop variable in scope (save/restore locals)
saved_locals = state.locals
state = resolve_block(body, state)

# Remove the loop variable from scope after the loop
# (In real implementation, we'd need scope stack management, but for now this is fine)
state
%{state | locals: saved_locals}
end

defp resolve_statement(%Statement.FuncDecl{params: params, body: body, is_method: is_method} = decl, state) do
defp resolve_statement(%Statement.FuncDecl{name: name, params: params, body: body, is_method: is_method} = decl, state) do
# Record if this FuncDecl should assign to a local (if a local with this name exists at this point)
state =
case name do
[single_name] ->
case Map.get(state.locals, single_name) do
nil ->
# No local in scope, will assign to global
state

local_reg ->
# Local in scope, record it so codegen assigns to the local
%{state | var_map: Map.put(state.var_map, {:func_decl_target, decl}, {:local, local_reg, single_name})}
end

_ ->
# Dotted name, always uses table assignment
state
end

all_params = if is_method, do: ["self" | params], else: params
resolve_function_scope(decl, all_params, body, state)
end
Expand All @@ -209,25 +240,33 @@ defmodule Lua.Compiler.Scope do
resolve_expr(call, state)
end

defp resolve_statement(%Statement.ForIn{vars: vars, iterators: iterators, body: body}, state) do
defp resolve_statement(%Statement.ForIn{vars: vars, iterators: iterators, body: body} = node, state) do
# Resolve iterator expressions with current scope
state = Enum.reduce(iterators, state, &resolve_expr/2)

# Assign registers for loop variables (same pattern as ForNum)
first_reg = state.next_register

{state, _} =
Enum.reduce(vars, {state, state.next_register}, fn name, {state, reg} ->
state = %{state | locals: Map.put(state.locals, name, reg)}
state = %{state | next_register: reg + 1}
{state, reg + 1}
end)

# Store in var_map so codegen can find registers after locals are restored
var_regs = Enum.with_index(vars, fn _name, i -> first_reg + i end)
state = %{state | var_map: Map.put(state.var_map, {:for_in_vars, node}, var_regs)}

# Update max_register
func_scope = state.functions[state.current_function]
func_scope = %{func_scope | max_register: max(func_scope.max_register, state.next_register)}
state = %{state | functions: Map.put(state.functions, state.current_function, func_scope)}

# Resolve the body with loop variables in scope
resolve_block(body, state)
# Resolve the body with loop variables in scope (save/restore locals)
saved_locals = state.locals
state = resolve_block(body, state)
%{state | locals: saved_locals}
end

defp resolve_statement(%Statement.LocalFunc{name: name, params: params, body: body} = local_func, state) do
Expand Down
42 changes: 26 additions & 16 deletions lib/lua/vm/executor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ defmodule Lua.VM.Executor do

{result, new_state} =
try_binary_metamethod("__band", val_a, val_b, state, fn ->
Bitwise.band(to_integer!(val_a), to_integer!(val_b))
val_a |> to_integer!() |> Bitwise.band(to_integer!(val_b)) |> to_signed_int64()
end)

regs = put_elem(regs, dest, result)
Expand All @@ -809,7 +809,7 @@ defmodule Lua.VM.Executor do

{result, new_state} =
try_binary_metamethod("__bor", val_a, val_b, state, fn ->
Bitwise.bor(to_integer!(val_a), to_integer!(val_b))
val_a |> to_integer!() |> Bitwise.bor(to_integer!(val_b)) |> to_signed_int64()
end)

regs = put_elem(regs, dest, result)
Expand All @@ -822,7 +822,7 @@ defmodule Lua.VM.Executor do

{result, new_state} =
try_binary_metamethod("__bxor", val_a, val_b, state, fn ->
Bitwise.bxor(to_integer!(val_a), to_integer!(val_b))
val_a |> to_integer!() |> Bitwise.bxor(to_integer!(val_b)) |> to_signed_int64()
end)

regs = put_elem(regs, dest, result)
Expand Down Expand Up @@ -860,7 +860,7 @@ defmodule Lua.VM.Executor do

{result, new_state} =
try_unary_metamethod("__bnot", val, state, fn ->
Bitwise.bnot(to_integer!(val))
val |> to_integer!() |> Bitwise.bnot() |> to_signed_int64()
end)

regs = put_elem(regs, dest, result)
Expand Down Expand Up @@ -1505,7 +1505,8 @@ defmodule Lua.VM.Executor do
defp safe_add(a, b) do
with {:ok, na} <- to_number(a),
{:ok, nb} <- to_number(b) do
na + nb
result = na + nb
if is_integer(result), do: to_signed_int64(result), else: result
else
{:error, val} ->
raise TypeError,
Expand All @@ -1518,7 +1519,8 @@ defmodule Lua.VM.Executor do
defp safe_subtract(a, b) do
with {:ok, na} <- to_number(a),
{:ok, nb} <- to_number(b) do
na - nb
result = na - nb
if is_integer(result), do: to_signed_int64(result), else: result
else
{:error, val} ->
raise TypeError,
Expand All @@ -1531,7 +1533,8 @@ defmodule Lua.VM.Executor do
defp safe_multiply(a, b) do
with {:ok, na} <- to_number(a),
{:ok, nb} <- to_number(b) do
na * nb
result = na * nb
if is_integer(result), do: to_signed_int64(result), else: result
else
{:error, val} ->
raise TypeError,
Expand All @@ -1544,9 +1547,6 @@ defmodule Lua.VM.Executor do
defp safe_divide(a, b) do
with {:ok, na} <- to_number(a),
{:ok, nb} <- to_number(b) do
# Check for division by zero
# Note: Standard Lua 5.3 returns inf/-inf/nan for float division by zero,
# but Elixir doesn't support creating these values easily, so we raise an error
if nb == 0 or nb == 0.0 do
raise RuntimeError, value: "attempt to divide by zero"
else
Expand Down Expand Up @@ -1594,7 +1594,7 @@ defmodule Lua.VM.Executor do

is_integer(na) and is_integer(nb) ->
# Lua floor modulo for integers: a - floor_div(a, b) * b
na - lua_idiv(na, nb) * nb
to_signed_int64(na - lua_idiv(na, nb) * nb)

true ->
# Float floor modulo: a - floor(a/b) * b
Expand All @@ -1614,7 +1614,8 @@ defmodule Lua.VM.Executor do
q = div(a, b)
r = rem(a, b)
# Adjust if remainder has different sign than divisor
if r != 0 and Bitwise.bxor(r, b) < 0, do: q - 1, else: q
result = if r != 0 and Bitwise.bxor(r, b) < 0, do: q - 1, else: q
to_signed_int64(result)
end

defp safe_power(a, b) do
Expand All @@ -1633,7 +1634,8 @@ defmodule Lua.VM.Executor do
defp safe_negate(a) do
case to_number(a) do
{:ok, na} ->
-na
result = -na
if is_integer(result), do: to_signed_int64(result), else: result

{:error, val} ->
raise TypeError,
Expand Down Expand Up @@ -1750,17 +1752,25 @@ defmodule Lua.VM.Executor do
defp lua_shift_left(val, shift) when shift < 0, do: lua_shift_right(val, -shift)

defp lua_shift_left(val, shift) do
Bitwise.band(Bitwise.bsl(val, shift), 0xFFFFFFFFFFFFFFFF)
val |> Bitwise.bsl(shift) |> to_signed_int64()
end

defp lua_shift_right(_val, shift) when shift >= 64, do: 0
defp lua_shift_right(_val, shift) when shift <= -64, do: 0
defp lua_shift_right(val, shift) when shift < 0, do: lua_shift_left(val, -shift)

defp lua_shift_right(val, shift) do
# Unsigned right shift - mask to 64-bit first
# Unsigned right shift - mask to 64-bit unsigned first
unsigned_val = Bitwise.band(val, 0xFFFFFFFFFFFFFFFF)
Bitwise.bsr(unsigned_val, shift)
unsigned_val |> Bitwise.bsr(shift) |> to_signed_int64()
end

# Wrap an arbitrary-precision integer to a signed 64-bit integer
@int64_max 0x7FFFFFFFFFFFFFFF
@int64_mod 0x10000000000000000
defp to_signed_int64(val) do
masked = Bitwise.band(val, 0xFFFFFFFFFFFFFFFF)
if masked > @int64_max, do: masked - @int64_mod, else: masked
end

# Helper to determine Lua type from Elixir value
Expand Down
Loading