diff --git a/AGENTS.md b/AGENTS.md index 9ed54c3..f378225 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -115,6 +115,7 @@ mix usage_rules.search_docs "Enum.zip" --query-by title ## Debugging - Use `dbg/1` to print values while debugging. This will display the formatted value and other relevant information in the console. +- Write ExUnit tests when trying to replicate bugs or issues. DO NOT use mix run -e unless necessary, always try to write a unit test replicating failures first. diff --git a/lib/lua/compiler/codegen.ex b/lib/lua/compiler/codegen.ex index 141ecb0..d365b51 100644 --- a/lib/lua/compiler/codegen.ex +++ b/lib/lua/compiler/codegen.ex @@ -36,8 +36,8 @@ defmodule Lua.Compiler.Codegen do prototypes: Enum.reverse(ctx.prototypes), upvalue_descriptors: [], param_count: 0, - is_vararg: false, - max_registers: func_scope.max_register, + is_vararg: func_scope.is_vararg, + max_registers: Enum.max([func_scope.max_register, ctx.next_reg, Map.get(ctx, :peak_reg, 0)]), source: source, lines: lines } @@ -47,9 +47,17 @@ defmodule Lua.Compiler.Codegen do defp gen_block(%Block{stmts: stmts}, ctx) do Enum.reduce(stmts, {[], ctx}, fn stmt, {instructions, ctx} -> + # Save next_reg before each statement so temp registers are recycled. + # Scope-assigned locals use fixed registers below this base; codegen temps + # are allocated above it and don't need to persist across statements. + saved_next_reg = ctx.next_reg # Emit source_line before each statement line_instr = emit_source_line(stmt, ctx) {new_instructions, ctx} = gen_statement(stmt, ctx) + # Track peak for max_registers, then reset for next statement + peak = max(Map.get(ctx, :peak_reg, 0), ctx.next_reg) + ctx = %{ctx | next_reg: saved_next_reg} + ctx = Map.put(ctx, :peak_reg, peak) {instructions ++ line_instr ++ new_instructions, ctx} end) end @@ -86,21 +94,14 @@ defmodule Lua.Compiler.Codegen do defp gen_statement(%Statement.Return{values: [%Expr.Call{} = call]}, ctx) do # return f(...) — forward all results from the call {call_instructions, _result_reg, ctx} = gen_expr(call, ctx) + call_instructions = patch_call_result_count(call_instructions, -1) + {call_instructions, ctx} + end - # Patch the call to request all results (sentinel -1) - call_instructions = - case List.last(call_instructions) do - {:call, base, arg_count, _result_count} -> - List.replace_at( - call_instructions, - length(call_instructions) - 1, - {:call, base, arg_count, -1} - ) - - _ -> - call_instructions - end - + defp gen_statement(%Statement.Return{values: [%Expr.MethodCall{} = call]}, ctx) do + # return obj:method(...) — forward all results + {call_instructions, _result_reg, ctx} = gen_expr(call, ctx) + call_instructions = patch_call_result_count(call_instructions, -1) {call_instructions, ctx} end @@ -130,6 +131,8 @@ defmodule Lua.Compiler.Codegen do |> Enum.with_index() |> Enum.reduce({[], ctx}, fn {value, i}, {instructions, ctx} -> target_reg = base_reg + i + # Ensure next_reg is past target so gen_expr doesn't overwrite previous results + ctx = %{ctx | next_reg: max(ctx.next_reg, target_reg + 1)} {value_instructions, value_reg, ctx} = gen_expr(value, ctx) move = @@ -146,8 +149,42 @@ defmodule Lua.Compiler.Codegen do vararg_base = base_reg + length(init_values) vararg_instruction = Instruction.vararg(vararg_base, 0) - # Return with -1 to indicate variable number of results - {init_instructions ++ [vararg_instruction, Instruction.return_instr(base_reg, -1)], ctx} + # Return with negative count: -(init_count + 1) to encode fixed + variable + init_count = length(init_values) + {init_instructions ++ [vararg_instruction, Instruction.return_instr(base_reg, -(init_count + 1))], ctx} + + call_expr when is_struct(call_expr, Expr.Call) or is_struct(call_expr, Expr.MethodCall) -> + # return a, b, f() - load a,b then expand all results of f() + base_reg = ctx.next_reg + fixed_count = length(init_values) + + {init_instructions, ctx} = + init_values + |> Enum.with_index() + |> Enum.reduce({[], ctx}, fn {value, i}, {instructions, ctx} -> + target_reg = base_reg + i + ctx = %{ctx | next_reg: max(ctx.next_reg, target_reg + 1)} + {value_instructions, value_reg, ctx} = gen_expr(value, ctx) + + move = + if value_reg == target_reg do + [] + else + [Instruction.move(target_reg, value_reg)] + end + + {instructions ++ value_instructions ++ move, ctx} + end) + + # Compile the tail call + ctx = %{ctx | next_reg: base_reg + fixed_count} + {call_instructions, _call_reg, ctx} = gen_expr(call_expr, ctx) + call_instructions = patch_call_result_count(call_instructions, -2) + + # Return with {:multi_return, fixed_count} to indicate fixed + expanded results + {init_instructions ++ + call_instructions ++ + [Instruction.return_instr(base_reg, {:multi_return, fixed_count})], ctx} _ -> # Normal multi-value return @@ -158,6 +195,8 @@ defmodule Lua.Compiler.Codegen do |> Enum.with_index() |> Enum.reduce({[], ctx}, fn {value, i}, {instructions, ctx} -> target_reg = base_reg + i + # Ensure next_reg is past target so gen_expr doesn't overwrite previous results + ctx = %{ctx | next_reg: max(ctx.next_reg, target_reg + 1)} {value_instructions, value_reg, ctx} = gen_expr(value, ctx) move = @@ -190,18 +229,19 @@ defmodule Lua.Compiler.Codegen do num_targets = length(targets) num_values = length(values) - # Check if last value is a call (for multiple-return expansion) - {init_values, last_value, last_is_call} = + # Check if last value is a call or vararg (for multiple-return expansion) + {init_values, last_value, last_kind} = if num_values > 0 do [last | _] = Enum.reverse(values) case last do - %Expr.Call{} -> {Enum.slice(values, 0..-2//1), last, true} - %Expr.MethodCall{} -> {Enum.slice(values, 0..-2//1), last, true} - _ -> {values, nil, false} + %Expr.Call{} -> {Enum.slice(values, 0..-2//1), last, :call} + %Expr.MethodCall{} -> {Enum.slice(values, 0..-2//1), last, :call} + %Expr.Vararg{} -> {Enum.slice(values, 0..-2//1), last, :vararg} + _ -> {values, nil, :none} end else - {[], nil, false} + {[], nil, :none} end # Evaluate init values into temp registers @@ -211,30 +251,36 @@ defmodule Lua.Compiler.Codegen do {instructions ++ value_instructions, regs ++ [value_reg], ctx} end) - # If last value is a call, expand multiple returns - {call_instructions, call_base, ctx} = - if last_is_call do - # Number of extra values we need from the call - needed_from_call = num_targets - length(init_values) - {call_instr, call_reg, ctx} = gen_expr(last_value, ctx) - - # Patch the call to request needed_from_call results - call_instr = - case List.last(call_instr) do - {:call, cb, arg_count, _result_count} -> - List.replace_at( - call_instr, - length(call_instr) - 1, - {:call, cb, arg_count, max(needed_from_call, 1)} - ) - - _ -> - call_instr - end - - {call_instr, call_reg, ctx} - else - {[], nil, ctx} + # If last value is a call or vararg, expand multiple returns + {multi_instructions, multi_base, ctx} = + case last_kind do + :call -> + needed = num_targets - length(init_values) + {call_instr, call_reg, ctx} = gen_expr(last_value, ctx) + + call_instr = + case List.last(call_instr) do + {:call, cb, arg_count, _result_count} -> + List.replace_at( + call_instr, + length(call_instr) - 1, + {:call, cb, arg_count, max(needed, 1)} + ) + + _ -> + call_instr + end + + {call_instr, call_reg, ctx} + + :vararg -> + needed = num_targets - length(init_values) + vararg_base = ctx.next_reg + ctx = %{ctx | next_reg: vararg_base + needed} + {[Instruction.vararg(vararg_base, needed)], vararg_base, ctx} + + :none -> + {[], nil, ctx} end # Build the list of value registers for each target @@ -244,10 +290,10 @@ defmodule Lua.Compiler.Codegen do i < length(init_regs) -> Enum.at(init_regs, i) - last_is_call -> - # Value comes from the multi-return call - call_result_offset = i - length(init_regs) - call_base + call_result_offset + last_kind != :none -> + # Value comes from the multi-return call or vararg + multi_offset = i - length(init_regs) + multi_base + multi_offset true -> nil @@ -274,7 +320,7 @@ defmodule Lua.Compiler.Codegen do end end) - {init_instructions ++ call_instructions ++ assign_instructions, ctx} + {init_instructions ++ multi_instructions ++ assign_instructions, ctx} end end @@ -284,23 +330,26 @@ defmodule Lua.Compiler.Codegen do num_names = length(names) num_values = length(values) - # Check if last value is a call (for multiple-return expansion) - {init_values, last_value, last_is_call} = + # Check if last value is a call or vararg (for multiple-return expansion) + {init_values, last_value, last_kind} = if num_values > 0 do [last | _] = Enum.reverse(values) case last do %Expr.Call{} when num_names > num_values -> - {Enum.slice(values, 0..-2//1), last, true} + {Enum.slice(values, 0..-2//1), last, :call} %Expr.MethodCall{} when num_names > num_values -> - {Enum.slice(values, 0..-2//1), last, true} + {Enum.slice(values, 0..-2//1), last, :call} + + %Expr.Vararg{} when num_names > num_values -> + {Enum.slice(values, 0..-2//1), last, :vararg} _ -> - {values, nil, false} + {values, nil, :none} end else - {[], nil, false} + {[], nil, :none} end # Generate code for init values @@ -310,28 +359,36 @@ defmodule Lua.Compiler.Codegen do {instructions ++ new_instructions, regs ++ [reg], ctx} end) - # If last value is a call, generate it requesting multiple returns - {call_instructions, call_base, ctx} = - if last_is_call do - needed = num_names - length(init_values) - {call_instr, call_reg, ctx} = gen_expr(last_value, ctx) - - call_instr = - case List.last(call_instr) do - {:call, cb, arg_count, _result_count} -> - List.replace_at( - call_instr, - length(call_instr) - 1, - {:call, cb, arg_count, max(needed, 1)} - ) - - _ -> - call_instr - end + # If last value is a call or vararg, generate it requesting multiple returns + {multi_instructions, multi_base, ctx} = + case last_kind do + :call -> + needed = num_names - length(init_values) + {call_instr, call_reg, ctx} = gen_expr(last_value, ctx) - {call_instr, call_reg, ctx} - else - {[], nil, ctx} + call_instr = + case List.last(call_instr) do + {:call, cb, arg_count, _result_count} -> + List.replace_at( + call_instr, + length(call_instr) - 1, + {:call, cb, arg_count, max(needed, 1)} + ) + + _ -> + call_instr + end + + {call_instr, call_reg, ctx} + + :vararg -> + needed = num_names - length(init_values) + vararg_base = ctx.next_reg + ctx = %{ctx | next_reg: vararg_base + needed} + {[Instruction.vararg(vararg_base, needed)], vararg_base, ctx} + + :none -> + {[], nil, ctx} end # Generate move instructions to copy values to their assigned registers @@ -346,10 +403,10 @@ defmodule Lua.Compiler.Codegen do source_reg = Enum.at(value_regs, index) if dest_reg == source_reg, do: [], else: [Instruction.move(dest_reg, source_reg)] - last_is_call -> - # Value comes from multi-return call - call_offset = index - length(value_regs) - source_reg = call_base + call_offset + last_kind != :none -> + # Value comes from multi-return call or vararg expansion + multi_offset = index - length(value_regs) + source_reg = multi_base + multi_offset if dest_reg == source_reg, do: [], else: [Instruction.move(dest_reg, source_reg)] true -> @@ -358,7 +415,7 @@ defmodule Lua.Compiler.Codegen do end end) - {value_instructions ++ call_instructions ++ move_instructions, ctx} + {value_instructions ++ multi_instructions ++ move_instructions, ctx} end defp gen_statement( @@ -558,7 +615,7 @@ defmodule Lua.Compiler.Codegen do [first | rest] -> # Dotted name: get the table chain, then set the final field - {get_instructions, table_reg, ctx} = gen_expr(%Expr.Var{name: first}, ctx) + {get_instructions, table_reg, ctx} = gen_var_by_name(first, ctx) {final_instructions, final_table_reg, ctx} = Enum.reduce(Enum.slice(rest, 0..-2//1), {get_instructions, table_reg, ctx}, fn field, {instrs, reg, ctx} -> @@ -569,7 +626,8 @@ defmodule Lua.Compiler.Codegen do last_field = List.last(rest) - {final_instructions ++ [Instruction.set_field(final_table_reg, last_field, closure_reg)], ctx} + {closure_instructions ++ final_instructions ++ [Instruction.set_field(final_table_reg, last_field, closure_reg)], + ctx} end end @@ -578,8 +636,8 @@ defmodule Lua.Compiler.Codegen do # Generate closure for the function {closure_instructions, closure_reg, ctx} = gen_closure_from_node(local_func, ctx) - # Get the local variable's register from scope - dest_reg = ctx.scope.locals[name] + # Get the local variable's register from var_map (per-statement, handles redefinitions) + dest_reg = Map.get(ctx.scope.var_map, {:local_func_reg, local_func}, ctx.scope.locals[name]) # Move closure to the local's register move_instructions = @@ -881,91 +939,101 @@ defmodule Lua.Compiler.Codegen do ctx = %{ctx | next_reg: base_reg + 1} - # Check if last arg is vararg - needs special handling - {has_vararg_last, init_args} = - if length(args) > 0 do - [last | _] = Enum.reverse(args) + # Check what the last argument is — determines calling convention + last_arg_type = + case args do + [] -> + :normal - case last do - %Expr.Vararg{} -> {true, Enum.slice(args, 0..-2//1)} - _ -> {false, args} - end - else - {false, []} + _ -> + case List.last(args) do + %Expr.Vararg{} -> :vararg + %Expr.Call{} -> :multi_call + %Expr.MethodCall{} -> :multi_call + _ -> :normal + end end - if has_vararg_last do - # f(a, b, ...) - load a, b then all varargs - arg_count = length(init_args) - ctx = %{ctx | next_reg: base_reg + 1 + arg_count} - - {arg_instructions, arg_regs, ctx} = - Enum.reduce(init_args, {[], [], ctx}, fn arg, {instructions, regs, ctx} -> - {arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx) - {instructions ++ arg_instructions, regs ++ [arg_reg], ctx} - end) - - # Move each arg result to its expected position (base+1+i) - move_instructions = - arg_regs - |> Enum.with_index() - |> Enum.flat_map(fn {arg_reg, i} -> - expected_reg = base_reg + 1 + i - - if arg_reg == expected_reg do - [] - else - [Instruction.move(expected_reg, arg_reg)] - end - end) - - # Load all varargs starting after init args - vararg_base = base_reg + 1 + arg_count - vararg_instruction = Instruction.vararg(vararg_base, 0) - - # Call with -(init_args+1) to encode both varargs and fixed arg count - # Negative values encode: -1 means 0 fixed + varargs, -2 means 1 fixed + varargs, etc. - call_instruction = Instruction.call(base_reg, -(arg_count + 1), 1) - - {function_instructions ++ - move_function ++ - arg_instructions ++ - move_instructions ++ - [vararg_instruction, call_instruction], base_reg, ctx} - else - # Normal function call without varargs - arg_count = length(args) - ctx = %{ctx | next_reg: base_reg + 1 + arg_count} - - {arg_instructions, arg_regs, ctx} = - Enum.reduce(args, {[], [], ctx}, fn arg, {instructions, regs, ctx} -> - {arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx) - {instructions ++ arg_instructions, regs ++ [arg_reg], ctx} - end) - - # Move each arg result to its expected position (base+1+i) - move_instructions = - arg_regs - |> Enum.with_index() - |> Enum.flat_map(fn {arg_reg, i} -> - expected_reg = base_reg + 1 + i - - if arg_reg == expected_reg do - [] - else - [Instruction.move(expected_reg, arg_reg)] - end - end) + case last_arg_type do + :vararg -> + # f(a, b, ...) - load a, b then all varargs + init_args = Enum.slice(args, 0..-2//1) + arg_count = length(init_args) + ctx = %{ctx | next_reg: base_reg + 1 + arg_count} + + {arg_instructions, arg_regs, ctx} = + Enum.reduce(init_args, {[], [], ctx}, fn arg, {instructions, regs, ctx} -> + {arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx) + {instructions ++ arg_instructions, regs ++ [arg_reg], ctx} + end) - # Generate call instruction (single return value for now) - call_instruction = Instruction.call(base_reg, arg_count, 1) + move_instructions = gen_move_args(arg_regs, base_reg + 1) - # Result will be in base_reg - {function_instructions ++ - move_function ++ - arg_instructions ++ - move_instructions ++ - [call_instruction], base_reg, ctx} + vararg_base = base_reg + 1 + arg_count + vararg_instruction = Instruction.vararg(vararg_base, 0) + call_instruction = Instruction.call(base_reg, -(arg_count + 1), 1) + + {function_instructions ++ + move_function ++ + arg_instructions ++ + move_instructions ++ + [vararg_instruction, call_instruction], base_reg, ctx} + + :multi_call -> + # f(a, b, g()) - load a, b then expand all results of g() + init_args = Enum.slice(args, 0..-2//1) + last_call = List.last(args) + fixed_count = length(init_args) + + ctx = %{ctx | next_reg: base_reg + 1 + fixed_count} + + {arg_instructions, arg_regs, ctx} = + Enum.reduce(init_args, {[], [], ctx}, fn arg, {instructions, regs, ctx} -> + {arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx) + {instructions ++ arg_instructions, regs ++ [arg_reg], ctx} + end) + + move_instructions = gen_move_args(arg_regs, base_reg + 1) + + # Ensure next_reg is positioned for the inner call + ctx = %{ctx | next_reg: base_reg + 1 + fixed_count} + + # Compile the inner call — its results will be placed at base+1+fixed_count + {inner_call_instructions, _inner_base, ctx} = gen_expr(last_call, ctx) + + # Patch the inner call's result_count to -2 (expand all results) + inner_call_instructions = patch_call_result_count(inner_call_instructions, -2) + + # Outer call uses {:multi, fixed_count} to collect fixed + expanded args + call_instruction = Instruction.call(base_reg, {:multi, fixed_count}, 1) + + {function_instructions ++ + move_function ++ + arg_instructions ++ + move_instructions ++ + inner_call_instructions ++ + [call_instruction], base_reg, ctx} + + :normal -> + # Normal function call + arg_count = length(args) + ctx = %{ctx | next_reg: base_reg + 1 + arg_count} + + {arg_instructions, arg_regs, ctx} = + Enum.reduce(args, {[], [], ctx}, fn arg, {instructions, regs, ctx} -> + {arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx) + {instructions ++ arg_instructions, regs ++ [arg_reg], ctx} + end) + + move_instructions = gen_move_args(arg_regs, base_reg + 1) + + call_instruction = Instruction.call(base_reg, arg_count, 1) + + {function_instructions ++ + move_function ++ + arg_instructions ++ + move_instructions ++ + [call_instruction], base_reg, ctx} end end @@ -1002,7 +1070,8 @@ defmodule Lua.Compiler.Codegen do # Table with {a, b, ...} # Reserve contiguous slots for the init values start_reg = ctx.next_reg - ctx = %{ctx | next_reg: start_reg + length(init_fields)} + init_count = length(init_fields) + ctx = %{ctx | next_reg: start_reg + init_count} {init_instructions, ctx} = init_fields @@ -1022,20 +1091,52 @@ defmodule Lua.Compiler.Codegen do end) # Load all varargs starting after init values - vararg_base = start_reg + length(init_fields) + vararg_base = start_reg + init_count vararg_instruction = Instruction.vararg(vararg_base, 0) - # set_list with count 0 means variable number of values - set_list_instruction = Instruction.set_list(dest, start_reg, 0, 0) + # {:multi, init_count} means init_count fixed + state.multi_return_count variable + set_list_instruction = Instruction.set_list(dest, start_reg, {:multi, init_count}, 0) {init_instructions ++ [vararg_instruction, set_list_instruction], ctx} %Expr.Vararg{} -> # Table with just {...} start_reg = ctx.next_reg vararg_instruction = Instruction.vararg(start_reg, 0) - set_list_instruction = Instruction.set_list(dest, start_reg, 0, 0) + set_list_instruction = Instruction.set_list(dest, start_reg, {:multi, 0}, 0) {[vararg_instruction, set_list_instruction], ctx} + call_expr when is_struct(call_expr, Expr.Call) or is_struct(call_expr, Expr.MethodCall) -> + # Table with {a, b, f()} - init values then multi-return expansion + start_reg = ctx.next_reg + init_count = length(init_fields) + ctx = %{ctx | next_reg: start_reg + init_count} + + {init_instructions, ctx} = + init_fields + |> Enum.with_index() + |> Enum.reduce({[], ctx}, fn {val_expr, i}, {instructions, ctx} -> + target_reg = start_reg + i + {value_instructions, val_reg, ctx} = gen_expr(val_expr, ctx) + + move = + if val_reg == target_reg do + [] + else + [Instruction.move(target_reg, val_reg)] + end + + {instructions ++ value_instructions ++ move, ctx} + end) + + # Compile the tail call — its results go after init values + ctx = %{ctx | next_reg: start_reg + init_count} + {call_instructions, _call_reg, ctx} = gen_expr(call_expr, ctx) + call_instructions = patch_call_result_count(call_instructions, -2) + + # set_list with {:multi, init_count} to use multi_return_count from state + set_list_instruction = Instruction.set_list(dest, start_reg, {:multi, init_count}, 0) + {init_instructions ++ call_instructions ++ [set_list_instruction], ctx} + _ -> # Normal list fields (no vararg) start_reg = ctx.next_reg @@ -1161,6 +1262,29 @@ defmodule Lua.Compiler.Codegen do {[Instruction.load_constant(reg, nil)], reg, ctx} end + # Look up a variable by name (not by AST node identity). + # Used when we need to resolve a variable name that doesn't have a corresponding + # scope-resolved Expr.Var node (e.g., FuncDecl table chain names). + defp gen_var_by_name(name, ctx) do + case Map.get(ctx.scope.locals, name) do + nil -> + # Not a local — treat as global + reg = ctx.next_reg + ctx = %{ctx | next_reg: reg + 1} + {[Instruction.get_global(reg, name)], reg, ctx} + + local_reg -> + if MapSet.member?(ctx.scope.captured_locals, name) do + # Captured local — read from open upvalue cell + dest = ctx.next_reg + ctx = %{ctx | next_reg: dest + 1} + {[Instruction.get_open_upvalue(dest, local_reg)], dest, ctx} + else + {[], local_reg, ctx} + end + end + end + # Shared helper: generates a closure from a function node (Expr.Function, Statement.FuncDecl, etc.) # Returns {instructions, dest_reg, ctx} like gen_expr. defp gen_closure_from_node(node, ctx) do @@ -1172,7 +1296,7 @@ defmodule Lua.Compiler.Codegen do {body_instructions, body_ctx} = gen_block(node.body, %{ - next_reg: func_scope.param_count, + next_reg: max(func_scope.param_count, func_scope.max_register), source: ctx.source, scope: func_locals_scope, prototypes: [] @@ -1188,7 +1312,7 @@ defmodule Lua.Compiler.Codegen do upvalue_descriptors: func_scope.upvalue_descriptors, param_count: func_scope.param_count, is_vararg: func_scope.is_vararg, - max_registers: func_scope.max_register, + max_registers: Enum.max([func_scope.max_register, body_ctx.next_reg, Map.get(body_ctx, :peak_reg, 0)]), source: ctx.source, lines: lines } @@ -1203,4 +1327,34 @@ defmodule Lua.Compiler.Codegen do {[Instruction.closure(dest_reg, proto_index)], dest_reg, ctx} end + + # Move argument values to their expected contiguous positions (base_start+i) + defp gen_move_args(arg_regs, base_start) do + arg_regs + |> Enum.with_index() + |> Enum.flat_map(fn {arg_reg, i} -> + expected_reg = base_start + i + + if arg_reg == expected_reg do + [] + else + [Instruction.move(expected_reg, arg_reg)] + end + end) + end + + # Patch the last {:call, ...} instruction's result_count + defp patch_call_result_count(instructions, new_result_count) do + case List.last(instructions) do + {:call, base, arg_count, _old_result_count} -> + List.replace_at( + instructions, + length(instructions) - 1, + {:call, base, arg_count, new_result_count} + ) + + _ -> + instructions + end + end end diff --git a/lib/lua/compiler/scope.ex b/lib/lua/compiler/scope.ex index 60e9dac..088f6db 100644 --- a/lib/lua/compiler/scope.ex +++ b/lib/lua/compiler/scope.ex @@ -65,8 +65,8 @@ defmodule Lua.Compiler.Scope do def resolve(%Chunk{block: block}, _opts \\ []) do state = %State{} - # The chunk itself is an implicit function with no parameters - func_scope = %FunctionScope{} + # The chunk itself is an implicit vararg function (Lua 5.3 spec) + func_scope = %FunctionScope{is_vararg: true} state = %{state | current_function: :chunk, functions: %{chunk: func_scope}} # Resolve the chunk body @@ -183,7 +183,9 @@ defmodule Lua.Compiler.Scope do # Assign it a register loop_var_reg = state.next_register state = %{state | locals: Map.put(state.locals, var, loop_var_reg)} - state = %{state | next_register: loop_var_reg + 1} + # Reserve 3 registers: the loop variable shares with the internal counter, + # plus limit and step registers (codegen allocates base, base+1, base+2) + state = %{state | next_register: loop_var_reg + 3} # Update max_register func_scope = state.functions[state.current_function] @@ -234,6 +236,10 @@ defmodule Lua.Compiler.Scope do state = %{state | locals: Map.put(state.locals, name, reg)} state = %{state | next_register: reg + 1} + # Store the register assignment in var_map so codegen can find the correct + # register even when the same name is redefined later (e.g., two `local function f`) + state = %{state | var_map: Map.put(state.var_map, {:local_func_reg, local_func}, reg)} + # Update max_register in current function scope func_scope = state.functions[state.current_function] func_scope = %{func_scope | max_register: max(func_scope.max_register, state.next_register)} @@ -334,13 +340,22 @@ defmodule Lua.Compiler.Scope do # For now, stub out other expression types defp resolve_expr(_expr, state), do: state - # Walk up the scope chain to find a variable and create upvalue descriptors - defp find_upvalue(_name, [], _state), do: :not_found + # Walk up the scope chain to find a variable and create upvalue descriptors. + # Delegates to ensure_upvalue which handles multi-level nesting correctly. + defp find_upvalue(name, parent_scopes, state) do + ensure_upvalue(name, state.current_function, parent_scopes, state) + end + + # ensure_upvalue(name, for_function, parent_scopes, state) + # Ensures that for_function has an upvalue for the variable `name`. + # Walks up parent_scopes to find the variable, creating upvalue descriptors + # in each intermediate function as needed. + defp ensure_upvalue(_name, _for_function, [], _state), do: :not_found - defp find_upvalue(name, [parent | rest], state) do + defp ensure_upvalue(name, for_function, [parent | rest], state) do case Map.get(parent.locals, name) do nil -> - # Not in this parent — check if the parent already has it as an upvalue + # Not in this parent's locals — check if the parent already has it as an upvalue parent_func = state.functions[parent.function] case Enum.find_index(parent_func.upvalue_descriptors, fn @@ -348,83 +363,64 @@ defmodule Lua.Compiler.Scope do {:parent_upvalue, _, n} -> n == name end) do nil -> - # Not in parent's upvalues either — recurse further up - case find_upvalue(name, rest, state) do - {:ok, grandparent_upvalue_index, state} -> - # The variable was found further up. The parent needs an upvalue too. + # Parent doesn't have it. Recurse to ensure the parent gets it first. + case ensure_upvalue(name, parent.function, rest, state) do + {:ok, _parent_uv_index, state} -> + # Parent now has an upvalue for this variable. Find its index. parent_func = state.functions[parent.function] - parent_upvalue_index = length(parent_func.upvalue_descriptors) - parent_func = %{ - parent_func - | upvalue_descriptors: - parent_func.upvalue_descriptors ++ - [{:parent_upvalue, grandparent_upvalue_index, name}] - } - - state = %{ - state - | functions: Map.put(state.functions, parent.function, parent_func) - } + parent_uv_index = + Enum.find_index(parent_func.upvalue_descriptors, fn + {:parent_local, _, n} -> n == name + {:parent_upvalue, _, n} -> n == name + end) - # Now add upvalue in current function referencing parent's upvalue - current_func = state.functions[state.current_function] - cur_upvalue_index = length(current_func.upvalue_descriptors) + # Add to for_function referencing parent's upvalue + func = state.functions[for_function] + uv_index = length(func.upvalue_descriptors) - current_func = %{ - current_func + func = %{ + func | upvalue_descriptors: - current_func.upvalue_descriptors ++ - [{:parent_upvalue, parent_upvalue_index, name}] + func.upvalue_descriptors ++ + [{:parent_upvalue, parent_uv_index, name}] } - state = %{ - state - | functions: Map.put(state.functions, state.current_function, current_func) - } - - {:ok, cur_upvalue_index, state} + state = %{state | functions: Map.put(state.functions, for_function, func)} + {:ok, uv_index, state} :not_found -> :not_found end - parent_upvalue_index -> - # Parent already has this upvalue — reference it - current_func = state.functions[state.current_function] - cur_upvalue_index = length(current_func.upvalue_descriptors) + parent_uv_index -> + # Parent already has this upvalue — add reference in for_function + func = state.functions[for_function] + uv_index = length(func.upvalue_descriptors) - current_func = %{ - current_func + func = %{ + func | upvalue_descriptors: - current_func.upvalue_descriptors ++ - [{:parent_upvalue, parent_upvalue_index, name}] - } - - state = %{ - state - | functions: Map.put(state.functions, state.current_function, current_func) + func.upvalue_descriptors ++ + [{:parent_upvalue, parent_uv_index, name}] } - {:ok, cur_upvalue_index, state} + state = %{state | functions: Map.put(state.functions, for_function, func)} + {:ok, uv_index, state} end reg -> - # Found in parent's locals — create upvalue descriptor - current_func = state.functions[state.current_function] - upvalue_index = length(current_func.upvalue_descriptors) - - current_func = %{ - current_func - | upvalue_descriptors: current_func.upvalue_descriptors ++ [{:parent_local, reg, name}] - } + # Found in parent's locals — add {:parent_local, reg, name} to for_function + func = state.functions[for_function] + uv_index = length(func.upvalue_descriptors) - state = %{ - state - | functions: Map.put(state.functions, state.current_function, current_func) + func = %{ + func + | upvalue_descriptors: func.upvalue_descriptors ++ [{:parent_local, reg, name}] } - {:ok, upvalue_index, state} + state = %{state | functions: Map.put(state.functions, for_function, func)} + {:ok, uv_index, state} end end diff --git a/lib/lua/lexer.ex b/lib/lua/lexer.ex index f5afa4b..f5ca94c 100644 --- a/lib/lua/lexer.ex +++ b/lib/lua/lexer.ex @@ -133,6 +133,11 @@ defmodule Lua.Lexer do scan_number(<>, "", acc, pos, pos) end + # Float starting with dot: .0, .5e3, etc. + defp do_tokenize(<<".", c, rest::binary>>, acc, pos) when c in ?0..?9 do + scan_float(rest, "0." <> <>, acc, advance_column(pos, 2), pos) + end + # Three-character operators defp do_tokenize(<<"...", rest::binary>>, acc, pos) do token = {:operator, :vararg, pos} @@ -486,19 +491,24 @@ defmodule Lua.Lexer do scan_number(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) end - defp scan_number(<>, num_acc, acc, pos, start_pos) do - # Trailing dot is not part of the number - finalize_number(num_acc, <<".">>, acc, pos, start_pos) - end - defp scan_number(<<".", c, rest::binary>>, num_acc, acc, pos, start_pos) when c in ?0..?9 do - # Decimal point with digit following + # Decimal point with digit following: 0.5 scan_float(rest, num_acc <> "." <> <>, acc, advance_column(pos, 2), start_pos) end + defp scan_number(<<"..", _rest::binary>> = rest, num_acc, acc, pos, start_pos) do + # ".." is concat operator, not a decimal point: 0..5 → 0 .. 5 + finalize_number(num_acc, rest, acc, pos, start_pos) + end + + defp scan_number(<<".", c, rest::binary>>, num_acc, acc, pos, start_pos) when c in [?e, ?E] do + # "0.e5" → float with exponent + scan_float(<>, num_acc <> ".", acc, advance_column(pos, 1), start_pos) + end + defp scan_number(<<".", rest::binary>>, num_acc, acc, pos, start_pos) do - # Decimal point but no digit following - finalize number, reprocess "." - finalize_number(num_acc, <<".", rest::binary>>, acc, pos, start_pos) + # Trailing dot makes it a float: 0. → 0.0 + scan_float(rest, num_acc <> ".", acc, advance_column(pos, 1), start_pos) end defp scan_number(<>, num_acc, acc, pos, start_pos) when c in [?e, ?E] do @@ -540,12 +550,22 @@ defmodule Lua.Lexer do finalize_number(num_acc, rest, acc, pos, start_pos) end - # Scan hexadecimal number (0x...) + # Scan hexadecimal number (0x...) — supports integers, hex floats (0xF0.0), and exponents (0xABCp-3) defp scan_hex_number(<>, hex_acc, acc, pos, start_pos) when c in ?0..?9 or c in ?a..?f or c in ?A..?F do scan_hex_number(rest, hex_acc <> <>, acc, advance_column(pos, 1), start_pos) end + # Hex float: dot followed by hex digits + defp scan_hex_number(<<".", rest::binary>>, hex_acc, acc, pos, start_pos) do + scan_hex_frac(rest, hex_acc, "", acc, advance_column(pos, 1), start_pos) + end + + # Hex float: binary exponent (p/P) + defp scan_hex_number(<>, hex_acc, acc, pos, start_pos) when p in [?p, ?P] do + scan_hex_exp(rest, hex_acc, "", acc, advance_column(pos, 1), start_pos) + end + defp scan_hex_number(rest, hex_acc, acc, pos, start_pos) do case Integer.parse(hex_acc, 16) do {num, ""} -> @@ -557,6 +577,59 @@ defmodule Lua.Lexer do end end + # Scan hex fractional digits after the dot + defp scan_hex_frac(<>, int_acc, frac_acc, acc, pos, start_pos) + when c in ?0..?9 or c in ?a..?f or c in ?A..?F do + scan_hex_frac(rest, int_acc, frac_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + # Hex float fractional part followed by exponent + defp scan_hex_frac(<>, int_acc, frac_acc, acc, pos, start_pos) when p in [?p, ?P] do + scan_hex_exp(rest, int_acc, frac_acc, acc, advance_column(pos, 1), start_pos) + end + + # Hex float fractional part without exponent + defp scan_hex_frac(rest, int_acc, frac_acc, acc, pos, start_pos) do + num = build_hex_float(int_acc, frac_acc, 0) + token = {:number, num, start_pos} + do_tokenize(rest, [token | acc], pos) + end + + # Scan binary exponent (p/P followed by optional sign and decimal digits) + defp scan_hex_exp(<>, int_acc, frac_acc, acc, pos, start_pos) when sign in [?+, ?-] do + scan_hex_exp_digits(rest, int_acc, frac_acc, <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_hex_exp(rest, int_acc, frac_acc, acc, pos, start_pos) do + scan_hex_exp_digits(rest, int_acc, frac_acc, "", acc, pos, start_pos) + end + + defp scan_hex_exp_digits(<>, int_acc, frac_acc, exp_acc, acc, pos, start_pos) when c in ?0..?9 do + scan_hex_exp_digits(rest, int_acc, frac_acc, exp_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_hex_exp_digits(rest, int_acc, frac_acc, exp_acc, acc, pos, start_pos) do + exp = if exp_acc == "" or exp_acc == "+" or exp_acc == "-", do: 0, else: String.to_integer(exp_acc) + num = build_hex_float(int_acc, frac_acc, exp) + token = {:number, num, start_pos} + do_tokenize(rest, [token | acc], pos) + end + + # Build a hex float value from integer hex digits, fractional hex digits, and binary exponent + defp build_hex_float(int_hex, frac_hex, exp) do + int_val = if int_hex == "", do: 0, else: String.to_integer(int_hex, 16) + + frac_val = + if frac_hex == "" do + 0.0 + else + frac_int = String.to_integer(frac_hex, 16) + frac_int / :math.pow(16, String.length(frac_hex)) + end + + (int_val + frac_val) * :math.pow(2, exp) + end + # Finalize number token defp finalize_number(num_str, rest, acc, pos, start_pos) do case parse_number(num_str) do @@ -573,7 +646,14 @@ defmodule Lua.Lexer do defp parse_number(num_str) do if String.contains?(num_str, ".") or String.contains?(num_str, "e") or String.contains?(num_str, "E") do - case Float.parse(num_str) do + # Normalize for Elixir's Float.parse which requires digits after dot + normalized = num_str + # "0." → "0.0" + normalized = if String.ends_with?(normalized, "."), do: normalized <> "0", else: normalized + # "2.E-1" → "2.0E-1" + normalized = String.replace(normalized, ~r/\.([eE])/, ".0\\1") + + case Float.parse(normalized) do {num, ""} -> {:ok, num} _ -> {:error, :invalid_number} end diff --git a/lib/lua/vm/executor.ex b/lib/lua/vm/executor.ex index 79d102e..47e31b0 100644 --- a/lib/lua/vm/executor.ex +++ b/lib/lua/vm/executor.ex @@ -19,7 +19,8 @@ defmodule Lua.VM.Executor do @spec execute([tuple()], tuple(), list(), map(), State.t()) :: {list(), tuple(), State.t()} def execute(instructions, registers, upvalues, proto, state) do - do_execute(instructions, registers, {upvalues, %{}}, proto, state) + state = %{state | open_upvalues: %{}} + do_execute(instructions, registers, upvalues, proto, state) end @doc """ @@ -47,15 +48,19 @@ defmodule Lua.VM.Executor do callee_proto end + saved_open_upvalues = state.open_upvalues + state = %{state | open_upvalues: %{}} + {results, _callee_regs, state} = do_execute( callee_proto.instructions, callee_regs, - {callee_upvalues, %{}}, + callee_upvalues, callee_proto, state ) + state = %{state | open_upvalues: saved_open_upvalues} {results, state} end @@ -155,47 +160,35 @@ defmodule Lua.VM.Executor do end # get_upvalue - defp do_execute([{:get_upvalue, dest, index} | rest], regs, {upvalues, _} = upvalue_context, proto, state) do + defp do_execute([{:get_upvalue, dest, index} | rest], regs, upvalues, proto, state) do cell_ref = Enum.at(upvalues, index) value = Map.get(state.upvalue_cells, cell_ref) regs = put_elem(regs, dest, value) - do_execute(rest, regs, upvalue_context, proto, state) + do_execute(rest, regs, upvalues, proto, state) end # set_upvalue - defp do_execute([{:set_upvalue, index, source} | rest], regs, {upvalues, _} = upvalue_context, proto, state) do + defp do_execute([{:set_upvalue, index, source} | rest], regs, upvalues, proto, state) do cell_ref = Enum.at(upvalues, index) value = elem(regs, source) state = %{state | upvalue_cells: Map.put(state.upvalue_cells, cell_ref, value)} - do_execute(rest, regs, upvalue_context, proto, state) + do_execute(rest, regs, upvalues, proto, state) end # get_open_upvalue - read a captured local through its open upvalue cell - defp do_execute( - [{:get_open_upvalue, dest, reg} | rest], - regs, - {_upvalues, open_upvalues} = upvalue_context, - proto, - state - ) do - cell_ref = Map.fetch!(open_upvalues, reg) + defp do_execute([{:get_open_upvalue, dest, reg} | rest], regs, upvalues, proto, state) do + cell_ref = Map.fetch!(state.open_upvalues, reg) value = Map.get(state.upvalue_cells, cell_ref) regs = put_elem(regs, dest, value) - do_execute(rest, regs, upvalue_context, proto, state) + do_execute(rest, regs, upvalues, proto, state) end # set_open_upvalue - write a captured local through its open upvalue cell - defp do_execute( - [{:set_open_upvalue, reg, source} | rest], - regs, - {_upvalues, open_upvalues} = upvalue_context, - proto, - state - ) do - cell_ref = Map.fetch!(open_upvalues, reg) + defp do_execute([{:set_open_upvalue, reg, source} | rest], regs, upvalues, proto, state) do + cell_ref = Map.fetch!(state.open_upvalues, reg) value = elem(regs, source) state = %{state | upvalue_cells: Map.put(state.upvalue_cells, cell_ref, value)} - do_execute(rest, regs, upvalue_context, proto, state) + do_execute(rest, regs, upvalues, proto, state) end # source_line - track current source location @@ -334,6 +327,13 @@ defmodule Lua.VM.Executor do # Copy counter to loop variable regs = put_elem(regs, loop_var, counter) + # Clear open upvalue cells for loop-local registers (loop var + body locals) + # so each iteration gets fresh upvalue cells for its own variables + state = %{ + state + | open_upvalues: Map.reject(state.open_upvalues, fn {reg, _} -> reg >= loop_var end) + } + # Execute body case do_execute(body, regs, upvalues, proto, state) do {:break, regs, state} -> @@ -381,6 +381,14 @@ defmodule Lua.VM.Executor do put_elem(regs, var_reg, Enum.at(results, i)) end) + # Clear open upvalue cells for loop-local registers + first_var_reg = List.first(var_regs) + + state = %{ + state + | open_upvalues: Map.reject(state.open_upvalues, fn {reg, _} -> reg >= first_var_reg end) + } + # Execute body case do_execute(body, regs, upvalues, proto, state) do {:break, regs, state} -> @@ -401,55 +409,71 @@ defmodule Lua.VM.Executor do end # closure - create a closure value from a prototype, capturing upvalues - defp do_execute([{:closure, dest, proto_index} | rest], regs, {upvalues, open_upvalues}, proto, state) do + defp do_execute([{:closure, dest, proto_index} | rest], regs, upvalues, proto, state) do nested_proto = Enum.at(proto.prototypes, proto_index) # Capture upvalues based on descriptors, reusing open upvalue cells when available - {captured_upvalues, state, open_upvalues} = - Enum.reduce(nested_proto.upvalue_descriptors, {[], state, open_upvalues}, fn - {:parent_local, reg, _name}, {cells, state, open_upvalues} -> - case Map.get(open_upvalues, reg) do + {captured_upvalues, state} = + Enum.reduce(nested_proto.upvalue_descriptors, {[], state}, fn + {:parent_local, reg, _name}, {cells, state} -> + case Map.get(state.open_upvalues, reg) do nil -> # Create a new cell for this local variable cell_ref = make_ref() value = elem(regs, reg) - state = %{state | upvalue_cells: Map.put(state.upvalue_cells, cell_ref, value)} - open_upvalues = Map.put(open_upvalues, reg, cell_ref) - {cells ++ [cell_ref], state, open_upvalues} + + state = %{ + state + | upvalue_cells: Map.put(state.upvalue_cells, cell_ref, value), + open_upvalues: Map.put(state.open_upvalues, reg, cell_ref) + } + + {cells ++ [cell_ref], state} existing_cell -> # Reuse existing open upvalue cell - {cells ++ [existing_cell], state, open_upvalues} + {cells ++ [existing_cell], state} end - {:parent_upvalue, index, _name}, {cells, state, open_upvalues} -> + {:parent_upvalue, index, _name}, {cells, state} -> # Share the parent's upvalue cell - {cells ++ [Enum.at(upvalues, index)], state, open_upvalues} + {cells ++ [Enum.at(upvalues, index)], state} end) closure = {:lua_closure, nested_proto, captured_upvalues} regs = put_elem(regs, dest, closure) - do_execute(rest, regs, {upvalues, open_upvalues}, proto, state) + do_execute(rest, regs, upvalues, proto, state) end # call - invoke a function value - defp do_execute([{:call, base, arg_count, result_count} | rest], regs, upvalue_context, proto, state) do + defp do_execute([{:call, base, arg_count, result_count} | rest], regs, upvalues, proto, state) do func_value = elem(regs, base) # Collect arguments from registers base+1..base+arg_count # arg_count < 0 encodes fixed args + varargs: # -1 means 0 fixed + varargs, -2 means 1 fixed + varargs, etc. + # arg_count = {:multi, fixed} encodes fixed args + multi-return expansion args = - cond do - arg_count > 0 -> - for i <- 1..arg_count, do: elem(regs, base + i) + case arg_count do + {:multi, fixed_count} -> + # Fixed args + results from a multi-return call + multi_count = state.multi_return_count + total = fixed_count + multi_count + + if total > 0 do + for i <- 1..total, do: elem(regs, base + i) + else + [] + end + + n when is_integer(n) and n > 0 -> + for i <- 1..n, do: elem(regs, base + i) - arg_count < 0 -> + n when is_integer(n) and n < 0 -> # Collect fixed args + all varargs # Decode: -1 => 0 fixed, -2 => 1 fixed, -3 => 2 fixed, etc. - fixed_arg_count = -(arg_count + 1) - varargs = Map.get(proto, :varargs, []) - total_args = fixed_arg_count + length(varargs) + fixed_arg_count = -(n + 1) + total_args = fixed_arg_count + state.multi_return_count if total_args > 0 do for i <- 1..total_args, do: elem(regs, base + i) @@ -457,7 +481,7 @@ defmodule Lua.VM.Executor do [] end - true -> + 0 -> [] end @@ -494,17 +518,20 @@ defmodule Lua.VM.Executor do end # Execute the callee with fresh open_upvalues + saved_open_upvalues = state.open_upvalues + state = %{state | open_upvalues: %{}} + {results, _callee_regs, state} = do_execute( callee_proto.instructions, callee_regs, - {callee_upvalues, %{}}, + callee_upvalues, callee_proto, state ) - # Pop call stack frame - state = %{state | call_stack: tl(state.call_stack)} + # Pop call stack frame, restore open_upvalues + state = %{state | call_stack: tl(state.call_stack), open_upvalues: saved_open_upvalues} {results, state} @@ -561,24 +588,41 @@ defmodule Lua.VM.Executor do end end - # result_count == -1 means "return all results" (used in return f() position) - if result_count == -1 do - {results, regs, state} - else - # Place results into caller registers starting at base - regs = - if result_count > 0 do - results_list = List.wrap(results) + cond do + # result_count == -1 means "return all results" (used in return f() position) + result_count == -1 -> + {results, regs, state} + + # result_count == -2 means "multi-return expansion": place all results into + # registers starting at base, store count in state, continue execution + result_count == -2 -> + results_list = List.wrap(results) - Enum.reduce(0..(result_count - 1), regs, fn i, regs -> - value = Enum.at(results_list, i) - put_elem(regs, base + i, value) + regs = + results_list + |> Enum.with_index() + |> Enum.reduce(regs, fn {val, i}, regs -> + put_elem(regs, base + i, val) end) - else - regs - end - do_execute(rest, regs, upvalue_context, proto, state) + state = %{state | multi_return_count: length(results_list)} + do_execute(rest, regs, upvalues, proto, state) + + true -> + # Place results into caller registers starting at base + regs = + if result_count > 0 do + results_list = List.wrap(results) + + Enum.reduce(0..(result_count - 1), regs, fn i, regs -> + value = Enum.at(results_list, i) + put_elem(regs, base + i, value) + end) + else + regs + end + + do_execute(rest, regs, upvalues, proto, state) end end @@ -587,17 +631,23 @@ defmodule Lua.VM.Executor do defp do_execute([{:vararg, base, count} | rest], regs, upvalues, proto, state) do varargs = Map.get(proto, :varargs, []) - regs = + {regs, state} = if count == 0 do - # Load all varargs - Enum.reduce(Enum.with_index(varargs), regs, fn {val, i}, regs -> - put_elem(regs, base + i, val) - end) + # Load all varargs and track the count for set_list/call + regs = + Enum.reduce(Enum.with_index(varargs), regs, fn {val, i}, regs -> + put_elem(regs, base + i, val) + end) + + {regs, %{state | multi_return_count: length(varargs)}} else # Load exactly count values - Enum.reduce(0..(count - 1), regs, fn i, regs -> - put_elem(regs, base + i, Enum.at(varargs, i)) - end) + regs = + Enum.reduce(0..(count - 1), regs, fn i, regs -> + put_elem(regs, base + i, Enum.at(varargs, i)) + end) + + {regs, state} end do_execute(rest, regs, upvalues, proto, state) @@ -613,23 +663,29 @@ defmodule Lua.VM.Executor do # count == -1 means return from base including all varargs # count == 0 means return nil # count > 0 means return exactly count values - defp do_execute([{:return, base, count} | _rest], regs, _upvalues, proto, state) do + # count == {:multi_return, fixed} means return fixed values + multi-return expanded values + defp do_execute([{:return, base, {:multi_return, fixed_count}} | _rest], regs, _upvalues, _proto, state) do + total = fixed_count + state.multi_return_count + results = if total > 0, do: for(i <- 0..(total - 1), do: elem(regs, base + i)), else: [] + {results, regs, state} + end + + defp do_execute([{:return, base, count} | _rest], regs, _upvalues, _proto, state) do results = cond do count == 0 -> [nil] - count == -1 -> - # Return values from base including varargs - # We need to collect values until we've covered the vararg range - varargs = Map.get(proto, :varargs, []) - tuple_size = tuple_size(regs) - max_index = min(tuple_size - 1, base + length(varargs) + proto.param_count - 1) + count < 0 -> + # Negative count encodes fixed values + variable values (vararg or multi-return) + # -(init_count + 1): e.g. -1 = 0 fixed, -2 = 1 fixed, -3 = 2 fixed + init_count = -(count + 1) + total = init_count + state.multi_return_count - if max_index < base do - [] + if total > 0 do + for i <- 0..(total - 1), do: elem(regs, base + i) else - for i <- base..max_index, do: elem(regs, i) + [] end count > 0 -> @@ -949,8 +1005,31 @@ defmodule Lua.VM.Executor do do_execute(rest, regs, upvalues, proto, state) end + # set_list with {:multi, init_count} — multi-return expansion in table constructor + defp do_execute([{:set_list, table_reg, start, {:multi, init_count}, offset} | rest], regs, upvalues, proto, state) do + {:tref, id} = elem(regs, table_reg) + total = init_count + state.multi_return_count + + state = + State.update_table(state, {:tref, id}, fn table -> + new_data = + if total > 0 do + Enum.reduce(0..(total - 1), table.data, fn i, data -> + value = elem(regs, start + i) + Map.put(data, offset + i + 1, value) + end) + else + table.data + end + + %{table | data: new_data} + end) + + do_execute(rest, regs, upvalues, proto, state) + end + # set_list — bulk store: table[offset+i] = R[start+i-1] for i in 1..count - # count == 0 means store all values from start until nil or end of tuple + # count == 0 means variable number of values (from vararg or multi-return) defp do_execute([{:set_list, table_reg, start, count, offset} | rest], regs, upvalues, proto, state) do {:tref, id} = elem(regs, table_reg) @@ -958,25 +1037,10 @@ defmodule Lua.VM.Executor do State.update_table(state, {:tref, id}, fn table -> new_data = if count == 0 do - # Variable number of values - collect from start register onwards - # This happens with varargs in table constructors like {a, b, ...} - # The previous vararg instruction loaded all varargs into registers, - # so we need to collect values until we've collected all of them - - # Count how many values to collect by checking registers - tuple_size = tuple_size(regs) - - # Collect values from start until we reach a nil or end of data - # We know varargs were just loaded, so collect until we see - # consecutive nils or reach tuple end - values_to_collect = - start..(tuple_size - 1) - |> Enum.take_while(fn reg_idx -> - reg_idx < tuple_size && elem(regs, reg_idx) != nil - end) - |> length() + # Variable number of values - use multi_return_count which is set by + # both vararg (count=0) and call (result_count=-2) instructions + values_to_collect = state.multi_return_count - # Now collect those values if values_to_collect > 0 do Enum.reduce(0..(values_to_collect - 1), table.data, fn i, data -> value = elem(regs, start + i) @@ -1056,15 +1120,19 @@ defmodule Lua.VM.Executor do callee_proto end + saved_open_upvalues = state.open_upvalues + state = %{state | open_upvalues: %{}} + {results, _callee_regs, state} = do_execute( callee_proto.instructions, callee_regs, - {callee_upvalues, %{}}, + callee_upvalues, callee_proto, state ) + state = %{state | open_upvalues: saved_open_upvalues} {results, state} end @@ -1289,12 +1357,10 @@ defmodule Lua.VM.Executor do {:lua_closure, callee_proto, callee_upvalues} -> # Call the Lua closure metamethod - # We need to execute the closure with the arguments args = [a, b] - - # Allocate enough registers for the function (Lua typically uses up to 250 registers) - # We need to pad the args to fill the register space the function expects initial_regs = List.to_tuple(args ++ List.duplicate(nil, 248)) + saved_open_upvalues = state.open_upvalues + state = %{state | open_upvalues: %{}} {results, _final_regs, new_state} = do_execute( @@ -1305,7 +1371,8 @@ defmodule Lua.VM.Executor do state ) - # Return first result and new state + new_state = %{new_state | open_upvalues: saved_open_upvalues} + result = case results do [r | _] -> r @@ -1343,9 +1410,9 @@ defmodule Lua.VM.Executor do {:lua_closure, callee_proto, callee_upvalues} -> # Call the Lua closure metamethod args = [a] - - # Allocate enough registers for the function (Lua typically uses up to 250 registers) initial_regs = List.to_tuple(args ++ List.duplicate(nil, 249)) + saved_open_upvalues = state.open_upvalues + state = %{state | open_upvalues: %{}} {results, _final_regs, new_state} = do_execute( @@ -1356,7 +1423,8 @@ defmodule Lua.VM.Executor do state ) - # Return first result and new state + new_state = %{new_state | open_upvalues: saved_open_upvalues} + result = case results do [r | _] -> r @@ -1402,6 +1470,8 @@ defmodule Lua.VM.Executor do {:lua_closure, callee_proto, callee_upvalues} -> args = [a, b] initial_regs = List.to_tuple(args ++ List.duplicate(nil, 248)) + saved_open_upvalues = state.open_upvalues + state = %{state | open_upvalues: %{}} {results, _final_regs, new_state} = do_execute( @@ -1412,6 +1482,8 @@ defmodule Lua.VM.Executor do state ) + new_state = %{new_state | open_upvalues: saved_open_upvalues} + result = case results do [r | _] -> r @@ -1492,10 +1564,17 @@ defmodule Lua.VM.Executor do defp safe_floor_divide(a, b) do with {:ok, na} <- to_number(a), {:ok, nb} <- to_number(b) do - if trunc(nb) == 0 do - raise RuntimeError, value: "attempt to divide by zero" - else - div(trunc(na), trunc(nb)) + cond do + nb == 0 or nb == 0.0 -> + raise RuntimeError, value: "attempt to divide by zero" + + is_integer(na) and is_integer(nb) -> + # Lua floor division for integers + lua_idiv(na, nb) + + true -> + # Float floor division + Float.floor(na / nb) * 1.0 end else {:error, val} -> @@ -1509,10 +1588,17 @@ defmodule Lua.VM.Executor do defp safe_modulo(a, b) do with {:ok, na} <- to_number(a), {:ok, nb} <- to_number(b) do - if trunc(nb) == 0 do - raise RuntimeError, value: "attempt to perform modulo by zero" - else - rem(trunc(na), trunc(nb)) + cond do + nb == 0 or nb == 0.0 -> + raise RuntimeError, value: "attempt to perform modulo by zero" + + is_integer(na) and is_integer(nb) -> + # Lua floor modulo for integers: a - floor_div(a, b) * b + na - lua_idiv(na, nb) * nb + + true -> + # Float floor modulo: a - floor(a/b) * b + na - Float.floor(na / nb) * nb end else {:error, val} -> @@ -1523,6 +1609,14 @@ defmodule Lua.VM.Executor do end end + # Lua-style integer floor division (rounds toward negative infinity) + defp lua_idiv(a, b) do + q = div(a, b) + r = rem(a, b) + # Adjust if remainder has different sign than divisor + if r != 0 and Bitwise.bxor(r, b) < 0, do: q - 1, else: q + end + defp safe_power(a, b) do with {:ok, na} <- to_number(a), {:ok, nb} <- to_number(b) do diff --git a/lib/lua/vm/state.ex b/lib/lua/vm/state.ex index 467c59a..af657b9 100644 --- a/lib/lua/vm/state.ex +++ b/lib/lua/vm/state.ex @@ -9,13 +9,15 @@ defmodule Lua.VM.State do call_stack: [], metatables: %{}, upvalue_cells: %{}, + open_upvalues: %{}, tables: %{}, table_next_id: 0, userdata: %{}, userdata_next_id: 0, private: %{}, current_line: 0, - current_source: nil + current_source: nil, + multi_return_count: 0 @type t :: %__MODULE__{ globals: map(), @@ -28,7 +30,8 @@ defmodule Lua.VM.State do userdata_next_id: non_neg_integer(), private: map(), current_line: non_neg_integer(), - current_source: binary() | nil + current_source: binary() | nil, + multi_return_count: non_neg_integer() } @doc """ diff --git a/lib/lua/vm/stdlib.ex b/lib/lua/vm/stdlib.ex index 3690343..12ffdce 100644 --- a/lib/lua/vm/stdlib.ex +++ b/lib/lua/vm/stdlib.ex @@ -101,6 +101,9 @@ defmodule Lua.VM.Stdlib do # Set _G global (the proxy table itself is stored in the raw data for _G._G == _G) state = State.set_global(state, "_G", g_ref) + # _ENV is the environment table — equivalent to _G for top-level code + state = State.set_global(state, "_ENV", g_ref) + # Store _G in the proxy's raw data so _G._G == _G works without hitting __index state = State.update_table(state, g_ref, fn table -> @@ -177,12 +180,12 @@ defmodule Lua.VM.Stdlib do message = case rest do [msg | _] -> msg - [] -> "assertion failed!" + [] -> "assertion failed! (line #{state.current_line})" end raise AssertionError, value: message else - {[value], state} + {[value | rest], state} end end diff --git a/lib/lua/vm/stdlib/string.ex b/lib/lua/vm/stdlib/string.ex index 4c99dcd..01835ec 100644 --- a/lib/lua/vm/stdlib/string.ex +++ b/lib/lua/vm/stdlib/string.ex @@ -53,7 +53,10 @@ defmodule Lua.VM.Stdlib.String do "find" => {:native_func, &string_find/2}, "match" => {:native_func, &string_match/2}, "gmatch" => {:native_func, &string_gmatch/2}, - "gsub" => {:native_func, &string_gsub/2} + "gsub" => {:native_func, &string_gsub/2}, + "packsize" => {:native_func, &string_packsize/2}, + "pack" => {:native_func, &string_pack/2}, + "unpack" => {:native_func, &string_unpack/2} } # Create the string table in VM state @@ -822,4 +825,42 @@ defmodule Lua.VM.Stdlib.String do defp raise_arg_expected(arg_num, func_name) do raise ArgumentError.value_expected("string.#{func_name}", arg_num) end + + # string.packsize(fmt) — returns size in bytes for the given format string + # Supports basic format codes used in Lua 5.3 + defp string_packsize([fmt | _], state) when is_binary(fmt) do + size = compute_pack_size(fmt, 0) + {[size], state} + end + + defp compute_pack_size("", acc), do: acc + defp compute_pack_size(<<"b", rest::binary>>, acc), do: compute_pack_size(rest, acc + 1) + defp compute_pack_size(<<"B", rest::binary>>, acc), do: compute_pack_size(rest, acc + 1) + defp compute_pack_size(<<"h", rest::binary>>, acc), do: compute_pack_size(rest, acc + 2) + defp compute_pack_size(<<"H", rest::binary>>, acc), do: compute_pack_size(rest, acc + 2) + defp compute_pack_size(<<"i", rest::binary>>, acc), do: compute_pack_size(rest, acc + 4) + defp compute_pack_size(<<"I", rest::binary>>, acc), do: compute_pack_size(rest, acc + 4) + defp compute_pack_size(<<"l", rest::binary>>, acc), do: compute_pack_size(rest, acc + 8) + defp compute_pack_size(<<"L", rest::binary>>, acc), do: compute_pack_size(rest, acc + 8) + defp compute_pack_size(<<"j", rest::binary>>, acc), do: compute_pack_size(rest, acc + 8) + defp compute_pack_size(<<"J", rest::binary>>, acc), do: compute_pack_size(rest, acc + 8) + defp compute_pack_size(<<"n", rest::binary>>, acc), do: compute_pack_size(rest, acc + 8) + defp compute_pack_size(<<"N", rest::binary>>, acc), do: compute_pack_size(rest, acc + 8) + defp compute_pack_size(<<"f", rest::binary>>, acc), do: compute_pack_size(rest, acc + 4) + defp compute_pack_size(<<"d", rest::binary>>, acc), do: compute_pack_size(rest, acc + 8) + defp compute_pack_size(<<"T", rest::binary>>, acc), do: compute_pack_size(rest, acc + 8) + defp compute_pack_size(<<" ", rest::binary>>, acc), do: compute_pack_size(rest, acc) + defp compute_pack_size(<<"<", rest::binary>>, acc), do: compute_pack_size(rest, acc) + defp compute_pack_size(<<">", rest::binary>>, acc), do: compute_pack_size(rest, acc) + defp compute_pack_size(<<"=", rest::binary>>, acc), do: compute_pack_size(rest, acc) + + # string.pack — stub + defp string_pack(_args, _state) do + raise Lua.RuntimeException, "string.pack not yet implemented" + end + + # string.unpack — stub + defp string_unpack(_args, _state) do + raise Lua.RuntimeException, "string.unpack not yet implemented" + end end diff --git a/lib/lua/vm/stdlib/table.ex b/lib/lua/vm/stdlib/table.ex index 3ace10b..7520a02 100644 --- a/lib/lua/vm/stdlib/table.ex +++ b/lib/lua/vm/stdlib/table.ex @@ -258,8 +258,9 @@ defmodule Lua.VM.Stdlib.Table do # table.unpack(list [, i [, j]]) defp table_unpack([{:tref, _} = tref | rest], state) do table = State.get_table(state, tref) - i = Enum.at(rest, 0, 1) - j = Enum.at(rest, 1, get_table_length(table)) + # Treat nil as "not provided" — fall back to defaults + i = Enum.at(rest, 0) || 1 + j = Enum.at(rest, 1) || get_table_length(table) if !is_integer(i) do raise ArgumentError, diff --git a/test/lua/lexer_test.exs b/test/lua/lexer_test.exs index 1a0a4d5..0e8ff86 100644 --- a/test/lua/lexer_test.exs +++ b/test/lua/lexer_test.exs @@ -577,15 +577,15 @@ defmodule Lua.LexerTest do end test "handles trailing dot after number" do - # "42." should tokenize as number 42 followed by dot + # "42." is a valid float literal in Lua 5.3 (= 42.0) assert {:ok, tokens} = Lexer.tokenize("42.") - assert [{:number, 42, _}, {:delimiter, :dot, _}, {:eof, _}] = tokens + assert [{:number, 42.0, _}, {:eof, _}] = tokens end test "handles decimal point without following digit" do - # "42.x" should be number 42 followed by dot and identifier x + # "42.x" is float 42.0 followed by identifier x assert {:ok, tokens} = Lexer.tokenize("42.x") - assert [{:number, 42, _}, {:delimiter, :dot, _}, {:identifier, "x", _}, {:eof, _}] = tokens + assert [{:number, 42.0, _}, {:identifier, "x", _}, {:eof, _}] = tokens end test "reports error for invalid hex number" do diff --git a/test/lua53_suite_test.exs b/test/lua53_suite_test.exs index c007f6f..de7a2c0 100644 --- a/test/lua53_suite_test.exs +++ b/test/lua53_suite_test.exs @@ -15,7 +15,7 @@ defmodule Lua.Lua53SuiteTest do |> Enum.sort() # Tests that are ready to run (not skipped) - @ready_tests ["simple_test.lua"] + @ready_tests ["simple_test.lua", "api.lua", "code.lua", "vararg.lua"] # Tests that require features not yet implemented # As we implement features, move tests from here to @ready_tests diff --git a/test/lua_test.exs b/test/lua_test.exs index ce1d370..37682a7 100644 --- a/test/lua_test.exs +++ b/test/lua_test.exs @@ -1523,7 +1523,9 @@ defmodule LuaTest do return f(10, 20, 30) """ - assert {[3, 20], _} = Lua.eval!(lua, code) + # In Lua 5.3, the last call in a return list expands all its results. + # select(2, 10, 20, 30) returns 20, 30 + assert {[3, 20, 30], _} = Lua.eval!(lua, code) end test "varargs in function call", %{lua: lua} do @@ -1866,6 +1868,504 @@ defmodule LuaTest do end end + describe "compiler fixes" do + setup do + %{lua: Lua.new(sandboxed: [])} + end + + test "redefine local function with same name", %{lua: lua} do + code = """ + local function f(x) return x + 1 end + assert(f(10) == 11) + local function f(x) return x + 2 end + assert(f(10) == 12) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "hex float literals", %{lua: lua} do + assert {[240.0], _} = Lua.eval!(lua, "return 0xF0.0") + assert {[343.5], _} = Lua.eval!(lua, "return 0xABCp-3") + assert {[1.0], _} = Lua.eval!(lua, "return 0x1p0") + assert {[255], _} = Lua.eval!(lua, "return 0xFF") + end + + test "assert returns all arguments", %{lua: lua} do + assert {[1, 2, 3], _} = Lua.eval!(lua, "return assert(1, 2, 3)") + end + + test "multi-value return register corruption", %{lua: lua} do + assert {[55, 2], _} = + Lua.eval!(lua, ~S""" + function c12(...) + local x = {...}; x.n = #x + local res = (x.n==2 and x[1] == 1 and x[2] == 2) + if res then res = 55 end + return res, 2 + end + return c12(1,2) + """) + end + + test "table.unpack with nil third argument", %{lua: lua} do + assert {[1, 2], _} = Lua.eval!(lua, "return table.unpack({1,2}, 1, nil)") + end + + test "string.find empty pattern", %{lua: lua} do + assert {[1, 0], _} = Lua.eval!(lua, "return string.find('', '')") + assert {[1, 0], _} = Lua.eval!(lua, "return string.find('alo', '')") + end + + test "select with multi-return function", %{lua: lua} do + # select(2, load(invalid)) should get the error message from load's two return values + code = ~S""" + local function multi() return nil, "error msg" end + return select(2, multi()) + """ + + assert {["error msg"], _} = Lua.eval!(lua, code) + end + + test "load returns nil and error for bad code", %{lua: lua} do + code = ~S""" + local st, msg = load("invalid code $$$$") + return st, type(msg) + """ + + assert {[nil, "string"], _} = Lua.eval!(lua, code) + end + + test "table constructor with vararg expansion", %{lua: lua} do + code = ~S""" + function f(a, ...) + local arg = {n = select('#', ...), ...} + return arg.n, arg[1], arg[2] + end + return f({}, 10, 20) + """ + + assert {[2, 10, 20], _} = Lua.eval!(lua, code) + end + + test "closure upvalue mutation", %{lua: lua} do + code = ~S""" + local A = 0 + local dummy = function () return A end + A = 1 + assert(dummy() == 1) + A = 0 + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + @tag :skip + test "closure upvalue mutation through nested scope", %{lua: lua} do + # Known limitation: upvalue mutation through nested function scopes + # doesn't propagate correctly yet (upvalue cell sharing) + code = ~S""" + local A = 0 + function f() + local dummy = function () return A end + A = 1 + local val = dummy() + A = 0 + return val + end + return f() + """ + + assert {[1], _} = Lua.eval!(lua, code) + end + + @tag :skip + test "goto scope validation in load", %{lua: lua} do + # Known limitation: compiler doesn't validate goto-label scope rules + code = ~S""" + local st, msg = load(" goto l1; do ::l1:: end ") + return st, msg + """ + + {[st, _msg], _} = Lua.eval!(lua, code) + assert st == nil + end + + test "vararg.lua early lines", %{lua: lua} do + code = ~S""" + function f(a, ...) + local arg = {n = select('#', ...), ...} + for i=1,arg.n do assert(a[i]==arg[i]) end + return arg.n + end + assert(f() == 0) + assert(f({1,2,3}, 1, 2, 3) == 3) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "constructs.lua priorities", %{lua: lua} do + assert {[true], _} = Lua.eval!(lua, "return 2^3^2 == 2^(3^2)") + assert {[true], _} = Lua.eval!(lua, "return 2^3*4 == (2^3)*4") + assert {[true], _} = Lua.eval!(lua, "return 2.0^-2 == 1/4") + assert {[true], _} = Lua.eval!(lua, "return -2^2 == -4 and (-2)^2 == 4") + end + + test "constructs.lua checkload pattern", %{lua: lua} do + # checkload uses select(2, load(s)) to get the error message + code = ~S""" + local function checkload (s, msg) + local err = select(2, load(s)) + assert(string.find(err, msg)) + end + checkload("invalid $$", "invalid") + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "constructs.lua lines 14-33", %{lua: lua} do + # Each priority assert individually + assert {[true], _} = Lua.eval!(lua, "do end; return true") + assert {[true], _} = Lua.eval!(lua, "do a = 3; assert(a == 3) end; return true") + assert {[true], _} = Lua.eval!(lua, "if false then a = 3 // 0; a = 0 % 0 end; return true") + assert {[true], _} = Lua.eval!(lua, "return 2^3^2 == 2^(3^2)") + assert {[true], _} = Lua.eval!(lua, "return 2^3*4 == (2^3)*4") + assert {[true], _} = Lua.eval!(lua, "return 2.0^-2 == 1/4") + assert {[true], _} = Lua.eval!(lua, "return -2^- -2 == - - -4") + assert {[true], _} = Lua.eval!(lua, "return -3-1-5 == 0+0-9") + assert {[true], _} = Lua.eval!(lua, "return -2^2 == -4 and (-2)^2 == 4 and 2*2-3-1 == 0") + assert {[true], _} = Lua.eval!(lua, "return -3%5 == 2 and -3+5 == 2") + assert {[true], _} = Lua.eval!(lua, "return 2*1+3/3 == 3 and 1+2 .. 3*1 == '33'") + assert {[true], _} = Lua.eval!(lua, "return not(2+1 > 3*1) and 'a'..'b' > 'a'") + end + + test "dead code not evaluated", %{lua: lua} do + assert {[true], _} = Lua.eval!(lua, "if false then a = 3 // 0 end; return true") + end + + test "multi-return in table constructor", %{lua: lua} do + # Last expression in table constructor should expand + code = ~S""" + local function multi() return 10, 20, 30 end + local t = {multi()} + return t[1], t[2], t[3] + """ + + assert {[10, 20, 30], _} = Lua.eval!(lua, code) + + # With init values before the call + code = ~S""" + local function multi() return 20, 30 end + local t = {10, multi()} + return t[1], t[2], t[3] + """ + + assert {[10, 20, 30], _} = Lua.eval!(lua, code) + + # Call NOT in last position should only return first value + code = ~S""" + local function multi() return 10, 20, 30 end + local t = {multi(), 99} + return t[1], t[2] + """ + + assert {[10, 99], _} = Lua.eval!(lua, code) + end + + test "pm.lua early lines", %{lua: lua} do + code = ~S""" + local function checkerror (msg, f, ...) + local s, err = pcall(f, ...) + assert(not s and string.find(err, msg)) + end + + function f(s, p) + local i,e = string.find(s, p) + if i then return string.sub(s, i, e) end + end + + a,b = string.find('', '') + assert(a == 1 and b == 0) + a,b = string.find('alo', '') + assert(a == 1 and b == 0) + assert(f("alo", "al") == "al") + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + end + + describe "suite triage - targeted fixes" do + setup do + %{lua: Lua.new(sandboxed: [])} + end + + test "if false should not execute body", %{lua: lua} do + # constructs.lua line 20: dead code with division by zero + code = ~S""" + if false then a = 3 // 0; a = 0 % 0 end + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "semicolons as empty statements", %{lua: lua} do + # constructs.lua lines 13-16 + code = ~S""" + do ;;; end + ; do ; a = 3; assert(a == 3) end; + ; + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "upvalue sharing between sibling closures", %{lua: lua} do + # closure.lua basic pattern - two closures sharing same upvalue + code = ~S""" + local a = 0 + local function inc() a = a + 1 end + local function get() return a end + inc() + assert(get() == 1) + inc() + assert(get() == 2) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "vararg table constructor with select", %{lua: lua} do + # vararg.lua line 7-8 pattern + code = ~S""" + function f(a, ...) + local arg = {n = select('#', ...), ...} + for i=1,arg.n do assert(a[i]==arg[i]) end + return arg.n + end + + assert(f() == 0) + assert(f({1,2,3}, 1, 2, 3) == 3) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "upvalue through nested scopes (3 levels)", %{lua: lua} do + # Simple: just one level of upvalue + code1 = ~S""" + local x = 10 + local function f() return x end + return f() + """ + + assert {[10], _} = Lua.eval!(lua, code1) + + # Two levels: variable captured through intermediate function's upvalue + code2 = ~S""" + local x = 10 + local function outer() + local function inner() + return x + end + return inner() + end + return outer() + """ + + assert {[10], _} = Lua.eval!(lua, code2) + + # Mutation through nested upvalue chain + code3 = ~S""" + local x = 10 + local function outer() + local function inner() + x = x + 1 + return x + end + return inner() + end + assert(outer() == 11) + assert(outer() == 12) + return x + """ + + assert {[12], _} = Lua.eval!(lua, code3) + end + + test "closure in for loop captures loop variable", %{lua: lua} do + # closure.lua pattern - closures in loop body + code = ~S""" + local a = {} + for i = 1, 3 do + a[i] = function() return i end + end + assert(a[1]() == 1) + assert(a[2]() == 2) + assert(a[3]() == 3) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "closure in loop accessing parameter through upvalue", %{lua: lua} do + code = ~S""" + function f(x) + local a = {} + for i=1,3 do + a[i] = function () return x end + end + return a[1](), a[2](), a[3]() + end + return f(10) + """ + + assert {[10, 10, 10], _} = Lua.eval!(lua, code) + end + + test "closure in loop with local and param upvalues", %{lua: lua} do + # Step 1: Does having a local in loop body break things? + code1 = ~S""" + local function f(x) + local a = {} + for i=1,3 do + local y = 0 + a[i] = function () return y end + end + return a[1](), a[2]() + end + return f(10) + """ + + assert {[0, 0], _} = Lua.eval!(lua, code1) + end + + test "vararg expansion in local multi-assignment", %{lua: lua} do + code = ~S""" + function f(...) + local a, b, c = ... + return a, b, c + end + return f(10, 20, 30) + """ + + assert {[10, 20, 30], _} = Lua.eval!(lua, code) + end + + test "vararg expansion in regular multi-assignment", %{lua: lua} do + code = ~S""" + function f(a, ...) + local b, c, d = ... + return a, b, c, d + end + return f(5, 4, 3, 2, 1) + """ + + assert {[5, 4, 3, 2], _} = Lua.eval!(lua, code) + end + + test "vararg.lua new-style varargs", %{lua: lua} do + code = ~S""" + function oneless (a, ...) return ... end + + function f (n, a, ...) + local b + if n == 0 then + local b, c, d = ... + return a, b, c, d, oneless(oneless(oneless(...))) + else + n, b, a = n-1, ..., a + assert(b == ...) + return f(n, a, ...) + end + end + + a,b,c,d,e = assert(f(10,5,4,3,2,1)) + assert(a==5 and b==4 and c==3 and d==2 and e==1) + + a,b,c,d,e = f(4) + assert(a==nil and b==nil and c==nil and d==nil and e==nil) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "float literal edge cases", %{lua: lua} do + code = ~S""" + assert(.0 == 0) + assert(0. == 0) + assert(.2e2 == 20) + assert(2.E-1 == 0.2) + assert(0e12 == 0) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "priority: power is right-associative", %{lua: lua} do + # constructs.lua line 25 + code = ~S""" + assert(2^3^2 == 2^(3^2)) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "priority: power vs multiply", %{lua: lua} do + # constructs.lua line 26 + code = ~S""" + assert(2^3*4 == (2^3)*4) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "string concat with shift operator priority", %{lua: lua} do + # constructs.lua line 35 + code = ~S""" + assert("7" .. 3 << 1 == 146) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + + test "bitwise.lua early pattern - pcall catches bitwise error", %{lua: lua} do + # Test pcall catches bitwise error and the checkerror pattern works + code = ~S""" + local s, err = pcall(function() return 1 | nil end) + assert(not s) + assert(type(err) == "string") + + -- Test the checkerror pattern used by many suite tests + local function checkerror(msg, f, ...) + local s, err = pcall(f, ...) + assert(not s and string.find(err, msg)) + end + checkerror("nil", function() return 1 | nil end) + return true + """ + + assert {[true], _} = Lua.eval!(lua, code) + end + end + defp test_file(name) do Path.join(["test", "fixtures", name]) end