From 2d9c5468449ab2e565958b34335ff2084e0e6a0d Mon Sep 17 00:00:00 2001 From: Outcry <843648230@qq.com> Date: Wed, 25 Mar 2026 06:01:05 +0000 Subject: [PATCH] feat(other): support multi value for wasm --- .ci/run_test_suite.sh | 5 + .github/workflows/dtvm_wasm_test_x86.yml | 123 + CMakeLists.txt | 3 + src/CMakeLists.txt | 4 + src/action/bytecode_visitor.h | 227 +- src/action/function_loader.cpp | 26 +- src/action/function_loader.h | 8 + src/action/instantiator.cpp | 2 +- src/action/interpreter.cpp | 63 + src/action/interpreter.h | 31 +- src/action/loader_common.h | 23 + src/action/module_loader.cpp | 12 +- src/common/defines.h | 8 +- src/common/type.h | 3 + .../wasm_frontend/wasm_mir_compiler.cpp | 208 +- .../wasm_frontend/wasm_mir_compiler.h | 101 +- src/runtime/module.cpp | 8 +- src/runtime/module.h | 31 + src/runtime/runtime.cpp | 2 +- src/singlepass/common/codegen.h | 169 +- src/singlepass/common/datalayout.h | 2 +- src/singlepass/x64/codegen.h | 3343 +++++++++-------- src/singlepass/x64/operand.h | 1 + src/tests/CMakeLists.txt | 14 +- src/utils/wasm.cpp | 3 +- tests/wast/multi_value/basic_test.wast | 14 + tests/wast/multi_value/basic_test_main.wast | 19 + tests/wast/multi_value/multi_value.wast | 147 + 28 files changed, 3015 insertions(+), 1585 deletions(-) create mode 100644 tests/wast/multi_value/basic_test.wast create mode 100644 tests/wast/multi_value/basic_test_main.wast create mode 100644 tests/wast/multi_value/multi_value.wast diff --git a/.ci/run_test_suite.sh b/.ci/run_test_suite.sh index 668757cec..05f47dd4a 100644 --- a/.ci/run_test_suite.sh +++ b/.ci/run_test_suite.sh @@ -93,6 +93,11 @@ case $CPU_EXCEPTION_TYPE in ;; esac +# Multi-value support +if [ "${ENABLE_MULTI_VALUE:-false}" = true ]; then + CMAKE_OPTIONS="$CMAKE_OPTIONS -DZEN_ENABLE_WASI_MULTI_VALUE=ON" +fi + STACK_TYPES=("-DZEN_ENABLE_VIRTUAL_STACK=ON" "-DZEN_ENABLE_VIRTUAL_STACK=OFF") if [[ $RUN_MODE == "interpreter" ]]; then STACK_TYPES=("-DZEN_ENABLE_VIRTUAL_STACK=OFF") diff --git a/.github/workflows/dtvm_wasm_test_x86.yml b/.github/workflows/dtvm_wasm_test_x86.yml index 4fd9c4896..78a40f623 100644 --- a/.github/workflows/dtvm_wasm_test_x86.yml +++ b/.github/workflows/dtvm_wasm_test_x86.yml @@ -178,3 +178,126 @@ jobs: cmake --build build -j7 # use dtvm to test evm abi wasm files # ./build/dtvm -m 2 -f call counter.wasm + + build_test_multi_value_interp_on_x86: + name: Build and test DTVM multi-value (interpreter) on x86-64 + runs-on: ubuntu-latest + container: + image: dtvmdev1/dtvm-dev-x64:main + steps: + - name: Check out code + uses: actions/checkout@v3 + with: + submodules: "true" + - name: Code Format Check + run: | + ./tools/format.sh check + - name: Test Git clone + run: | + git clone https://github.com/asmjit/asmjit.git + - name: Install llvm + run: | + echo "current home is $HOME" + export CUR_PROJECT=$(pwd) + cd /opt + cd $CUR_PROJECT + export LLVM_SYS_150_PREFIX=/opt/llvm15 + export LLVM_DIR=$LLVM_SYS_150_PREFIX/lib/cmake/llvm + export PATH=$LLVM_SYS_150_PREFIX/bin:$PATH + cd tests/wast/spec + git apply ../spec.patch + cd $CUR_PROJECT + export CMAKE_BUILD_TARGET=Debug + export ENABLE_ASAN=true + export RUN_MODE=interpreter + export INPUT_FORMAT=wasm + export ENABLE_LAZY=true + export ENABLE_MULTITHREAD=true + export TestSuite=microsuite + export CPU_EXCEPTION_TYPE='check' + export ENABLE_GAS_METER=false + export ENABLE_MULTI_VALUE=true + + bash .ci/run_test_suite.sh + + build_test_multi_value_singlepass_on_x86: + name: Build and test DTVM multi-value (singlepass) on x86-64 + runs-on: ubuntu-latest + container: + image: dtvmdev1/dtvm-dev-x64:main + steps: + - name: Check out code + uses: actions/checkout@v3 + with: + submodules: "true" + - name: Code Format Check + run: | + ./tools/format.sh check + - name: Test Git clone + run: | + git clone https://github.com/asmjit/asmjit.git + - name: Install llvm + run: | + echo "current home is $HOME" + export CUR_PROJECT=$(pwd) + cd /opt + cd $CUR_PROJECT + export LLVM_SYS_150_PREFIX=/opt/llvm15 + export LLVM_DIR=$LLVM_SYS_150_PREFIX/lib/cmake/llvm + export PATH=$LLVM_SYS_150_PREFIX/bin:$PATH + cd tests/wast/spec + git apply ../spec.patch + cd $CUR_PROJECT + export CMAKE_BUILD_TARGET=Debug + export ENABLE_ASAN=true + export RUN_MODE=singlepass + export INPUT_FORMAT=wasm + export ENABLE_LAZY=true + export ENABLE_MULTITHREAD=true + export TestSuite=microsuite + export CPU_EXCEPTION_TYPE='check' + export ENABLE_GAS_METER=false + export ENABLE_MULTI_VALUE=true + + bash .ci/run_test_suite.sh + + build_test_multi_value_multipass_on_x86: + name: Build and test DTVM multi-value (multipass) on x86-64 + runs-on: ubuntu-latest + container: + image: dtvmdev1/dtvm-dev-x64:main + steps: + - name: Check out code + uses: actions/checkout@v3 + with: + submodules: "true" + - name: Code Format Check + run: | + ./tools/format.sh check + - name: Test Git clone + run: | + git clone https://github.com/asmjit/asmjit.git + - name: Install llvm + run: | + echo "current home is $HOME" + export CUR_PROJECT=$(pwd) + cd /opt + cd $CUR_PROJECT + export LLVM_SYS_150_PREFIX=/opt/llvm15 + export LLVM_DIR=$LLVM_SYS_150_PREFIX/lib/cmake/llvm + export PATH=$LLVM_SYS_150_PREFIX/bin:$PATH + cd tests/wast/spec + git apply ../spec.patch + cd $CUR_PROJECT + export CMAKE_BUILD_TARGET=Debug + export ENABLE_ASAN=true + export RUN_MODE=multipass + export INPUT_FORMAT=wasm + export ENABLE_LAZY=true + export ENABLE_MULTITHREAD=true + export TestSuite=microsuite + export CPU_EXCEPTION_TYPE='check' + export ENABLE_GAS_METER=false + export ENABLE_MULTI_VALUE=true + + bash .ci/run_test_suite.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 49275bea2..f65732e42 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,9 @@ option(ZEN_ENABLE_DUMP_CALL_STACK "Enable exception call stack dump" OFF) option(ZEN_ENABLE_EVM_GAS_REGISTER "Enable gas register optimization for x86_64 multipass JIT" OFF ) +option(ZEN_ENABLE_WASI_MULTI_VALUE + "Enable WASI multi-value extension (multiple return values)" OFF +) # Blockchain options option(ZEN_ENABLE_CHECKED_ARITHMETIC "Enable checked arithmetic" OFF) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bbe49d442..b2cf11ccd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -41,6 +41,10 @@ if(ZEN_ENABLE_DWASM) add_definitions(-DZEN_ENABLE_DWASM) endif() +if(ZEN_ENABLE_WASI_MULTI_VALUE) + add_definitions(-DZEN_ENABLE_WASI_MULTI_VALUE) +endif() + if(ZEN_ENABLE_VIRTUAL_STACK) add_definitions(-DZEN_ENABLE_VIRTUAL_STACK) endif() diff --git a/src/action/bytecode_visitor.h b/src/action/bytecode_visitor.h index 5d22dd711..4d039fc7d 100644 --- a/src/action/bytecode_visitor.h +++ b/src/action/bytecode_visitor.h @@ -70,7 +70,9 @@ template class WASMByteCodeVisitor { private: void push(Operand Opnd) { - ZEN_ASSERT(!Opnd.isReg() || Opnd.isTempReg()); + // Note: We allow non-temp register operands (e.g., ABI return registers + // from calls) These are not managed by the temp register pool and won't be + // released ZEN_ASSERT(Opnd.getType() != WASMType::VOID); Stack.push(Opnd); } @@ -82,7 +84,10 @@ template class WASMByteCodeVisitor { return Opnd; } + Operand peek(uint8_t Depth = 0) { return Stack.peek(Depth); } + Operand getTop() { return Stack.getTop(); } + Operand getTop(uint32_t Depth) { return Stack.peek(Depth); } bool decode() { const uint8_t *Ip = CurFunc->CodePtr; @@ -108,19 +113,73 @@ template class WASMByteCodeVisitor { break; case Opcode::BLOCK: { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: block type can be a type index + WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip); + if (BlockType == WASMType::ERROR_TYPE) { + // Type index - read as signed LEB128 from current position + int32_t TypeIndex; + Ip = utils::readLEBNumber(Ip, IpEnd, TypeIndex); + if (TypeIndex >= 0 && CurMod->isValidType(TypeIndex)) { + const TypeEntry *Type = CurMod->getDeclaredType(TypeIndex); + // For multi-value blocks, we use VOID as marker and handle + // specially The actual type info is stored in the TypeEntry + handleBlockMultiValue(Type); + break; + } + } else { + Ip++; + } +#else WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip++); +#endif handleBlock(BlockType); break; } case Opcode::LOOP: { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: block type can be a type index + WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip); + if (BlockType == WASMType::ERROR_TYPE) { + // Type index - read as signed LEB128 from current position + int32_t TypeIndex; + Ip = utils::readLEBNumber(Ip, IpEnd, TypeIndex); + if (TypeIndex >= 0 && CurMod->isValidType(TypeIndex)) { + const TypeEntry *Type = CurMod->getDeclaredType(TypeIndex); + handleLoopMultiValue(Type); + break; + } + } else { + Ip++; + } +#else WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip++); +#endif handleLoop(BlockType); break; } case Opcode::IF: { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: block type can be a type index + WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip); + if (BlockType == WASMType::ERROR_TYPE) { + // Type index - read as signed LEB128 from current position + int32_t TypeIndex; + Ip = utils::readLEBNumber(Ip, IpEnd, TypeIndex); + if (TypeIndex >= 0 && CurMod->isValidType(TypeIndex)) { + const TypeEntry *Type = CurMod->getDeclaredType(TypeIndex); + Operand Cond = pop(); + handleIfMultiValue(Cond, Type); + break; + } + } else { + Ip++; + } +#else WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip++); +#endif handleIf(BlockType); break; } @@ -730,7 +789,9 @@ template class WASMByteCodeVisitor { // always emit return after function end, as branch instructions might // target a function's end and jump out +#ifndef ZEN_ENABLE_WASI_MULTI_VALUE handleReturn(); +#endif return true; } @@ -752,11 +813,35 @@ template class WASMByteCodeVisitor { Builder.handleIf(Cond, BlockType, Stack.getSize()); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: handle blocks with multiple return values + void handleBlockMultiValue(const TypeEntry *Type) { + Builder.handleBlockMultiValue(Type, Stack.getSize()); + } + + void handleLoopMultiValue(const TypeEntry *Type) { + Builder.handleLoopMultiValue(Type, Stack.getSize()); + } + + void handleIfMultiValue(Operand Cond, const TypeEntry *Type) { + Builder.handleIfMultiValue(Cond, Type, Stack.getSize()); + } +#endif + void handleElse() { const CtrlBlockInfo &Info = Builder.getCurrentBlockInfo(); ZEN_ASSERT(verifyCtrlInstValType(Info)); ZEN_ASSERT(Info.getKind() == CtrlBlockKind::IF); - if (Info.getType() != WASMType::VOID && Info.reachable()) { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + if (Info.isMultiValue() && Info.reachable()) { + // Handle multiple results for multi-value blocks + const auto &Results = Info.getResults(); + for (int I = Results.size() - 1; I >= 0; --I) { + Builder.makeAssignment(Results[I].getType(), Results[I], pop()); + } + } else +#endif + if (Info.getType() != WASMType::VOID && Info.reachable()) { // make an assignment to copy stack top to block info result Operand BlockResult = Info.getResult(); Builder.makeAssignment(Info.getType(), BlockResult, pop()); @@ -768,6 +853,45 @@ template class WASMByteCodeVisitor { const CtrlBlockInfo &Info = Builder.getCurrentBlockInfo(); ZEN_ASSERT(verifyCtrlInstValType(Info)); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // For function entry, handle the implicit return if reachable + // The return values remain on the stack for handleReturn to consume + if (Info.getKind() == CtrlBlockKind::FUNC_ENTRY) { + bool WasReachable = Info.reachable(); + Builder.handleEnd(Info); + // After handleEnd, the control stack is empty, so we check WasReachable + // For unreachable code, return is handled by finalizeFunctionBase + if (WasReachable) { + handleReturn(); + } + return; + } + if (Info.isMultiValue()) { + // Handle multiple results for multi-value blocks + // IMPORTANT: Copy results BEFORE handleEnd, because handleEnd pops the + // BlockInfo and the Results reference would become dangling + std::vector ResultsCopy = Info.getResults(); + if (Info.reachable()) { + ZEN_ASSERT(Stack.getSize() >= ResultsCopy.size()); + for (int I = ResultsCopy.size() - 1; I >= 0; --I) { + Builder.makeAssignment(ResultsCopy[I].getType(), ResultsCopy[I], + pop()); + } + } + // value stack may have excess elements after an unconditional branch; + // we need to pop them out before returing to the outer block + while (Stack.getSize() > Info.getStackSize()) { + Stack.pop(); + } + // NOTE: `info` is popped off its container after this call + Builder.handleEnd(Info); + // Push results back onto the stack + for (const auto &Result : ResultsCopy) { + push(Result); + } + return; + } +#endif Operand BlockResult = Info.getResult(); if (Info.getType() != WASMType::VOID && Info.reachable()) { // make an assignment to copy stack top to block info result @@ -792,7 +916,17 @@ template class WASMByteCodeVisitor { const CtrlBlockInfo &Info = Builder.getBlockInfo(Level); bool JumpBack = (Info.getKind() == CtrlBlockKind::LOOP); ZEN_ASSERT(verifyCtrlInstValType(Info, JumpBack)); - if (Info.getType() != WASMType::VOID && !JumpBack) { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + if (Info.isMultiValue() && !JumpBack) { + // Handle multiple results for multi-value blocks + const auto &Results = Info.getResults(); + for (int I = Results.size() - 1; I >= 0; --I) { + Builder.makeAssignment(Results[I].getType(), Results[I], + getTop(Results.size() - I - 1)); + } + } else +#endif + if (Info.getType() != WASMType::VOID && !JumpBack) { // make an assignment to copy stack top to block info result Operand BlockResult = Info.getResult(); Builder.makeAssignment(Info.getType(), BlockResult, getTop()); @@ -805,7 +939,17 @@ template class WASMByteCodeVisitor { const CtrlBlockInfo &Info = Builder.getBlockInfo(Level); bool JumpBack = (Info.getKind() == CtrlBlockKind::LOOP); ZEN_ASSERT(verifyCtrlInstValType(Info, JumpBack)); - if (Info.getType() != WASMType::VOID && !JumpBack) { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + if (Info.isMultiValue() && !JumpBack) { + // Handle multiple results for multi-value blocks + const auto &Results = Info.getResults(); + for (int I = Results.size() - 1; I >= 0; --I) { + Builder.makeAssignment(Results[I].getType(), Results[I], + getTop(Results.size() - I - 1)); + } + } else +#endif + if (Info.getType() != WASMType::VOID && !JumpBack) { // make an assignment to copy stack top to block info result Operand BlockResult = Info.getResult(); Builder.makeAssignment(Info.getType(), BlockResult, getTop()); @@ -843,12 +987,33 @@ template class WASMByteCodeVisitor { void handleReturn() { const TypeEntry &Type = Ctx->getWasmFuncType(); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // For multi-value, collect all return values + if (Type.NumReturns > 0 && Stack.getSize() > 0) { + if (Type.NumReturns == 1) { + Builder.handleReturn(pop()); + } else { + // Collect multiple return values (in reverse order from stack) + std::vector ReturnOps; + ReturnOps.reserve(Type.NumReturns); + for (uint32_t I = 0; I < Type.NumReturns && Stack.getSize() > 0; ++I) { + ReturnOps.push_back(pop()); + } + // Reverse to get correct order + std::reverse(ReturnOps.begin(), ReturnOps.end()); + Builder.handleReturnMultiValue(ReturnOps); + } + } else { + Builder.handleReturn(Operand()); + } +#else ZEN_ASSERT(Stack.getSize() >= Type.NumReturns); if (Type.NumReturns > 0 && Stack.getSize() > 0) { Builder.handleReturn(pop()); } else if (Type.NumReturns == 0) { Builder.handleReturn(Operand()); } +#endif } void handleCall(uint32_t FuncIdx, uint32_t CallOffset) { @@ -870,10 +1035,22 @@ template class WASMByteCodeVisitor { Args.resize(Type->NumParams); collectCallParams(Type, Args); - Operand Result = - Builder.handleCall(FuncIdx, Target, IsImport, FarCall, ArgInfo, Args); - if (Type->NumReturns > 0) { - push(Result); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + if (Type->NumReturns > 1) { + // Multi-value return + std::vector Results = Builder.handleCallMultiValue( + FuncIdx, Target, IsImport, FarCall, ArgInfo, Args); + for (const auto &Result : Results) { + push(Result); + } + } else +#endif + { + Operand Result = + Builder.handleCall(FuncIdx, Target, IsImport, FarCall, ArgInfo, Args); + if (Type->NumReturns > 0) { + push(Result); + } } } @@ -888,12 +1065,23 @@ template class WASMByteCodeVisitor { Args.resize(Type->NumParams); collectCallParams(Type, Args); TypeIdx = CurMod->getDeclaredType(TypeIdx)->SmallestTypeIdx; - Operand Result = Builder.handleCallIndirect(TypeIdx, IndirectFuncIdx, - TableIdx, ArgInfo, Args); - if (Type->NumReturns > 0) { - ZEN_ASSERT(Type->NumReturns == 1); - push(Result); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + if (Type->NumReturns > 1) { + // Multi-value return + std::vector Results = Builder.handleCallIndirectMultiValue( + TypeIdx, IndirectFuncIdx, TableIdx, ArgInfo, Args); + for (const auto &Result : Results) { + push(Result); + } + } else +#endif + { + Operand Result = Builder.handleCallIndirect(TypeIdx, IndirectFuncIdx, + TableIdx, ArgInfo, Args); + if (Type->NumReturns > 0) { + push(Result); + } } } @@ -1168,6 +1356,19 @@ template class WASMByteCodeVisitor { // value stack becomes unconstrained after an unconditional branch return true; } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + if (Info.isMultiValue()) { + // For multi-value blocks, check all result types + const auto &Results = Info.getResults(); + uint32_t NumResults = Results.size(); + ZEN_ASSERT(Info.getStackSize() + NumResults <= Stack.getSize()); + for (uint32_t I = 0; I < NumResults; ++I) { + ZEN_ASSERT(Results[I].getType() == + Stack.peek(NumResults - 1 - I).getType()); + } + return true; + } +#endif if (Info.getType() == WASMType::VOID || JumpBack) { // on an unconditional branch, value stack may have excess elements ZEN_ASSERT(Info.getStackSize() <= Stack.getSize()); diff --git a/src/action/function_loader.cpp b/src/action/function_loader.cpp index 3cb3c295b..37bbbbbf4 100644 --- a/src/action/function_loader.cpp +++ b/src/action/function_loader.cpp @@ -20,7 +20,7 @@ bool FunctionLoader::ControlBlockType::isBalanced() const { uint32_t NumParamTypes = Type->NumParams; uint32_t NumReturnTypes = Type->NumReturns; const WASMType *ParamTypes = Type->getParamTypes(); - const WASMType *ReturnTypes = Type->ReturnTypes; + const WASMType *ReturnTypes = Type->getReturnTypes(); return NumParamTypes == NumReturnTypes && std::memcmp(ParamTypes, ReturnTypes, NumParamTypes * sizeof(WASMType)) == 0; @@ -48,7 +48,7 @@ FunctionLoader::ControlBlockType::getReturnTypes() const { const TypeEntry *Type = std::get(TypeVariant); return { static_cast(Type->NumReturns), - Type->ReturnTypes, + Type->getReturnTypes(), }; } @@ -287,8 +287,24 @@ void FunctionLoader::load() { [[fallthrough]]; case BLOCK: case LOOP: { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + int32_t BlockTypeOrIndex = readBlockType(); + ControlBlockType BlockType; + if (BlockTypeOrIndex >= 0) { + // Simple type (VOID, I32, I64, F32, F64) + BlockType = static_cast(BlockTypeOrIndex); + } else { + // Type index (multi-value) + uint32_t TypeIndex = static_cast(-BlockTypeOrIndex - 1); + if (!Mod.isValidType(TypeIndex)) { + throw getError(ErrorCode::UnknownTypeIdx); + } + BlockType = Mod.getDeclaredType(TypeIndex); + } +#else WASMType Type = readBlockType(); ControlBlockType BlockType = Type; +#endif auto BlockLabelTy = static_cast(LABEL_BLOCK + Opcode - BLOCK); pushBlock(BlockLabelTy, BlockType, Ptr); @@ -782,7 +798,7 @@ void FunctionLoader::load() { case RETURN: { int32_t NumReturns = static_cast(FuncTypeEntry.NumReturns); for (int32_t I = NumReturns - 1; I >= 0; --I) { - popValueType(FuncTypeEntry.ReturnTypes[I]); + popValueType(FuncTypeEntry.getReturnTypes()[I]); } resetStack(); setStackPolymorphic(true); @@ -801,7 +817,7 @@ void FunctionLoader::load() { popValueType(ParamTypes[I - 1]); } for (uint32_t I = 0; I < CalleeFuncType->NumReturns; ++I) { - pushValueType(CalleeFuncType->ReturnTypes[I]); + pushValueType(CalleeFuncType->getReturnTypes()[I]); } #ifdef ZEN_ENABLE_MULTIPASS_JIT if (!CalleeIdxBitset[CalleeIdx]) { @@ -836,7 +852,7 @@ void FunctionLoader::load() { } for (uint32_t I = 0; I < CalleeFuncType->NumReturns; ++I) { - pushValueType(CalleeFuncType->ReturnTypes[I]); + pushValueType(CalleeFuncType->getReturnTypes()[I]); } #ifdef ZEN_ENABLE_MULTIPASS_JIT const auto &LikelyCalleeIdxs = Mod.TypedFuncRefs[TypeIdx]; diff --git a/src/action/function_loader.h b/src/action/function_loader.h index f828d11a5..210cf4f48 100644 --- a/src/action/function_loader.h +++ b/src/action/function_loader.h @@ -26,6 +26,14 @@ class FunctionLoader final : public LoaderCommon { return *this; } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Check if this block type is a type index (multi-value) + bool isTypeIndex() const { + return TypeVariant.index() == 1 && + std::holds_alternative(TypeVariant); + } +#endif + // Means the block has the same popped types and pushed types bool isBalanced() const; diff --git a/src/action/instantiator.cpp b/src/action/instantiator.cpp index 6d94f0b66..4e531ee35 100644 --- a/src/action/instantiator.cpp +++ b/src/action/instantiator.cpp @@ -106,7 +106,7 @@ void Instantiator::instantiateFunctions(Instance &Inst) { FuncInst.NumParamCells = Type.NumParamCells; FuncInst.NumReturns = Type.NumReturns; FuncInst.NumReturnCells = Type.NumReturnCells; - std::memcpy(FuncInst.ReturnTypes, Type.ReturnTypes, + std::memcpy(FuncInst.ReturnTypes, Type.getReturnTypes(), sizeof(FuncInst.ReturnTypes)); FuncInst.ParamTypes = Type.ParamTypes; FuncInst.FuncType = &Type; diff --git a/src/action/interpreter.cpp b/src/action/interpreter.cpp index dbfa494b4..6a5b3287e 100644 --- a/src/action/interpreter.cpp +++ b/src/action/interpreter.cpp @@ -1171,6 +1171,23 @@ void BaseInterpreterImpl::interpret() { BREAK; } CASE(BLOCK) : { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: check if block type is a type index + WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip); + if (BlockType == WASMType::ERROR_TYPE) { + // Type index - read as signed LEB128 from current position + int32_t TypeIndex; + Ip = utils::readLEBNumber(Ip, IpEnd, TypeIndex); + if (TypeIndex >= 0 && Mod->isValidType(TypeIndex)) { + const TypeEntry *Type = Mod->getDeclaredType(TypeIndex); + findBlockAddr(Ip, IpEnd, ElseAddr, EndAddr); + Frame->blockPush(ControlStackPtr, EndAddr, ValStackPtr, + Type->NumReturnCells, LABEL_BLOCK, Type, + Type->NumReturns, Type->NumReturnCells); + BREAK; + } + } +#endif uint32_t CellNum = getWASMTypeCellNumFromOpcode(*Ip++); findBlockAddr(Ip, IpEnd, ElseAddr, EndAddr); @@ -1179,6 +1196,23 @@ void BaseInterpreterImpl::interpret() { BREAK; } CASE(LOOP) : { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: check if block type is a type index + WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip); + if (BlockType == WASMType::ERROR_TYPE) { + // Type index - read as signed LEB128 from current position + int32_t TypeIndex; + Ip = utils::readLEBNumber(Ip, IpEnd, TypeIndex); + if (TypeIndex >= 0 && Mod->isValidType(TypeIndex)) { + const TypeEntry *Type = Mod->getDeclaredType(TypeIndex); + // For loops, the result type is the input parameters + Frame->blockPush(ControlStackPtr, Ip, ValStackPtr, + Type->NumParamCells, LABEL_LOOP, Type, + Type->NumParams, Type->NumParamCells); + BREAK; + } + } +#endif uint32_t CellNum = getWASMTypeCellNumFromOpcode(*Ip++); Frame->blockPush(ControlStackPtr, Ip, ValStackPtr, CellNum, LABEL_LOOP); BREAK; @@ -1217,6 +1251,35 @@ void BaseInterpreterImpl::interpret() { BREAK; } CASE(IF) : { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: check if block type is a type index + WASMType BlockType = getWASMBlockTypeFromOpcode(*Ip); + if (BlockType == WASMType::ERROR_TYPE) { + // Type index - read as signed LEB128 from current position + int32_t TypeIndex; + Ip = utils::readLEBNumber(Ip, IpEnd, TypeIndex); + if (TypeIndex >= 0 && Mod->isValidType(TypeIndex)) { + const TypeEntry *Type = Mod->getDeclaredType(TypeIndex); + Cond = Frame->valuePop(ValStackPtr); + findBlockAddr(Ip, IpEnd, ElseAddr, EndAddr); + if (Cond) { + Frame->blockPush(ControlStackPtr, EndAddr, ValStackPtr, + Type->NumReturnCells, LABEL_IF, Type, + Type->NumReturns, Type->NumReturnCells); + } else { + if (ElseAddr == nullptr) { + Ip = EndAddr + 1; + } else { + Frame->blockPush(ControlStackPtr, EndAddr, ValStackPtr, + Type->NumReturnCells, LABEL_IF, Type, + Type->NumReturns, Type->NumReturnCells); + Ip = ElseAddr + 1; + } + } + BREAK; + } + } +#endif uint32_t CellNum = getWASMTypeCellNumFromOpcode(*Ip++); Cond = Frame->valuePop(ValStackPtr); diff --git a/src/action/interpreter.h b/src/action/interpreter.h index 17d3da218..063153975 100644 --- a/src/action/interpreter.h +++ b/src/action/interpreter.h @@ -15,6 +15,9 @@ namespace runtime { struct FunctionInstance; class Instance; class Runtime; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE +struct TypeEntry; +#endif } // namespace runtime namespace action { @@ -24,6 +27,12 @@ struct BlockInfo { uint32_t *ValueStackPtr; uint32_t CellNum; common::LabelType LabelType; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: store full type information + const runtime::TypeEntry *BlockType; // Full type for multi-value blocks + uint32_t NumResults; // Number of result values + uint32_t TotalResultCells; // Total cells for all results +#endif }; struct InterpFrame { @@ -89,12 +98,23 @@ struct InterpFrame { void blockPush(BlockInfo *&ControlStackPtr, const uint8_t *TargetAddr, uint32_t *ValStackPtr, uint32_t CellNum, - common::LabelType LabelType) { + common::LabelType LabelType +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + , + const runtime::TypeEntry *BlockType = nullptr, + uint32_t NumResults = 0, uint32_t TotalResultCells = 0 +#endif + ) { ZEN_ASSERT(ControlStackPtr <= CtrlBoundary); ControlStackPtr->TargetAddr = TargetAddr; ControlStackPtr->ValueStackPtr = ValStackPtr; ControlStackPtr->CellNum = CellNum; ControlStackPtr->LabelType = LabelType; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + ControlStackPtr->BlockType = BlockType; + ControlStackPtr->NumResults = NumResults; + ControlStackPtr->TotalResultCells = TotalResultCells; +#endif ControlStackPtr++; } @@ -114,10 +134,19 @@ struct InterpFrame { Ip = CurBlock->TargetAddr; if (CurBlock->LabelType != common::LABEL_LOOP) { +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: copy all result values + uint32_t CellsToCopy = CurBlock->TotalResultCells > 0 + ? CurBlock->TotalResultCells + : CurBlock->CellNum; + std::memcpy(ValStackPtr, ValStackPtrOld - CellsToCopy, CellsToCopy << 2); + ValStackPtr += CellsToCopy; +#else uint32_t CellNum = (ControlStackPtr - 1)->CellNum; std::memcpy(ValStackPtr, ValStackPtrOld - CellNum, CellNum << 2); ValStackPtr += CellNum; +#endif } } }; diff --git a/src/action/loader_common.h b/src/action/loader_common.h index c0565c764..23d26ba3f 100644 --- a/src/action/loader_common.h +++ b/src/action/loader_common.h @@ -98,9 +98,32 @@ class LoaderCommon { return readTypeBase(common::getWASMValTypeFromOpcode); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: block type can be a type index + // Returns: positive value for WASMType, negative value for type index + // The caller should check if result < 0 to determine if it's a type index + int32_t readBlockType() { + uint8_t TypeOpcode = common::to_underlying(readByte()); + WASMType Type = common::getWASMBlockTypeFromOpcode(TypeOpcode); + if (Type != WASMType::ERROR_TYPE) { + return static_cast(Type); + } + // Type index (signed LEB128) + // Rewind and read as signed LEB128 + Ptr--; + int32_t TypeIndex = readLEB(); + if (TypeIndex < 0) { + // Valid type index (must be non-negative) + throw getError(ErrorCode::InvalidType); + } + // Return negative value to indicate type index + return -TypeIndex - 1; // -1 means type index 0, -2 means type index 1, etc. + } +#else WASMType readBlockType() { return readTypeBase(common::getWASMBlockTypeFromOpcode); } +#endif WASMType readRefType() { return readTypeBase(common::getWASMRefTypeFromOpcode); diff --git a/src/action/module_loader.cpp b/src/action/module_loader.cpp index 79331d561..fac1d2021 100644 --- a/src/action/module_loader.cpp +++ b/src/action/module_loader.cpp @@ -251,7 +251,7 @@ ModuleLoader::resolveImportFunction(WASMSymbol ModuleName, WASMSymbol FieldName, } for (uint32_t I = 0; I < ExpectedNumReturns; ++I) { - WASMType ExpectedType = ExpectedFuncType.ReturnTypes[I]; + WASMType ExpectedType = ExpectedFuncType.getReturnTypes()[I]; WASMType ActualType = ActualFuncType[I + ActualNumParams]; if (ExpectedType != ActualType) { std::string DetailErrMsg = "return type mismatch (expected "; @@ -434,7 +434,17 @@ void ModuleLoader::loadTypeSection() { } uint32_t NumReturns = readU32(); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + WASMType *ReturnTypes = nullptr; + if (NumReturns > 2) { + // Allocate dynamic array for more than 2 return types + ReturnTypes = Entry->ReturnTypesPtr = Mod.initParamTypes(NumReturns); + } else { + ReturnTypes = Entry->ReturnTypesVec; + } +#else WASMType *ReturnTypes = Entry->ReturnTypes; +#endif if (NumReturns > PresetMaxNumReturns) { throw getError(ErrorCode::TooManyReturns); } diff --git a/src/common/defines.h b/src/common/defines.h index 13da3dc8a..5be94b8e5 100644 --- a/src/common/defines.h +++ b/src/common/defines.h @@ -117,8 +117,14 @@ constexpr size_t PresetMaxSectionSize = 512 * 1024 * 1024; // 512MB constexpr size_t PresetMaxNameLength = UINT16_MAX; constexpr size_t PresetMaxNumParams = UINT16_MAX; // uint16_t constexpr size_t PresetMaxNumParamCells = UINT16_MAX; // uint16_t +// Multi-value support: allow multiple return values (WASI extension) +// MVP limit: PresetMaxNumReturns = 1 +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE +constexpr size_t PresetMaxNumReturns = UINT8_MAX; // uint8_t: 8 bits +#else // At most one return value in MVP -constexpr size_t PresetMaxNumReturns = 1; // uint8_t: 2 bits +constexpr size_t PresetMaxNumReturns = 1; // uint8_t: 2 bits +#endif constexpr size_t PresetMaxNumReturnCells = (1u << 6) - 1; // uint8_t: 6 bits constexpr size_t PresetMaxMemoryPages = 1u << 16; // 65536 pages diff --git a/src/common/type.h b/src/common/type.h index 16a6c2351..6926ec026 100644 --- a/src/common/type.h +++ b/src/common/type.h @@ -226,6 +226,9 @@ template static inline constexpr uint32_t getWASMTypeSize() { static inline uint32_t getWASMTypeSize(WASMType Type) { switch (Type) { + case WASMType::VOID: + case WASMType::ERROR_TYPE: + return 0; case WASMType::I8: return 1; case WASMType::I16: diff --git a/src/compiler/wasm_frontend/wasm_mir_compiler.cpp b/src/compiler/wasm_frontend/wasm_mir_compiler.cpp index d0b2c86a5..47ba0118f 100644 --- a/src/compiler/wasm_frontend/wasm_mir_compiler.cpp +++ b/src/compiler/wasm_frontend/wasm_mir_compiler.cpp @@ -77,7 +77,8 @@ void buildAllMIRFuncTypes(WasmFrontendContext &Context, MModule &MMod, MParamTypes[J + 1] = Context.getMIRTypeFromWASMType(ParamTypes[J]); } MType *MRetType = Context.getMIRTypeFromWASMType( - FuncType->NumReturns > 0 ? FuncType->ReturnTypes[0] : WASMType::VOID); + FuncType->NumReturns > 0 ? FuncType->getReturnTypes()[0] + : WASMType::VOID); MMod.addFuncType(MFunctionType::create(Context, *MRetType, MParamTypes)); } } @@ -126,7 +127,22 @@ void FunctionMirBuilder::initFunction( } MBasicBlock *ReturnBB = createBasicBlock(); - enterBlock(CtrlBlockKind::FUNC_ENTRY, RetType, 0, ReturnBB); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Handle multi-value function returns + if (Type.NumReturns > 1) { + std::vector Results; + Results.reserve(Type.NumReturns); + const WASMType *RetTypes = Type.getReturnTypes(); + for (uint32_t I = 0; I < Type.NumReturns; ++I) { + Results.push_back(createTempStackOperand(RetTypes[I])); + } + ControlStack.emplace_back(CtrlBlockKind::FUNC_ENTRY, std::move(Results), + &Type, 0, ReturnBB, nullptr, nullptr); + } else +#endif + { + enterBlock(CtrlBlockKind::FUNC_ENTRY, RetType, 0, ReturnBB); + } loadWASMInstanceAttr(); } @@ -256,8 +272,12 @@ void FunctionMirBuilder::finalizeFunctionBase() { auto ReturnZero = [&]() { Operand Ret; - WASMType WType = - static_cast(Ctx.getWasmFuncType().ReturnTypes[0]); + const auto &Type = Ctx.getWasmFuncType(); + if (Type.NumReturns == 0) { + handleReturn(Ret); + return; + } + WASMType WType = Type.getReturnType(); switch (WType) { case WASMType::I32: Ret = handleConst(0); @@ -266,12 +286,10 @@ void FunctionMirBuilder::finalizeFunctionBase() { Ret = handleConst(0); break; case WASMType::F32: - Ret = handleConst(0); + Ret = handleConst(0.0f); break; case WASMType::F64: - Ret = handleConst(0); - break; - case WASMType::VOID: + Ret = handleConst(0.0); break; default: ZEN_ABORT(); @@ -405,6 +423,67 @@ void FunctionMirBuilder::handleIf(Operand CondOp, WASMType Type, setInsertBlock(ThenBlock); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE +void FunctionMirBuilder::handleBlockMultiValue(const TypeEntry *Type, + uint32_t StackSize) { + MBasicBlock *EndBlock = createBasicBlock(); + std::vector Results; + if (Type->NumReturns > 0) { + Results.reserve(Type->NumReturns); + const WASMType *RetTypes = Type->getReturnTypes(); + for (uint32_t I = 0; I < Type->NumReturns; ++I) { + Results.push_back(createTempStackOperand(RetTypes[I])); + } + } + ControlStack.emplace_back(CtrlBlockKind::BLOCK, std::move(Results), Type, + StackSize, EndBlock, nullptr, nullptr); +} + +void FunctionMirBuilder::handleLoopMultiValue(const TypeEntry *Type, + uint32_t StackSize) { + MBasicBlock *LoopBlock = createBasicBlock(); + MBasicBlock *EndBlock = createBasicBlock(); + createInstruction(true, Ctx, LoopBlock); + addSuccessor(LoopBlock); + + std::vector Results; + if (Type->NumReturns > 0) { + Results.reserve(Type->NumReturns); + const WASMType *RetTypes = Type->getReturnTypes(); + for (uint32_t I = 0; I < Type->NumReturns; ++I) { + Results.push_back(createTempStackOperand(RetTypes[I])); + } + } + ControlStack.emplace_back(CtrlBlockKind::LOOP, std::move(Results), Type, + StackSize, LoopBlock, EndBlock, nullptr); + setInsertBlock(LoopBlock); +} + +void FunctionMirBuilder::handleIfMultiValue(Operand CondOp, + const TypeEntry *Type, + uint32_t StackSize) { + MInstruction *Condition = extractOperand(CondOp); + MBasicBlock *ThenBlock = createBasicBlock(); + MBasicBlock *EndBlock = createBasicBlock(); + auto BranchInst = createInstruction(true, Ctx, Condition, + ThenBlock, EndBlock); + addSuccessor(ThenBlock); + addSuccessor(EndBlock); + + std::vector Results; + if (Type->NumReturns > 0) { + Results.reserve(Type->NumReturns); + const WASMType *RetTypes = Type->getReturnTypes(); + for (uint32_t I = 0; I < Type->NumReturns; ++I) { + Results.push_back(createTempStackOperand(RetTypes[I])); + } + } + ControlStack.emplace_back(CtrlBlockKind::IF, std::move(Results), Type, + StackSize, EndBlock, nullptr, BranchInst); + setInsertBlock(ThenBlock); +} +#endif + void FunctionMirBuilder::handleElse(const BlockInfo &Info) { MBasicBlock *EndBlock = Info.getJumpBlock(); if (Info.reachable()) { @@ -535,6 +614,35 @@ void FunctionMirBuilder::handleReturn(Operand Opnd) { createInstruction(true, Type, Ret); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE +void FunctionMirBuilder::handleReturnMultiValue( + const std::vector &Opnds) { +#ifdef ZEN_ENABLE_DWASM + const auto &Layout = Ctx.getWasmMod().getLayout(); + MInstruction *StackCost = + getInstanceElement(&Ctx.I32Type, Layout.StackCostOffset); + MInstruction *CurFuncStackCost = createIntConstInstruction( + &Ctx.I32Type, Ctx.getWasmFuncCode().JITStackCost); + MInstruction *NewStackCost = createInstruction( + false, OP_sub, &Ctx.I32Type, StackCost, CurFuncStackCost); + setInstanceElement(&Ctx.I32Type, NewStackCost, Layout.StackCostOffset); +#endif + + // For multi-value returns, we need to handle multiple operands + // Currently we just use the first operand for the return instruction + // Full multi-value return requires extending ReturnInstruction + if (Opnds.empty()) { + createInstruction(true, &Ctx.VoidType, nullptr); + } else { + // Use first operand for now + // TODO: Extend ReturnInstruction to support multiple operands + MInstruction *Ret = extractOperand(Opnds[0]); + MType *Type = Ret ? Ret->getType() : &Ctx.VoidType; + createInstruction(true, Type, Ret); + } +} +#endif + FunctionMirBuilder::Operand FunctionMirBuilder::handleCall( uint32_t FuncIdx, uintptr_t Target, bool IsImport, bool FarCall, const ArgumentInfo &ArgInfo, const std::vector &Args) { @@ -641,6 +749,90 @@ FunctionMirBuilder::Operand FunctionMirBuilder::handleCallIndirect( return handleCallBase(FuncAddr, ArgInfo, Args, true); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE +// Multi-value call implementation +std::vector +FunctionMirBuilder::handleCallMultiValue(uint32_t FuncIdx, uintptr_t Target, + bool IsImport, bool FarCall, + const ArgumentInfo &ArgInfo, + const std::vector &Args) { + // First, get the single result from the actual call + // The calling convention only supports single return value + Operand SingleResult = + handleCall(FuncIdx, Target, IsImport, FarCall, ArgInfo, Args); + + // For multi-value returns, we need to create temp operands for each return + // The first return value comes from the actual call result + // Additional return values would require extending the calling convention + std::vector Results; + uint32_t NumReturns = ArgInfo.getNumReturns(); + if (NumReturns > 0) { + Results.reserve(NumReturns); + // First result is from the actual call + Results.push_back(SingleResult); + // For additional return values, create properly defined operands + // Note: This is a limitation - the calling convention only supports + // single return value, so additional values are initialized to 0 + // Full multi-value support requires extending the calling convention + const WASMType *RetTypes = ArgInfo.getReturnTypes(); + for (uint32_t I = 1; I < NumReturns; ++I) { + MType *Mtype = Ctx.getMIRTypeFromWASMType(RetTypes[I]); + MInstruction *ConstVal; + if (RetTypes[I] == WASMType::F32) { + ConstVal = createInstruction( + false, Mtype, *MConstantFloat::get(Ctx, *Mtype, 0.0f)); + } else if (RetTypes[I] == WASMType::F64) { + ConstVal = createInstruction( + false, Mtype, *MConstantFloat::get(Ctx, *Mtype, 0.0)); + } else { + ConstVal = createIntConstInstruction(Mtype, 0); + } + MInstruction *DefinedVal = makeReusableValue(ConstVal, Mtype); + Results.push_back(Operand(DefinedVal, RetTypes[I])); + } + } + return Results; +} + +std::vector +FunctionMirBuilder::handleCallIndirectMultiValue( + uint32_t TypeIdx, Operand IndirectFuncIdx, uint32_t TblIdx, + const ArgumentInfo &ArgInfo, const std::vector &Args) { + // First, get the single result from the actual call + Operand SingleResult = + handleCallIndirect(TypeIdx, IndirectFuncIdx, TblIdx, ArgInfo, Args); + + // For multi-value returns, create properly defined operands + std::vector Results; + uint32_t NumReturns = ArgInfo.getNumReturns(); + if (NumReturns > 0) { + Results.reserve(NumReturns); + // First result is from the actual call + Results.push_back(SingleResult); + // For additional return values, create properly defined operands + // Note: This is a limitation - the calling convention only supports + // single return value, so additional values are initialized to 0 + const WASMType *RetTypes = ArgInfo.getReturnTypes(); + for (uint32_t I = 1; I < NumReturns; ++I) { + MType *Mtype = Ctx.getMIRTypeFromWASMType(RetTypes[I]); + MInstruction *ConstVal; + if (RetTypes[I] == WASMType::F32) { + ConstVal = createInstruction( + false, Mtype, *MConstantFloat::get(Ctx, *Mtype, 0.0f)); + } else if (RetTypes[I] == WASMType::F64) { + ConstVal = createInstruction( + false, Mtype, *MConstantFloat::get(Ctx, *Mtype, 0.0)); + } else { + ConstVal = createIntConstInstruction(Mtype, 0); + } + MInstruction *DefinedVal = makeReusableValue(ConstVal, Mtype); + Results.push_back(Operand(DefinedVal, RetTypes[I])); + } + } + return Results; +} +#endif + void FunctionMirBuilder::checkCallException(bool IsImportOrIndirect) { #ifdef ZEN_ENABLE_CPU_EXCEPTION if (IsImportOrIndirect) { diff --git a/src/compiler/wasm_frontend/wasm_mir_compiler.h b/src/compiler/wasm_frontend/wasm_mir_compiler.h index 866764d09..2b4202513 100644 --- a/src/compiler/wasm_frontend/wasm_mir_compiler.h +++ b/src/compiler/wasm_frontend/wasm_mir_compiler.h @@ -85,8 +85,11 @@ class FunctionMirBuilder final { bool isEmpty() const { return !Instr && !Var && Type == WASMType::VOID; } /* Do nothing, only used to match WASMByteVisitor */ - constexpr bool isReg() { return false; } - constexpr bool isTempReg() { return true; } + constexpr bool isReg() const { return false; } + constexpr bool isMem() const { return false; } + constexpr bool isTempReg() const { return true; } + constexpr int getKind() const { return 0; } + constexpr uint8_t getRawOpKind() const { return 0; } private: MInstruction *Instr = nullptr; @@ -103,7 +106,26 @@ class FunctionMirBuilder final { MBasicBlock *JumpBlock, MBasicBlock *NextBlock, BrIfInstruction *BranchInst) : Kind(Kind), Result(Result), StackSize(StackSize), - JumpBlock(JumpBlock), NextBlock(NextBlock), BranchInstr(BranchInst) {} + JumpBlock(JumpBlock), NextBlock(NextBlock), BranchInstr(BranchInst) +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + , + IsMultiValue(false) +#endif + { + } + +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Constructor for multi-value blocks + BlockInfo(CtrlBlockKind Kind, std::vector Results, + const TypeEntry *BlockType, uint32_t StackSize, + MBasicBlock *JumpBlock, MBasicBlock *NextBlock, + BrIfInstruction *BranchInst) + : Kind(Kind), Results(std::move(Results)), BlockType(BlockType), + StackSize(StackSize), JumpBlock(JumpBlock), NextBlock(NextBlock), + BranchInstr(BranchInst), IsMultiValue(true) { + Result = this->Results.empty() ? Operand() : this->Results[0]; + } +#endif CtrlBlockKind getKind() const { return Kind; } @@ -111,6 +133,16 @@ class FunctionMirBuilder final { WASMType getType() const { return Result.getType(); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + bool isMultiValue() const { return IsMultiValue; } + uint32_t getNumResults() const { + return IsMultiValue ? Results.size() + : (Result.getType() == WASMType::VOID ? 0 : 1); + } + const std::vector &getResults() const { return Results; } + const TypeEntry *getBlockType() const { return BlockType; } +#endif + uint32_t getStackSize() const { return StackSize; } void setReachable(bool V) { Reachable = V; } @@ -129,18 +161,38 @@ class FunctionMirBuilder final { private: CtrlBlockKind Kind; Operand Result; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + std::vector Results; + const TypeEntry *BlockType = nullptr; +#endif uint32_t StackSize; MBasicBlock *JumpBlock = nullptr; MBasicBlock *NextBlock = nullptr; BrIfInstruction *BranchInstr = nullptr; bool Reachable = true; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + bool IsMultiValue = false; +#endif }; class ArgumentInfo { public: ArgumentInfo(const TypeEntry *Type) { ZEN_ASSERT(Type); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: store all return types + NumReturns = Type->NumReturns; + if (NumReturns > 0) { + ReturnTypes.resize(NumReturns); + std::memcpy(ReturnTypes.data(), Type->getReturnTypes(), + NumReturns * sizeof(WASMType)); + RetType = ReturnTypes[0]; // Primary return type for backward compat + } else { + RetType = WASMType::VOID; + } +#else RetType = Type->getReturnType(); +#endif uint32_t NumParams = Type->NumParams; // Reserve 1 slot for instance ArgTypes.resize(NumParams + 1); @@ -151,9 +203,20 @@ class FunctionMirBuilder final { WASMType getReturnType() const { return RetType; } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + uint32_t getNumReturns() const { return NumReturns; } + const WASMType *getReturnTypes() const { + return NumReturns > 0 ? ReturnTypes.data() : nullptr; + } +#endif + private: std::vector ArgTypes; WASMType RetType; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + uint32_t NumReturns = 0; + std::vector ReturnTypes; +#endif }; bool compile(CompilerContext *Context); @@ -180,10 +243,23 @@ class FunctionMirBuilder final { void handleBlock(WASMType Type, uint32_t Estack); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + void handleBlockMultiValue(const TypeEntry *Type, uint32_t Estack); +#endif + void handleLoop(WASMType Type, uint32_t Estack); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + void handleLoopMultiValue(const TypeEntry *Type, uint32_t Estack); +#endif + void handleIf(Operand CondOp, WASMType Type, uint32_t Estack); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + void handleIfMultiValue(Operand CondOp, const TypeEntry *Type, + uint32_t Estack); +#endif + void handleElse(const BlockInfo &Info); void handleEnd(const BlockInfo &Info); @@ -197,13 +273,32 @@ class FunctionMirBuilder final { void handleReturn(Operand Opnd); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + void handleReturnMultiValue(const std::vector &Opnds); +#endif + Operand handleCall(uint32_t FuncIdx, uintptr_t TarGet, bool IsImport, bool FarCall, const ArgumentInfo &ArgInfo, const std::vector &Args); + +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + std::vector handleCallMultiValue(uint32_t FuncIdx, uintptr_t TarGet, + bool IsImport, bool FarCall, + const ArgumentInfo &ArgInfo, + const std::vector &Args); +#endif + Operand handleCallIndirect(uint32_t TypeIdx, Operand IndirectFuncIdx, uint32_t TblIdx, const ArgumentInfo &ArgInfo, const std::vector &Args); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + std::vector + handleCallIndirectMultiValue(uint32_t TypeIdx, Operand IndirectFuncIdx, + uint32_t TblIdx, const ArgumentInfo &ArgInfo, + const std::vector &Args); +#endif + // ==================== Parametric Instruction Handlers ==================== Operand handleSelect(Operand CondOp, Operand LHSOp, Operand RHSOp); diff --git a/src/runtime/module.cpp b/src/runtime/module.cpp index b5a5955aa..75568d751 100644 --- a/src/runtime/module.cpp +++ b/src/runtime/module.cpp @@ -112,7 +112,7 @@ bool TypeEntry::isEqual(TypeEntry *Type1, TypeEntry *Type2) { } if (std::memcmp(Type1->getParamTypes(), Type2->getParamTypes(), sizeof(WASMType) * Type1->NumParams) || - std::memcmp(Type1->ReturnTypes, Type2->ReturnTypes, + std::memcmp(Type1->getReturnTypes(), Type2->getReturnTypes(), sizeof(WASMType) * Type1->NumReturns)) { return false; } @@ -291,6 +291,12 @@ void Module::destroyTypeTable() { if (TypeTable[I].NumParams > (__WORDSIZE / 8) && TypeTable[I].ParamTypes) { deallocate(TypeTable[I].ParamTypes); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Free dynamically allocated return types array for multi-value + if (TypeTable[I].NumReturns > 2 && TypeTable[I].ReturnTypesPtr) { + deallocate(TypeTable[I].ReturnTypesPtr); + } +#endif } deallocate(TypeTable); } diff --git a/src/runtime/module.h b/src/runtime/module.h index fead2c0aa..126a7ed00 100644 --- a/src/runtime/module.h +++ b/src/runtime/module.h @@ -126,9 +126,20 @@ class HostModule final : public BaseModule { struct TypeEntry final { uint16_t NumParams; uint16_t NumParamCells; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support: allow more return values + uint8_t NumReturns; + uint8_t NumReturnCells; + // Dynamic array for return types when multi-value is enabled + union { + WASMType *ReturnTypesPtr; + WASMType ReturnTypesVec[2]; // Inline storage for small number of returns + }; +#else uint8_t NumReturns : 2; uint8_t NumReturnCells : 6; WASMType ReturnTypes[2]; +#endif union { WASMType *ParamTypes; WASMType ParamTypesVec[__WORDSIZE / 8]; @@ -142,9 +153,29 @@ struct TypeEntry final { return ParamTypesVec; } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + const WASMType *getReturnTypes() const { + if (NumReturns <= 2) { + return ReturnTypesVec; + } + return ReturnTypesPtr; + } +#else + const WASMType *getReturnTypes() const { return ReturnTypes; } +#endif + WASMType getReturnType() const { + // When multi-value is enabled, functions may have multiple return values. + // This method returns the first return type (or VOID if none). + // The assertion only applies when multi-value is disabled. +#ifndef ZEN_ENABLE_WASI_MULTI_VALUE ZEN_ASSERT(NumReturns <= 1); +#endif +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + return NumReturns > 0 ? getReturnTypes()[0] : WASMType::VOID; +#else return NumReturns > 0 ? ReturnTypes[0] : WASMType::VOID; +#endif } static bool isEqual(TypeEntry *Type1, TypeEntry *Type2); diff --git a/src/runtime/runtime.cpp b/src/runtime/runtime.cpp index 0f2a116cb..d70c39e1f 100644 --- a/src/runtime/runtime.cpp +++ b/src/runtime/runtime.cpp @@ -373,7 +373,7 @@ static bool checkMainFuncType(TypeEntry *Type) { !(ParamTypes[0] == WASMType::I32 && ParamTypes[1] == WASMType::I32)) { return false; } - if (Type->NumReturns && Type->ReturnTypes[0] != WASMType::I32) { + if (Type->NumReturns && Type->getReturnTypes()[0] != WASMType::I32) { return false; } return true; diff --git a/src/singlepass/common/codegen.h b/src/singlepass/common/codegen.h index 24a27a906..a93c0ce7e 100644 --- a/src/singlepass/common/codegen.h +++ b/src/singlepass/common/codegen.h @@ -46,7 +46,20 @@ class ArgumentInfo { ArgumentInfo(TypeEntry *Type) { ZEN_ASSERT(Type); + // When multi-value is enabled, functions may have multiple return values. + // ArgumentInfo currently only supports single return value - the first + // return type is used. Full multi-value support requires additional work. +#ifndef ZEN_ENABLE_WASI_MULTI_VALUE ZEN_ASSERT(Type->NumReturns <= 1); +#else + // Store multi-value return info + NumRet = Type->NumReturns; + if (NumRet > 0) { + RetTypes.resize(NumRet); + std::memcpy(RetTypes.data(), Type->getReturnTypes(), + NumRet * sizeof(WASMType)); + } +#endif uint32_t ArgNum = Type->NumParams; RetType = Type->getReturnType(); uint32_t GpNum = 0; @@ -92,6 +105,13 @@ class ArgumentInfo { WASMType getReturnType() const { return RetType; } uint32_t getStackSize() const { return StackSize; } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + uint32_t getNumReturns() const { return NumRet; } + const WASMType *getReturnTypes() const { + return NumRet > 0 ? RetTypes.data() : nullptr; + } +#endif + typedef typename std::vector::const_reverse_iterator ConstReverseIterator; typedef typename std::vector::const_iterator ConstIterator; @@ -133,6 +153,10 @@ class ArgumentInfo { uint8_t NumFpRegs; uint16_t StackSize; WASMType RetType : 8; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + uint8_t NumRet = 0; + std::vector RetTypes; +#endif }; constexpr uint32_t InvalidLabelId = asmjit::Globals::kInvalidId; @@ -212,7 +236,24 @@ class OnePassCodeGen { public: BlockInfo(CtrlBlockKind Kind, Operand Result, uint32_t Label, uint32_t StackSize) - : Kind(Kind), Result(Result), Label(Label), StackSize(StackSize) {} + : Kind(Kind), Result(Result), Label(Label), StackSize(StackSize) +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + , + IsMultiValue(false) +#endif + { + } + +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Constructor for multi-value blocks + BlockInfo(CtrlBlockKind Kind, std::vector Results, + const TypeEntry *BlockType, uint32_t Label, uint32_t StackSize) + : Kind(Kind), Results(std::move(Results)), BlockType(BlockType), + Label(Label), StackSize(StackSize), IsMultiValue(true) { + // Set primary result for backward compatibility + Result = this->Results.empty() ? Operand() : this->Results[0]; + } +#endif // Get block kind CtrlBlockKind getKind() const { return Kind; } @@ -222,6 +263,17 @@ class OnePassCodeGen { // Get block WASM type WASMType getType() const { return Result.getType(); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Multi-value support + bool isMultiValue() const { return IsMultiValue; } + uint32_t getNumResults() const { + return IsMultiValue ? Results.size() + : (Result.getType() == WASMType::VOID ? 0 : 1); + } + const std::vector &getResults() const { return Results; } + const TypeEntry *getBlockType() const { return BlockType; } +#endif + // Get label associated with the block uint32_t getLabel() const { return Label; } @@ -245,11 +297,18 @@ class OnePassCodeGen { private: CtrlBlockKind Kind; Operand Result; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + std::vector Results; + const TypeEntry *BlockType = nullptr; +#endif uint32_t Label; uint32_t StackSize; bool HasElseLabel = false; bool Reachable = true; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + bool IsMultiValue = false; +#endif }; OnePassCodeGen(asmjit::CodeHolder *Code, OnePassDataLayout &Layout, @@ -279,13 +338,28 @@ class OnePassCodeGen { ZEN_ASSERT(Stack.size() == 0); - WASMType RetType = Type->getReturnType(); - // Use stack operand instead of register operand, as return values of - // function/block have relatively long lifetime and may hold registers - // for too long. - auto Res = - (RetType == WASMType::VOID) ? Operand() : getTempStackOperand(RetType); - Stack.emplace_back(CtrlBlockKind::FUNC_ENTRY, Res, createLabel(), 0); +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + // Handle multi-value function returns + if (Type->NumReturns > 1) { + std::vector Results; + Results.reserve(Type->NumReturns); + const WASMType *RetTypes = Type->getReturnTypes(); + for (uint32_t I = 0; I < Type->NumReturns; ++I) { + Results.push_back(getTempStackOperand(RetTypes[I])); + } + Stack.emplace_back(CtrlBlockKind::FUNC_ENTRY, std::move(Results), Type, + createLabel(), 0); + } else +#endif + { + WASMType RetType = Type->getReturnType(); + // Use stack operand instead of register operand, as return values of + // function/block have relatively long lifetime and may hold registers + // for too long. + auto Res = (RetType == WASMType::VOID) ? Operand() + : getTempStackOperand(RetType); + Stack.emplace_back(CtrlBlockKind::FUNC_ENTRY, Res, createLabel(), 0); + } } // finalize after handle the function @@ -394,6 +468,22 @@ class OnePassCodeGen { Stack.push_back(BlockInfo(CtrlBlockKind::BLOCK, Res, Label, Estack)); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + void handleBlockMultiValue(const TypeEntry *Type, uint32_t Estack) { + uint32_t Label = createLabel(); + std::vector Results; + if (Type->NumReturns > 0) { + Results.reserve(Type->NumReturns); + const WASMType *RetTypes = Type->getReturnTypes(); + for (uint32_t I = 0; I < Type->NumReturns; ++I) { + Results.push_back(getTempStackOperand(RetTypes[I])); + } + } + Stack.push_back(BlockInfo(CtrlBlockKind::BLOCK, std::move(Results), Type, + Label, Estack)); + } +#endif + void handleLoop(WASMType Type, uint32_t Estack) { uint32_t Label = createLabel(); auto Res = (Type == WASMType::VOID) ? Operand() : getTempStackOperand(Type); @@ -401,6 +491,24 @@ class OnePassCodeGen { bindLabel(Label); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + void handleLoopMultiValue(const TypeEntry *Type, uint32_t Estack) { + uint32_t Label = createLabel(); + // For loops, results are the return values (for fallthrough) + std::vector Results; + if (Type->NumReturns > 0) { + Results.reserve(Type->NumReturns); + const WASMType *RetTypes = Type->getReturnTypes(); + for (uint32_t I = 0; I < Type->NumReturns; ++I) { + Results.push_back(getTempStackOperand(RetTypes[I])); + } + } + Stack.push_back(BlockInfo(CtrlBlockKind::LOOP, std::move(Results), Type, + Label, Estack)); + bindLabel(Label); + } +#endif + void handleIf(Operand Op, WASMType Type, uint32_t Estack) { uint32_t Label = createLabel(); uint32_t ElseLabel = createLabel(); @@ -410,6 +518,25 @@ class OnePassCodeGen { self().branchFalse(Op, ElseLabel); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + void handleIfMultiValue(Operand Op, const TypeEntry *Type, uint32_t Estack) { + uint32_t Label = createLabel(); + uint32_t ElseLabel = createLabel(); + ZEN_ASSERT(ElseLabel == Label + 1); + std::vector Results; + if (Type->NumReturns > 0) { + Results.reserve(Type->NumReturns); + const WASMType *RetTypes = Type->getReturnTypes(); + for (uint32_t I = 0; I < Type->NumReturns; ++I) { + Results.push_back(getTempStackOperand(RetTypes[I])); + } + } + Stack.push_back( + BlockInfo(CtrlBlockKind::IF, std::move(Results), Type, Label, Estack)); + self().branchFalse(Op, ElseLabel); + } +#endif + // else block in if-block // if (!cond) goto else_label; // ... @@ -482,6 +609,12 @@ class OnePassCodeGen { void handleReturn(Operand Opnd) { self().handleReturnImpl(Opnd); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + void handleReturnMultiValue(const std::vector &Opnds) { + self().handleReturnMultiValueImpl(Opnds); + } +#endif + Operand handleCall(uint32_t FuncIdx, uintptr_t Target, bool IsImport, bool FarCall, const ArgumentInfo &ArgInfo, const std::vector &Arg) { @@ -489,12 +622,32 @@ class OnePassCodeGen { Arg); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + std::vector handleCallMultiValue(uint32_t FuncIdx, uintptr_t Target, + bool IsImport, bool FarCall, + const ArgumentInfo &ArgInfo, + const std::vector &Arg) { + return self().handleCallMultiValueImpl(FuncIdx, Target, IsImport, FarCall, + ArgInfo, Arg); + } +#endif + Operand handleCallIndirect(uint32_t TypeIdx, Operand Callee, uint32_t TblIdx, const ArgumentInfo &ArgInfo, const std::vector &Arg) { return self().handleCallIndirectImpl(TypeIdx, Callee, TblIdx, ArgInfo, Arg); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + std::vector + handleCallIndirectMultiValue(uint32_t TypeIdx, Operand Callee, + uint32_t TblIdx, const ArgumentInfo &ArgInfo, + const std::vector &Arg) { + return self().handleCallIndirectMultiValueImpl(TypeIdx, Callee, TblIdx, + ArgInfo, Arg); + } +#endif + // ==================== Parametric Instruction Handlers ==================== Operand handleSelect(Operand Cond, Operand LHS, Operand RHS) { diff --git a/src/singlepass/common/datalayout.h b/src/singlepass/common/datalayout.h index c1354881f..1505f358a 100644 --- a/src/singlepass/common/datalayout.h +++ b/src/singlepass/common/datalayout.h @@ -178,7 +178,7 @@ template class OnePassDataLayout : public DataLayout { WASMType getReturnType(uint32_t Index) { ZEN_ASSERT(Index < getNumReturns()); - return static_cast(Ctx->FuncType->ReturnTypes[Index]); + return static_cast(Ctx->FuncType->getReturnTypes()[Index]); } uint32_t getIntPresSavedCount() const { diff --git a/src/singlepass/x64/codegen.h b/src/singlepass/x64/codegen.h index 89e415d0b..a7fa9d52f 100644 --- a/src/singlepass/x64/codegen.h +++ b/src/singlepass/x64/codegen.h @@ -198,8 +198,17 @@ class X64OnePassCodeGenImpl #endif if (Layout.getNumReturns() > 0) { +#ifndef ZEN_ENABLE_WASI_MULTI_VALUE ZEN_ASSERT(Layout.getNumReturns() == 1); ZEN_ASSERT(Layout.getReturnType(0) == Op.getType()); +#else + // For multi-value, Op might be empty if values were already moved to + // registers by handleReturnMultiValueImpl + if (Op.isNone()) { + // Values already in registers, skip the move + } else { + ZEN_ASSERT(Layout.getReturnType(0) == Op.getType()); +#endif switch (Op.getType()) { case WASMType::I32: mov(ABI.getRetRegNum(), Op); @@ -216,1714 +225,1970 @@ class X64OnePassCodeGenImpl default: ZEN_ASSERT(false); } +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE } - for (uint32_t I = 0; I < Layout.getIntPresSavedCount(); ++I) { - const X64::GP Reg = ABI.getPresRegNum(I); - _ mov(X64Reg::getRegRef(Reg), - asmjit::x86::Mem(ABI.getFrameBaseReg(), -(I + 1) * ABI.GpRegWidth)); - } - _ mov(ABI.getStackPointerReg(), ABI.getFrameBaseReg()); - _ pop(ABI.getFrameBaseReg()); - _ ret(); - } // EmitEpilog - - template - void emitTableSize(uint32_t TblIdx, Operand EntryIdx) { - ZEN_ASSERT(EntryIdx.getType() == WASMType::I32); - - ZEN_STATIC_ASSERT(sizeof(TableInstance::CurSize) == sizeof(uint32_t)); - uint32_t SizeOffset = Ctx->Mod->getLayout().TableElemSizeOffset; - asmjit::x86::Mem SizeAddr(ABI.getModuleInstReg(), SizeOffset, - sizeof(SizeOffset)); - // compare entry_idx with sizeReg - if (EntryIdx.isReg()) { - ZEN_ASSERT(EntryIdx.isTempReg()); - _ cmp(SizeAddr, EntryIdx.getRegRef()); - } else if (EntryIdx.isMem()) { - ZEN_ASSERT(EntryIdx.isTempMem()); - auto SizeReg = Layout.getScopedTempReg(); - _ mov(SizeReg, SizeAddr); - _ cmp(SizeReg, EntryIdx.getMem()); - } else if (EntryIdx.isImm()) { - _ cmp(SizeAddr, EntryIdx.getImm()); - } else { - ZEN_ABORT(); - } - _ jbe(getExceptLabel(ErrorCode::UndefinedElement)); - } - - void emitTableGet(uint32_t TblIdx, Operand Elem, X64::GP ResRegNum) { - // place table[tbl_idx] to ScopedTempReg1 - emitTableSize(TblIdx, Elem); - auto InstReg = ABI.getModuleInstReg(); - auto ResReg = X64Reg::getRegRef(ResRegNum); - constexpr uint32_t Shift = 2; - uint32_t BaseOffset = Ctx->Mod->getLayout().TableElemBaseOffset; - // load table[tbl_idx].functions[elem] into register - if (Elem.isReg()) { - // elem is in reg, reuse this reg - _ mov(ResReg, asmjit::x86::ptr(InstReg, Elem.getRegRef(), Shift, - BaseOffset)); - } else if (Elem.isMem()) { - // elem is on stack, load it and save to ScopedTempReg0 - auto ElemReg = Layout.getScopedTempReg(); - _ mov(ElemReg, Elem.getMem()); - _ mov(ResReg, asmjit::x86::ptr(InstReg, ElemReg, Shift, BaseOffset)); - } else if (Elem.isImm()) { - _ mov(ResReg, asmjit::x86::Mem(InstReg, Elem.getImm() * sizeof(uint32_t) + - BaseOffset)); - } +#endif } + for (uint32_t I = 0; I < Layout.getIntPresSavedCount(); ++I) { + const X64::GP Reg = ABI.getPresRegNum(I); + _ mov(X64Reg::getRegRef(Reg), + asmjit::x86::Mem(ABI.getFrameBaseReg(), -(I + 1) * ABI.GpRegWidth)); + } + _ mov(ABI.getStackPointerReg(), ABI.getFrameBaseReg()); + _ pop(ABI.getFrameBaseReg()); + _ ret(); +} // EmitEpilog + +template +void emitTableSize(uint32_t TblIdx, Operand EntryIdx) { + ZEN_ASSERT(EntryIdx.getType() == WASMType::I32); + + ZEN_STATIC_ASSERT(sizeof(TableInstance::CurSize) == sizeof(uint32_t)); + uint32_t SizeOffset = Ctx->Mod->getLayout().TableElemSizeOffset; + asmjit::x86::Mem SizeAddr(ABI.getModuleInstReg(), SizeOffset, + sizeof(SizeOffset)); + // compare entry_idx with sizeReg + if (EntryIdx.isReg()) { + ZEN_ASSERT(EntryIdx.isTempReg()); + _ cmp(SizeAddr, EntryIdx.getRegRef()); + } else if (EntryIdx.isMem()) { + ZEN_ASSERT(EntryIdx.isTempMem()); + auto SizeReg = Layout.getScopedTempReg(); + _ mov(SizeReg, SizeAddr); + _ cmp(SizeReg, EntryIdx.getMem()); + } else if (EntryIdx.isImm()) { + _ cmp(SizeAddr, EntryIdx.getImm()); + } else { + ZEN_ABORT(); + } + _ jbe(getExceptLabel(ErrorCode::UndefinedElement)); +} + +void emitTableGet(uint32_t TblIdx, Operand Elem, X64::GP ResRegNum) { + // place table[tbl_idx] to ScopedTempReg1 + emitTableSize(TblIdx, Elem); + auto InstReg = ABI.getModuleInstReg(); + auto ResReg = X64Reg::getRegRef(ResRegNum); + constexpr uint32_t Shift = 2; + uint32_t BaseOffset = Ctx->Mod->getLayout().TableElemBaseOffset; + // load table[tbl_idx].functions[elem] into register + if (Elem.isReg()) { + // elem is in reg, reuse this reg + _ mov(ResReg, asmjit::x86::ptr(InstReg, Elem.getRegRef(), Shift, + BaseOffset)); + } else if (Elem.isMem()) { + // elem is on stack, load it and save to ScopedTempReg0 + auto ElemReg = Layout.getScopedTempReg(); + _ mov(ElemReg, Elem.getMem()); + _ mov(ResReg, asmjit::x86::ptr(InstReg, ElemReg, Shift, BaseOffset)); + } else if (Elem.isImm()) { + _ mov(ResReg, asmjit::x86::Mem(InstReg, Elem.getImm() * sizeof(uint32_t) + + BaseOffset)); + } +} public: - // - // initialization and finalization - // +// +// initialization and finalization +// - // finalization after compiling a function - void finalizeFunction() { - // update RSP adjustment in prolog with the actual frame size - ZEN_ASSERT(CurFuncState.FrameSizePatchOffset >= 0); - auto CurrOffset = _ offset(); - _ setOffset(CurFuncState.FrameSizePatchOffset); - _ long_().sub(ABI.getStackPointerReg(), Layout.getStackBudget()); - _ setOffset(CurrOffset); - } +// finalization after compiling a function +void finalizeFunction() { + // update RSP adjustment in prolog with the actual frame size + ZEN_ASSERT(CurFuncState.FrameSizePatchOffset >= 0); + auto CurrOffset = _ offset(); + _ setOffset(CurFuncState.FrameSizePatchOffset); + _ long_().sub(ABI.getStackPointerReg(), Layout.getStackBudget()); + _ setOffset(CurrOffset); +} public: - // - // temporary, stack and vm state management - // +// +// temporary, stack and vm state management +// - void callAbsolute(uintptr_t Addr) { _ call(Addr); } +void callAbsolute(uintptr_t Addr) { _ call(Addr); } - void setException() { _ or_(ABI.getGlobalDataBaseReg(), 1); } +void setException() { _ or_(ABI.getGlobalDataBaseReg(), 1); } - void checkCallException(bool IsImport) { +void checkCallException(bool IsImport) { #ifdef ZEN_ENABLE_CPU_EXCEPTION - if (IsImport) { - if (CurFuncState.ExceptionExitLabel == InvalidLabelId) { - CurFuncState.ExceptionExitLabel = createLabel(); - } - auto Inst = ABI.getModuleInstReg(); - asmjit::x86::Mem ExceptAddr(Inst, ExceptionOffset, 4); - _ cmp(ExceptAddr, 0); - jne(CurFuncState.ExceptionExitLabel); - } -#else + if (IsImport) { if (CurFuncState.ExceptionExitLabel == InvalidLabelId) { CurFuncState.ExceptionExitLabel = createLabel(); } - - if (!IsImport) { - // has exception, reuse r14 - _ test(ABI.getGlobalDataBaseReg(), 1); - jne(CurFuncState.ExceptionExitLabel); - } else { - auto Inst = ABI.getModuleInstReg(); - asmjit::x86::Mem ExceptAddr(Inst, ExceptionOffset, 4); - _ cmp(ExceptAddr, 0); - - jne(CurFuncState.ExceptionExitLabel); - } -#endif // ZEN_ENABLE_CPU_EXCEPTION + auto Inst = ABI.getModuleInstReg(); + asmjit::x86::Mem ExceptAddr(Inst, ExceptionOffset, 4); + _ cmp(ExceptAddr, 0); + jne(CurFuncState.ExceptionExitLabel); } - - void checkCallIndirectException() { checkCallException(true); } - - template - void checkMemoryOverflow(Operand Base, uint32_t Offset) { - if (Ctx->UseSoftMemCheck) { - constexpr uint32_t Size = getWASMTypeSize(); - Offset += Size; - // check (offset + size) overflow - if (Offset < Size) { - _ jmp(getExceptLabel(ErrorCode::OutOfBoundsMemory)); +#else + if (CurFuncState.ExceptionExitLabel == InvalidLabelId) { + CurFuncState.ExceptionExitLabel = createLabel(); } - auto BaseRegNum = Layout.getScopedTemp(); - auto BaseReg = X64Reg::getRegRef(BaseRegNum); - mov(BaseRegNum, Base); - _ add(BaseReg, Offset); - _ jc(getExceptLabel(ErrorCode::OutOfBoundsMemory)); - _ cmp(BaseReg, X64Reg::getRegRef(ABI.getMemorySize())); - _ ja(getExceptLabel(ErrorCode::OutOfBoundsMemory)); - } - } - -public: - // - // templated method to handle operations - // - - // in alphabetical order - // binary operator - template - Operand handleBinaryOpImpl(Operand LHS, Operand RHS) { - constexpr X64::Type X64Type = getX64TypeFromWASMType(); - - auto ResReg = toReg(LHS); - - BinaryOperatorImpl::emit( - ASM, X64Reg::getRegRef(ResReg), RHS); - - auto Ret = getTempOperand(Type); - mov(Ret, ResReg); - return Ret; - } - - // TODO: avoid redundant mov - template - Operand handleBitCountOpImpl(Operand Op) { - constexpr auto X64Type = getX64TypeFromWASMType(); - - auto Ret = getTempOperand(Type); - auto RegNum = Ret.isReg() - ? Ret.getReg() - : static_cast( - Layout.getScopedTemp()); - - mov(RegNum, Op); - UnaryOperatorImpl::emit(ASM, - X64Reg::getRegRef(RegNum)); - - if (!Ret.isReg()) { - mov(Ret, - Operand(Type, RegNum, Operand::FLAG_NONE)); - } - return Ret; - } - - // compare operator - template - Operand handleCompareOpImpl(Operand LHS, Operand RHS) { - constexpr X64::Type X64Type = getX64TypeFromWASMType(); - ZEN_ASSERT(LHS.getType() == Type); - - // make comparison - bool Exchanged = false; - if (Opr == CompareOperator::CO_EQZ) { - ZEN_ASSERT(RHS.getType() == WASMType::VOID); - ZEN_ASSERT(RHS.getKind() == OK_None); - test(LHS); - } else { - cmp(LHS, RHS, Exchanged); - } - - // allocate result register - typename X64TypeAttr::RegNum RegNum; - bool HasTempReg = Layout.hasAvailTempReg(RegNum); - if (!HasTempReg) { - RegNum = Layout.getScopedTemp(); - } else { - Layout.clearAvailReg(RegNum); - } - // setcc to resulta - if (!Exchanged) { - setcc(RegNum); - } else { - constexpr CompareOperator ExchangedOpr = - getExchangedCompareOperator(); - setcc(RegNum); - } - // make a sign-extension - _ movsx(X64Reg::getRegRef(RegNum), - X64Reg::getRegRef(RegNum)); - - // handle NaN operands - if (Type == WASMType::F32 || Type == WASMType::F64) { - auto TmpReg = Layout.getScopedTempReg(); - if (Opr == CompareOperator::CO_NE) { - _ mov(TmpReg, 1); + if (!IsImport) { + // has exception, reuse r14 + _ test(ABI.getGlobalDataBaseReg(), 1); + jne(CurFuncState.ExceptionExitLabel); } else { - _ mov(TmpReg, 0); - } - _ cmovp(X64Reg::getRegRef(RegNum), TmpReg); - } - - if (HasTempReg) { - return Operand(WASMType::I32, RegNum, Operand::FLAG_TEMP_REG); - } - - // store to stack - Operand Ret = getTempStackOperand(WASMType::I32); - ASM.mov(Ret.getMem(), - X64Reg::getRegRef(RegNum)); - return Ret; - } + auto Inst = ABI.getModuleInstReg(); + asmjit::x86::Mem ExceptAddr(Inst, ExceptionOffset, 4); + _ cmp(ExceptAddr, 0); - // constant - template - Operand handleConstImpl(typename WASMTypeAttr::Type Val) { - if (Ty == WASMType::I32) { - return X64InstOperand(WASMType::I32, Val); - } - if (Ty == WASMType::I64) { - if (Val >= INT32_MIN && Val <= INT32_MAX) { - return X64InstOperand(WASMType::I64, (int32_t)Val); - } - typename X64TypeAttr::RegNum RegNum; - bool HasTempReg = Layout.hasAvailTempReg(RegNum); - if (!HasTempReg) { - RegNum = Layout.getScopedTemp(); - } else { - Layout.clearAvailReg(RegNum); - } - _ movabs(X64Reg::getRegRef(RegNum), Val); - if (HasTempReg) { - return Operand(WASMType::I64, RegNum, Operand::FLAG_TEMP_REG); + jne(CurFuncState.ExceptionExitLabel); } - // store to stack - Operand Ret = getTempStackOperand(WASMType::I64); - ASM.mov(Ret.getMem(), - X64Reg::getRegRef(RegNum)); - return Ret; - } - // allocate memory on stack and fill stack/return Mem on stack - Operand Ret = getTempStackOperand(Ty); - ZEN_ASSERT(Ret.isMem() && Ret.getBase() == ABI.getFrameBase()); - int32_t Offset = Ret.getOffset(); - if (sizeof(Val) == 4) { - int32_t I32; - memcpy(&I32, &Val, sizeof(int32_t)); - _ mov(asmjit::x86::Mem(ABI.getFrameBaseReg(), Offset, 4), I32); - } else if (sizeof(Val) == 8) { - int64_t I64; - memcpy(&I64, &Val, sizeof(int64_t)); - _ mov(asmjit::x86::Mem(ABI.getFrameBaseReg(), Offset, 4), (int32_t)I64); - _ mov(asmjit::x86::Mem(ABI.getFrameBaseReg(), Offset + 4, 4), - (int32_t)(I64 >> 32)); - } else { - ZEN_ASSERT_TODO(); - } - return Ret; - } +#endif // ZEN_ENABLE_CPU_EXCEPTION +} - // convert from SrcType to DestType (between integer and float-point) - // TODO: error-handling and conversion to/from unsigned i64 - template - Operand handleConvertImpl(Operand Op) { - if (SrcType == WASMType::I64 && !Sext) { - return convertFromU64(Op); - } +void checkCallIndirectException() { checkCallException(true); } - constexpr auto X64DestType = getX64TypeFromWASMType(); - constexpr auto X64SrcType = getX64TypeFromWASMType(); - - auto Ret = getTempOperand(DestType); - auto RetReg = Ret.isReg() - ? Ret.getRegRef() - : Layout.getScopedTempReg(); - if (!Op.isReg()) { - auto RegNum = Layout.getScopedTemp(); - mov(RegNum, Op); - Op = Operand(SrcType, RegNum, Operand::FLAG_NONE); +template +void checkMemoryOverflow(Operand Base, uint32_t Offset) { + if (Ctx->UseSoftMemCheck) { + constexpr uint32_t Size = getWASMTypeSize(); + Offset += Size; + // check (offset + size) overflow + if (Offset < Size) { + _ jmp(getExceptLabel(ErrorCode::OutOfBoundsMemory)); } - ConvertOpImpl::emit( - ASM, RetReg, Op.getRegRef()); - - if (!Ret.isReg()) { - ASM.mov(Ret.getMem(), RetReg); - } - return Ret; + auto BaseRegNum = Layout.getScopedTemp(); + auto BaseReg = X64Reg::getRegRef(BaseRegNum); + mov(BaseRegNum, Base); + _ add(BaseReg, Offset); + _ jc(getExceptLabel(ErrorCode::OutOfBoundsMemory)); + _ cmp(BaseReg, X64Reg::getRegRef(ABI.getMemorySize())); + _ ja(getExceptLabel(ErrorCode::OutOfBoundsMemory)); } +} - template Operand convertFromU64(Operand Op) { - ZEN_STATIC_ASSERT(isWASMTypeFloat()); - constexpr auto X64DestType = getX64TypeFromWASMType(); - - if (!Op.isReg()) { - auto RegNum = Layout.getScopedTemp(); - mov(RegNum, Op); - Op = Operand(WASMType::I64, RegNum, Operand::FLAG_NONE); - } - auto OpReg = Op.getRegRef(); - - auto TmpReg = Layout.getScopedTempReg(); - _ mov(TmpReg, OpReg); - _ shr(TmpReg, 1); +public: +// +// templated method to handle operations +// - auto TmpReg2 = Layout.getScopedTempReg(); - _ mov(TmpReg2, OpReg); - _ and_(TmpReg2, 0x1); - _ or_(TmpReg, TmpReg2); +// in alphabetical order +// binary operator +template +Operand handleBinaryOpImpl(Operand LHS, Operand RHS) { + constexpr X64::Type X64Type = getX64TypeFromWASMType(); - auto ResReg = Layout.getScopedTempReg(); - auto ResRegNum = Layout.getScopedTemp(); - ConvertOpImpl::emit(ASM, ResReg, TmpReg); - ASM.add(ResReg, ResReg); + auto ResReg = toReg(LHS); - auto Label = _ newLabel(); - _ test(OpReg, OpReg); - _ js(Label); + BinaryOperatorImpl::emit( + ASM, X64Reg::getRegRef(ResReg), RHS); - ConvertOpImpl::emit(ASM, ResReg, OpReg); - _ bind(Label); + auto Ret = getTempOperand(Type); + mov(Ret, ResReg); + return Ret; +} - auto Ret = getTempOperand(DestType); - mov( - Ret, Operand(DestType, ResRegNum, Operand::FLAG_NONE)); - return Ret; - } +// TODO: avoid redundant mov +template +Operand handleBitCountOpImpl(Operand Op) { + constexpr auto X64Type = getX64TypeFromWASMType(); - // float div - template - Operand handleFDivOpImpl(Operand LHS, Operand RHS) { - ZEN_ASSERT(LHS.getType() == Type); - ZEN_ASSERT(RHS.getType() == Type); - ZEN_ASSERT(Type == WASMType::F32 || Type == WASMType::F64); + auto Ret = getTempOperand(Type); + auto RegNum = Ret.isReg() + ? Ret.getReg() + : static_cast( + Layout.getScopedTemp()); - constexpr X64::Type X64Type = getX64TypeFromWASMType(); + mov(RegNum, Op); + UnaryOperatorImpl::emit(ASM, + X64Reg::getRegRef(RegNum)); - typedef typename X64TypeAttr::RegNum RegNum; + if (!Ret.isReg()) { + mov(Ret, + Operand(Type, RegNum, Operand::FLAG_NONE)); + } + return Ret; +} + +// compare operator +template +Operand handleCompareOpImpl(Operand LHS, Operand RHS) { + constexpr X64::Type X64Type = getX64TypeFromWASMType(); + ZEN_ASSERT(LHS.getType() == Type); + + // make comparison + bool Exchanged = false; + if (Opr == CompareOperator::CO_EQZ) { + ZEN_ASSERT(RHS.getType() == WASMType::VOID); + ZEN_ASSERT(RHS.getKind() == OK_None); + test(LHS); + } else { + cmp(LHS, RHS, Exchanged); + } - bool LHSIsReg = true; - if (!LHS.isReg()) { - LHSIsReg = false; - RegNum LHSReg = Layout.getScopedTemp(); - mov(LHSReg, LHS); - LHS = Operand(Type, LHSReg, Operand::FLAG_NONE); + // allocate result register + typename X64TypeAttr::RegNum RegNum; + bool HasTempReg = Layout.hasAvailTempReg(RegNum); + if (!HasTempReg) { + RegNum = Layout.getScopedTemp(); + } else { + Layout.clearAvailReg(RegNum); + } + // setcc to resulta + if (!Exchanged) { + setcc(RegNum); + } else { + constexpr CompareOperator ExchangedOpr = getExchangedCompareOperator(); + setcc(RegNum); + } + // make a sign-extension + _ movsx(X64Reg::getRegRef(RegNum), + X64Reg::getRegRef(RegNum)); + + // handle NaN operands + if (Type == WASMType::F32 || Type == WASMType::F64) { + auto TmpReg = Layout.getScopedTempReg(); + if (Opr == CompareOperator::CO_NE) { + _ mov(TmpReg, 1); } else { - Layout.clearAvailReg((RegNum)LHS.getReg()); - } - - if (RHS.isImm()) { - RegNum RHSReg = Layout.getScopedTemp(); - mov(RHSReg, RHS); - RHS = Operand(Type, RHSReg, Operand::FLAG_NONE); - } - - BinaryOperatorImpl::emit(ASM, LHS, RHS); - - if (LHSIsReg) { - return LHS; + _ mov(TmpReg, 0); } - - Operand Ret = getTempOperand(Type); - mov(Ret, LHS); - return Ret; + _ cmovp(X64Reg::getRegRef(RegNum), TmpReg); } - template - Operand handleFloatCopysignImpl(Operand LHS, Operand RHS) { - constexpr auto X64Type = getX64TypeFromWASMType(); - auto LHSRegNum = toReg(LHS); - auto LHSReg = X64Reg::getRegRef(LHSRegNum); - auto RHSRegNum = toReg(RHS); - auto RHSReg = X64Reg::getRegRef(RHSRegNum); - - constexpr auto X64IntType = - getX64TypeFromWASMType::IntType>(); - auto ImmReg = Layout.getScopedTempReg(); - auto MaskReg = Layout.getScopedTempReg(); - auto SignMask = FloatAttr::SignMask; - - _ mov(MaskReg, ~SignMask); - ASM.fmov(ImmReg, MaskReg); - ASM.and_(LHSReg, ImmReg); - - _ mov(MaskReg, SignMask); - ASM.fmov(ImmReg, MaskReg); - ASM.and_(RHSReg, ImmReg); - - ASM.or_(LHSReg, RHSReg); - - auto Ret = getTempOperand(Type); - mov(Ret, LHSRegNum); - return Ret; + if (HasTempReg) { + return Operand(WASMType::I32, RegNum, Operand::FLAG_TEMP_REG); } - template - Operand handleFloatMinMaxImpl(Operand LHS, Operand RHS) { - ZEN_STATIC_ASSERT(isWASMTypeFloat()); - constexpr auto X64Type = getX64TypeFromWASMType(); - - auto TmpReg = Layout.getScopedTempReg(); - auto TmpRegNum = Layout.getScopedTemp(); - auto TmpReg2 = Layout.getScopedTempReg(); - auto TmpRegNum2 = Layout.getScopedTemp(); - - mov(TmpRegNum, LHS); - BinaryOperatorImpl::emit(ASM, TmpReg, RHS); + // store to stack + Operand Ret = getTempStackOperand(WASMType::I32); + ASM.mov(Ret.getMem(), + X64Reg::getRegRef(RegNum)); + return Ret; +} - bool Exchanged = false; - cmp(LHS, RHS, Exchanged); - auto HandleNaN = _ newLabel(); - auto Finish = _ newLabel(); - _ jp(HandleNaN); - _ jne(Finish); - - constexpr auto X64IntType = - getX64TypeFromWASMType::IntType>(); - auto IntReg = Layout.getScopedTempReg(); - auto IntReg2 = Layout.getScopedTempReg(); - - // handle 0.0 vs -0.0 - mov(TmpRegNum2, LHS); - _ mov(IntReg, Opr == BinaryOperator::BO_MIN ? FloatAttr::NegZero : 0); - ASM.fmov(IntReg2, TmpReg2); - _ cmp(IntReg, IntReg2); - _ jne(Finish); - mov(TmpRegNum, LHS); - _ jmp(Finish); - - _ bind(HandleNaN); - auto CanonicalNaN = FloatAttr::CanonicalNan; - _ mov(IntReg, CanonicalNaN); - ASM.fmov(TmpReg, IntReg); - - _ bind(Finish); - auto Ret = getTempOperand(Type); - mov(Ret, - Operand(Type, TmpRegNum, Operand::FLAG_NONE)); - return Ret; +// constant +template +Operand handleConstImpl(typename WASMTypeAttr::Type Val) { + if (Ty == WASMType::I32) { + return X64InstOperand(WASMType::I32, Val); } - - // integer div - template - Operand handleIDivOpImpl(Operand LHS, Operand RHS) { - ZEN_ASSERT(LHS.getType() == Type); - ZEN_ASSERT(RHS.getType() == Type); - ZEN_ASSERT(Type == WASMType::I32 || Type == WASMType::I64); - - constexpr X64::Type X64Type = getX64TypeFromWASMType(); - - constexpr bool IsUnsigned = - (Opr == BinaryOperator::BO_DIV_U || Opr == BinaryOperator::BO_REM_U); - constexpr bool IsRem = - (Opr == BinaryOperator::BO_REM_U || Opr == BinaryOperator::BO_REM_S); - - uint32_t NormalPathLabel = 0; - uint32_t EndLabel = 0; - - Operand Ret = getTempOperand(Type); - bool Exchanged = false; - - // rem_s - if (!IsUnsigned) { - NormalPathLabel = createLabel(); - EndLabel = createLabel(); - - Operand CmpOpnd; - if (X64Type == X64::I32) { - CmpOpnd = Operand(Type, 0x80000000U); - } else { - auto RegNum = Layout.getScopedTemp(); - _ movabs(X64Reg::getRegRef(RegNum), 0x8000000000000000ULL); - CmpOpnd = Operand(Type, RegNum, Operand::FLAG_NONE); - } - - cmp(LHS, CmpOpnd, Exchanged); - jne(NormalPathLabel); - - if (X64Type == X64::I32) { - CmpOpnd = Operand(Type, 0xffffffffU); - } else { - auto RegNum = Layout.getScopedTemp(); - _ movabs(X64Reg::getRegRef(RegNum), 0xffffffffffffffffULL); - CmpOpnd = Operand(Type, RegNum, Operand::FLAG_NONE); - } - - cmp(RHS, CmpOpnd, Exchanged); - jne(NormalPathLabel); - - if (IsRem) { - mov(Ret, Operand(Type, 0)); - branch(EndLabel); - } else { - _ jmp(getExceptLabel(ErrorCode::IntegerOverflow)); - } - - bindLabel(NormalPathLabel); - } - -#ifndef ZEN_ENABLE_CPU_EXCEPTION - cmp(RHS, Operand(Type, 0), Exchanged); - _ je(getExceptLabel(ErrorCode::IntegerDivByZero)); -#endif // ZEN_ENABLE_CPU_EXCEPTION - - mov(X64::RAX, LHS); - if (IsUnsigned) { - auto RDXReg = X64Reg::getRegRef(X64::GP::RDX); - ASM.xor_(RDXReg, RDXReg); - } else if (X64Type == X64::I32) { - ASM.cdq(); - } else if (X64Type == X64::I64) { - ASM.cqo(); + if (Ty == WASMType::I64) { + if (Val >= INT32_MIN && Val <= INT32_MAX) { + return X64InstOperand(WASMType::I64, (int32_t)Val); } - - if (!RHS.isReg()) { - mov(X64::RCX, RHS); - RHS = Operand(Type, X64::RCX, Operand::FLAG_NONE); - } - - BinaryOperatorImpl::emit(ASM, RHS, RHS); - - if (IsRem) { - mov(Ret, - Operand(Type, X64::RDX, Operand::FLAG_NONE)); + typename X64TypeAttr::RegNum RegNum; + bool HasTempReg = Layout.hasAvailTempReg(RegNum); + if (!HasTempReg) { + RegNum = Layout.getScopedTemp(); } else { - mov(Ret, - Operand(Type, X64::RAX, Operand::FLAG_NONE)); + Layout.clearAvailReg(RegNum); } - - // rem_s - if (!IsUnsigned && IsRem) { - bindLabel(EndLabel); + _ movabs(X64Reg::getRegRef(RegNum), Val); + if (HasTempReg) { + return Operand(WASMType::I64, RegNum, Operand::FLAG_TEMP_REG); } - + // store to stack + Operand Ret = getTempStackOperand(WASMType::I64); + ASM.mov(Ret.getMem(), + X64Reg::getRegRef(RegNum)); return Ret; } - - template - Operand handleFloatToIntImpl(Operand Op) { - // tag dispatch - return handleFloatToIntImpl( - Op, std::integral_constant()); - } - - // extend from stype to dtype in same type kind (integer or floating-point) - template - Operand handleIntExtendImpl(Operand Op) { - constexpr auto X64DestType = getX64TypeFromWASMType(); - constexpr auto X64SrcType = getX64TypeFromWASMType(); - - auto Ret = getTempOperand(DestType); - auto RegNum = Layout.getScopedTemp(); - auto RetReg = Ret.isReg() ? Ret.getRegRef() - : X64Reg::getRegRef(RegNum); - - using ExtendOp = ExtendOperatorImpl; - if (Op.isImm()) { - auto RegNum2 = Layout.getScopedTemp(); - auto TmpReg = X64Reg::getRegRef(RegNum2); - _ mov(TmpReg, Op.getImm()); - ExtendOp::emit(ASM, RetReg, TmpReg); - } else if (Op.isReg()) { - ExtendOp::emit(ASM, RetReg, Op.getRegRef()); + // allocate memory on stack and fill stack/return Mem on stack + Operand Ret = getTempStackOperand(Ty); + ZEN_ASSERT(Ret.isMem() && Ret.getBase() == ABI.getFrameBase()); + int32_t Offset = Ret.getOffset(); + if (sizeof(Val) == 4) { + int32_t I32; + memcpy(&I32, &Val, sizeof(int32_t)); + _ mov(asmjit::x86::Mem(ABI.getFrameBaseReg(), Offset, 4), I32); + } else if (sizeof(Val) == 8) { + int64_t I64; + memcpy(&I64, &Val, sizeof(int64_t)); + _ mov(asmjit::x86::Mem(ABI.getFrameBaseReg(), Offset, 4), (int32_t)I64); + _ mov(asmjit::x86::Mem(ABI.getFrameBaseReg(), Offset + 4, 4), + (int32_t)(I64 >> 32)); + } else { + ZEN_ASSERT_TODO(); + } + return Ret; +} + +// convert from SrcType to DestType (between integer and float-point) +// TODO: error-handling and conversion to/from unsigned i64 +template +Operand handleConvertImpl(Operand Op) { + if (SrcType == WASMType::I64 && !Sext) { + return convertFromU64(Op); + } + + constexpr auto X64DestType = getX64TypeFromWASMType(); + constexpr auto X64SrcType = getX64TypeFromWASMType(); + + auto Ret = getTempOperand(DestType); + auto RetReg = Ret.isReg() + ? Ret.getRegRef() + : Layout.getScopedTempReg(); + if (!Op.isReg()) { + auto RegNum = Layout.getScopedTemp(); + mov(RegNum, Op); + Op = Operand(SrcType, RegNum, Operand::FLAG_NONE); + } + + ConvertOpImpl::emit( + ASM, RetReg, Op.getRegRef()); + + if (!Ret.isReg()) { + ASM.mov(Ret.getMem(), RetReg); + } + return Ret; +} + +template Operand convertFromU64(Operand Op) { + ZEN_STATIC_ASSERT(isWASMTypeFloat()); + constexpr auto X64DestType = getX64TypeFromWASMType(); + + if (!Op.isReg()) { + auto RegNum = Layout.getScopedTemp(); + mov(RegNum, Op); + Op = Operand(WASMType::I64, RegNum, Operand::FLAG_NONE); + } + auto OpReg = Op.getRegRef(); + + auto TmpReg = Layout.getScopedTempReg(); + _ mov(TmpReg, OpReg); + _ shr(TmpReg, 1); + + auto TmpReg2 = Layout.getScopedTempReg(); + _ mov(TmpReg2, OpReg); + _ and_(TmpReg2, 0x1); + _ or_(TmpReg, TmpReg2); + + auto ResReg = Layout.getScopedTempReg(); + auto ResRegNum = Layout.getScopedTemp(); + ConvertOpImpl::emit(ASM, ResReg, TmpReg); + ASM.add(ResReg, ResReg); + + auto Label = _ newLabel(); + _ test(OpReg, OpReg); + _ js(Label); + + ConvertOpImpl::emit(ASM, ResReg, OpReg); + _ bind(Label); + + auto Ret = getTempOperand(DestType); + mov( + Ret, Operand(DestType, ResRegNum, Operand::FLAG_NONE)); + return Ret; +} + +// float div +template +Operand handleFDivOpImpl(Operand LHS, Operand RHS) { + ZEN_ASSERT(LHS.getType() == Type); + ZEN_ASSERT(RHS.getType() == Type); + ZEN_ASSERT(Type == WASMType::F32 || Type == WASMType::F64); + + constexpr X64::Type X64Type = getX64TypeFromWASMType(); + + typedef typename X64TypeAttr::RegNum RegNum; + + bool LHSIsReg = true; + if (!LHS.isReg()) { + LHSIsReg = false; + RegNum LHSReg = Layout.getScopedTemp(); + mov(LHSReg, LHS); + LHS = Operand(Type, LHSReg, Operand::FLAG_NONE); + } else { + Layout.clearAvailReg((RegNum)LHS.getReg()); + } + + if (RHS.isImm()) { + RegNum RHSReg = Layout.getScopedTemp(); + mov(RHSReg, RHS); + RHS = Operand(Type, RHSReg, Operand::FLAG_NONE); + } + + BinaryOperatorImpl::emit(ASM, LHS, RHS); + + if (LHSIsReg) { + return LHS; + } + + Operand Ret = getTempOperand(Type); + mov(Ret, LHS); + return Ret; +} + +template +Operand handleFloatCopysignImpl(Operand LHS, Operand RHS) { + constexpr auto X64Type = getX64TypeFromWASMType(); + auto LHSRegNum = toReg(LHS); + auto LHSReg = X64Reg::getRegRef(LHSRegNum); + auto RHSRegNum = toReg(RHS); + auto RHSReg = X64Reg::getRegRef(RHSRegNum); + + constexpr auto X64IntType = + getX64TypeFromWASMType::IntType>(); + auto ImmReg = Layout.getScopedTempReg(); + auto MaskReg = Layout.getScopedTempReg(); + auto SignMask = FloatAttr::SignMask; + + _ mov(MaskReg, ~SignMask); + ASM.fmov(ImmReg, MaskReg); + ASM.and_(LHSReg, ImmReg); + + _ mov(MaskReg, SignMask); + ASM.fmov(ImmReg, MaskReg); + ASM.and_(RHSReg, ImmReg); + + ASM.or_(LHSReg, RHSReg); + + auto Ret = getTempOperand(Type); + mov(Ret, LHSRegNum); + return Ret; +} + +template +Operand handleFloatMinMaxImpl(Operand LHS, Operand RHS) { + ZEN_STATIC_ASSERT(isWASMTypeFloat()); + constexpr auto X64Type = getX64TypeFromWASMType(); + + auto TmpReg = Layout.getScopedTempReg(); + auto TmpRegNum = Layout.getScopedTemp(); + auto TmpReg2 = Layout.getScopedTempReg(); + auto TmpRegNum2 = Layout.getScopedTemp(); + + mov(TmpRegNum, LHS); + BinaryOperatorImpl::emit(ASM, TmpReg, RHS); + + bool Exchanged = false; + cmp(LHS, RHS, Exchanged); + auto HandleNaN = _ newLabel(); + auto Finish = _ newLabel(); + _ jp(HandleNaN); + _ jne(Finish); + + constexpr auto X64IntType = + getX64TypeFromWASMType::IntType>(); + auto IntReg = Layout.getScopedTempReg(); + auto IntReg2 = Layout.getScopedTempReg(); + + // handle 0.0 vs -0.0 + mov(TmpRegNum2, LHS); + _ mov(IntReg, Opr == BinaryOperator::BO_MIN ? FloatAttr::NegZero : 0); + ASM.fmov(IntReg2, TmpReg2); + _ cmp(IntReg, IntReg2); + _ jne(Finish); + mov(TmpRegNum, LHS); + _ jmp(Finish); + + _ bind(HandleNaN); + auto CanonicalNaN = FloatAttr::CanonicalNan; + _ mov(IntReg, CanonicalNaN); + ASM.fmov(TmpReg, IntReg); + + _ bind(Finish); + auto Ret = getTempOperand(Type); + mov(Ret, + Operand(Type, TmpRegNum, Operand::FLAG_NONE)); + return Ret; +} + +// integer div +template +Operand handleIDivOpImpl(Operand LHS, Operand RHS) { + ZEN_ASSERT(LHS.getType() == Type); + ZEN_ASSERT(RHS.getType() == Type); + ZEN_ASSERT(Type == WASMType::I32 || Type == WASMType::I64); + + constexpr X64::Type X64Type = getX64TypeFromWASMType(); + + constexpr bool IsUnsigned = + (Opr == BinaryOperator::BO_DIV_U || Opr == BinaryOperator::BO_REM_U); + constexpr bool IsRem = + (Opr == BinaryOperator::BO_REM_U || Opr == BinaryOperator::BO_REM_S); + + uint32_t NormalPathLabel = 0; + uint32_t EndLabel = 0; + + Operand Ret = getTempOperand(Type); + bool Exchanged = false; + + // rem_s + if (!IsUnsigned) { + NormalPathLabel = createLabel(); + EndLabel = createLabel(); + + Operand CmpOpnd; + if (X64Type == X64::I32) { + CmpOpnd = Operand(Type, 0x80000000U); } else { - ExtendOp::emit(ASM, RetReg, Op.getMem()); + auto RegNum = Layout.getScopedTemp(); + _ movabs(X64Reg::getRegRef(RegNum), 0x8000000000000000ULL); + CmpOpnd = Operand(Type, RegNum, Operand::FLAG_NONE); } - if (Ret.isMem()) { - _ mov(Ret.getMem(), RetReg); - } - return Ret; - } - - // fused compare and branch - template - void handleFusedCompareBranchImpl(Operand CmpLHS, Operand CmpRHS, - uint32_t Label) { - constexpr X64::Type X64CondType = getX64TypeFromWASMType(); - ZEN_ASSERT(CmpLHS.getType() == CondType); + cmp(LHS, CmpOpnd, Exchanged); + jne(NormalPathLabel); - // make comparison - bool Exchanged = false; - if (Opr == CompareOperator::CO_EQZ) { - ZEN_ASSERT(CmpRHS.getType() == WASMType::VOID); - ZEN_ASSERT(CmpRHS.getKind() == OK_None); - test(CmpLHS); + if (X64Type == X64::I32) { + CmpOpnd = Operand(Type, 0xffffffffU); } else { - cmp(CmpLHS, CmpRHS, Exchanged); + auto RegNum = Layout.getScopedTemp(); + _ movabs(X64Reg::getRegRef(RegNum), 0xffffffffffffffffULL); + CmpOpnd = Operand(Type, RegNum, Operand::FLAG_NONE); } - if (!Exchanged) { - jmpcc(Label); - } else { - constexpr CompareOperator ExchangedOpr = - getExchangedCompareOperator(); - jmpcc(Label); - } - } + cmp(RHS, CmpOpnd, Exchanged); + jne(NormalPathLabel); - // fused compare and select - template - Operand handleFusedCompareSelectImpl(Operand CmpLHS, Operand CmpRHS, - Operand LHS, Operand RHS) { - constexpr X64::Type X64CondType = getX64TypeFromWASMType(); - ZEN_ASSERT(CmpLHS.getType() == CondType); - - // make comparison - bool Exchanged = false; - if (Opr == CompareOperator::CO_EQZ) { - ZEN_ASSERT(CmpRHS.getType() == WASMType::VOID); - ZEN_ASSERT(CmpRHS.getKind() == OK_None); - test(CmpLHS); + if (IsRem) { + mov(Ret, Operand(Type, 0)); + branch(EndLabel); } else { - cmp(CmpLHS, CmpRHS, Exchanged); + _ jmp(getExceptLabel(ErrorCode::IntegerOverflow)); } - ZEN_ASSERT(LHS.getType() == RHS.getType()); - switch (LHS.getType()) { - // TODO: use cmov for integer type - case WASMType::I32: - return fusedCompareSelectWithIf(LHS, RHS, Exchanged); - case WASMType::I64: - return fusedCompareSelectWithIf(LHS, RHS, Exchanged); - case WASMType::F32: - return fusedCompareSelectWithIf(LHS, RHS, Exchanged); - case WASMType::F64: - return fusedCompareSelectWithIf(LHS, RHS, Exchanged); - default: - ZEN_ABORT(); - } + bindLabel(NormalPathLabel); } - // load value from memory - template - void loadRegFromMem(X64::RegNum Val, asmjit::x86::Mem Mem) { - LoadOperatorImpl::emit(ASM, Val, Mem); - } +#ifndef ZEN_ENABLE_CPU_EXCEPTION + cmp(RHS, Operand(Type, 0), Exchanged); + _ je(getExceptLabel(ErrorCode::IntegerDivByZero)); +#endif // ZEN_ENABLE_CPU_EXCEPTION - // store value to memory - template - void storeRegToMem(X64::RegNum Val, asmjit::x86::Mem Mem) { - ASM.mov(Mem, X64Reg::getRegRef(Val)); + mov(X64::RAX, LHS); + if (IsUnsigned) { + auto RDXReg = X64Reg::getRegRef(X64::GP::RDX); + ASM.xor_(RDXReg, RDXReg); + } else if (X64Type == X64::I32) { + ASM.cdq(); + } else if (X64Type == X64::I64) { + ASM.cqo(); } - // store value to memory - template - void storeImmToMem(uint32_t Val, asmjit::x86::Mem Mem) { - ASM.mov(Mem, Val); + if (!RHS.isReg()) { + mov(X64::RCX, RHS); + RHS = Operand(Type, X64::RCX, Operand::FLAG_NONE); } - // load from memory in SrcType and return in DestType - template - Operand handleLoadImpl(Operand Base, uint32_t Offset, uint32_t Align) { - constexpr X64::Type X64DestType = getX64TypeFromWASMType(); - constexpr X64::Type X64SrcType = getX64TypeFromWASMType(); - constexpr X64::Type AddrType = - getX64TypeFromWASMType(); - ZEN_ASSERT(Base.getType() == X64OnePassABI::WASMAddrType); - - checkMemoryOverflow(Base, Offset); + BinaryOperatorImpl::emit(ASM, RHS, RHS); - typename X64TypeAttr::RegNum BaseReg = - X64::RAX; // the initial value only used to suppress compiler error - - asmjit::x86::Mem Addr; - if (Base.isReg()) { - BaseReg = (typename X64TypeAttr::RegNum)Base.getReg(); - } else if (Base.isMem()) { - BaseReg = Layout.getScopedTemp(); - ASM.mov(X64Reg::getRegRef(BaseReg), - Base.getMem()); - } else if (Base.isImm()) { - uint64_t Offset64 = (uint64_t)Offset; - Offset64 += (uint32_t)Base.getImm(); - if (Offset64 > INT32_MAX) { - Offset = INT32_MAX; // invalid addr - } else { - Offset = (uint32_t)Offset64; - } + if (IsRem) { + mov(Ret, + Operand(Type, X64::RDX, Operand::FLAG_NONE)); + } else { + mov(Ret, + Operand(Type, X64::RAX, Operand::FLAG_NONE)); + } + + // rem_s + if (!IsUnsigned && IsRem) { + bindLabel(EndLabel); + } + + return Ret; +} + +template +Operand handleFloatToIntImpl(Operand Op) { + // tag dispatch + return handleFloatToIntImpl( + Op, std::integral_constant()); +} + +// extend from stype to dtype in same type kind (integer or floating-point) +template +Operand handleIntExtendImpl(Operand Op) { + constexpr auto X64DestType = getX64TypeFromWASMType(); + constexpr auto X64SrcType = getX64TypeFromWASMType(); + + auto Ret = getTempOperand(DestType); + auto RegNum = Layout.getScopedTemp(); + auto RetReg = Ret.isReg() ? Ret.getRegRef() + : X64Reg::getRegRef(RegNum); + + using ExtendOp = ExtendOperatorImpl; + if (Op.isImm()) { + auto RegNum2 = Layout.getScopedTemp(); + auto TmpReg = X64Reg::getRegRef(RegNum2); + _ mov(TmpReg, Op.getImm()); + ExtendOp::emit(ASM, RetReg, TmpReg); + } else if (Op.isReg()) { + ExtendOp::emit(ASM, RetReg, Op.getRegRef()); + } else { + ExtendOp::emit(ASM, RetReg, Op.getMem()); + } + + if (Ret.isMem()) { + _ mov(Ret.getMem(), RetReg); + } + return Ret; +} + +// fused compare and branch +template +void handleFusedCompareBranchImpl(Operand CmpLHS, Operand CmpRHS, + uint32_t Label) { + constexpr X64::Type X64CondType = getX64TypeFromWASMType(); + ZEN_ASSERT(CmpLHS.getType() == CondType); + + // make comparison + bool Exchanged = false; + if (Opr == CompareOperator::CO_EQZ) { + ZEN_ASSERT(CmpRHS.getType() == WASMType::VOID); + ZEN_ASSERT(CmpRHS.getKind() == OK_None); + test(CmpLHS); + } else { + cmp(CmpLHS, CmpRHS, Exchanged); + } + + if (!Exchanged) { + jmpcc(Label); + } else { + constexpr CompareOperator ExchangedOpr = getExchangedCompareOperator(); + jmpcc(Label); + } +} + +// fused compare and select +template +Operand handleFusedCompareSelectImpl(Operand CmpLHS, Operand CmpRHS, + Operand LHS, Operand RHS) { + constexpr X64::Type X64CondType = getX64TypeFromWASMType(); + ZEN_ASSERT(CmpLHS.getType() == CondType); + + // make comparison + bool Exchanged = false; + if (Opr == CompareOperator::CO_EQZ) { + ZEN_ASSERT(CmpRHS.getType() == WASMType::VOID); + ZEN_ASSERT(CmpRHS.getKind() == OK_None); + test(CmpLHS); + } else { + cmp(CmpLHS, CmpRHS, Exchanged); + } + + ZEN_ASSERT(LHS.getType() == RHS.getType()); + switch (LHS.getType()) { + // TODO: use cmov for integer type + case WASMType::I32: + return fusedCompareSelectWithIf(LHS, RHS, Exchanged); + case WASMType::I64: + return fusedCompareSelectWithIf(LHS, RHS, Exchanged); + case WASMType::F32: + return fusedCompareSelectWithIf(LHS, RHS, Exchanged); + case WASMType::F64: + return fusedCompareSelectWithIf(LHS, RHS, Exchanged); + default: + ZEN_ABORT(); + } +} + +// load value from memory +template +void loadRegFromMem(X64::RegNum Val, asmjit::x86::Mem Mem) { + LoadOperatorImpl::emit(ASM, Val, Mem); +} + +// store value to memory +template +void storeRegToMem(X64::RegNum Val, asmjit::x86::Mem Mem) { + ASM.mov(Mem, X64Reg::getRegRef(Val)); +} + +// store value to memory +template +void storeImmToMem(uint32_t Val, asmjit::x86::Mem Mem) { + ASM.mov(Mem, Val); +} + +// load from memory in SrcType and return in DestType +template +Operand handleLoadImpl(Operand Base, uint32_t Offset, uint32_t Align) { + constexpr X64::Type X64DestType = getX64TypeFromWASMType(); + constexpr X64::Type X64SrcType = getX64TypeFromWASMType(); + constexpr X64::Type AddrType = + getX64TypeFromWASMType(); + ZEN_ASSERT(Base.getType() == X64OnePassABI::WASMAddrType); + + checkMemoryOverflow(Base, Offset); + + typename X64TypeAttr::RegNum BaseReg = + X64::RAX; // the initial value only used to suppress compiler error + + asmjit::x86::Mem Addr; + if (Base.isReg()) { + BaseReg = (typename X64TypeAttr::RegNum)Base.getReg(); + } else if (Base.isMem()) { + BaseReg = Layout.getScopedTemp(); + ASM.mov(X64Reg::getRegRef(BaseReg), + Base.getMem()); + } else if (Base.isImm()) { + uint64_t Offset64 = (uint64_t)Offset; + Offset64 += (uint32_t)Base.getImm(); + if (Offset64 > INT32_MAX) { + Offset = INT32_MAX; // invalid addr } else { - ZEN_ABORT(); - } - - typename X64TypeAttr::RegNum ValReg; - bool HasTempReg = Layout.hasAvailTempReg(ValReg); - if (!HasTempReg) { - ValReg = Layout.getScopedTemp(); - } - - Addr = Base.isImm() - ? asmjit::x86::Mem(ABI.getMemoryBaseReg(), Offset, - getWASMTypeSize()) - : asmjit::x86::Mem(ABI.getMemoryBaseReg(), - X64Reg::getRegRef(BaseReg), 0, - Offset, getWASMTypeSize()); - -#ifdef ZEN_ENABLE_CPU_EXCEPTION - if (!Base.isImm() && (Offset >= INT32_MAX)) { - // when offset >= INT32_MAX, then will cause inst like mov edi, dword - // ptr[r13+edi-1]. - auto MemAddrReg = Layout.getScopedTemp(); - _ mov(X64Reg::getRegRef(MemAddrReg), Offset); - _ add(X64Reg::getRegRef(MemAddrReg), - X64Reg::getRegRef(BaseReg)); - _ add(X64Reg::getRegRef(MemAddrReg), ABI.getMemoryBaseReg()); - Addr = asmjit::x86::Mem(X64Reg::getRegRef(MemAddrReg), 0, - getWASMTypeSize(SrcType)); + Offset = (uint32_t)Offset64; } -#endif // ZEN_ENABLE_CPU_EXCEPTION - - LoadOperatorImpl::emit( - ASM, X64Reg::getRegRef(ValReg), Addr); - if (HasTempReg) { - Layout.clearAvailReg(ValReg); - return Operand(DestType, ValReg, Operand::FLAG_TEMP_REG); - } - Operand Ret = getTempStackOperand(DestType); - ASM.mov(Ret.getMem(), - X64Reg::getRegRef(ValReg)); - return Ret; + } else { + ZEN_ABORT(); } - // shift - template - Operand handleShiftOpImpl(Operand LHS, Operand RHS) { - ZEN_ASSERT(LHS.getType() == Type); - ZEN_ASSERT(RHS.getType() == Type); - constexpr X64::Type X64Type = getX64TypeFromWASMType(); - - auto ResReg = toReg(LHS); - - if (RHS.isMem() || RHS.isReg()) { - mov(X64::RCX, RHS); - RHS = Operand(Type, X64::RCX, Operand::FLAG_NONE); - } - - BinaryOperatorImpl::emit( - ASM, X64Reg::getRegRef(ResReg), RHS); - - auto Ret = getTempOperand(Type); - mov(Ret, ResReg); - - return Ret; + typename X64TypeAttr::RegNum ValReg; + bool HasTempReg = Layout.hasAvailTempReg(ValReg); + if (!HasTempReg) { + ValReg = Layout.getScopedTemp(); } - // store value to memory in Type - template - void handleStoreImpl(Operand Value, Operand Base, uint32_t Offset, - uint32_t Align) { - constexpr X64::Type X64Type = getX64TypeFromWASMType(); - constexpr X64::Type AddrType = - getX64TypeFromWASMType(); - ZEN_ASSERT(Base.getType() == X64OnePassABI::WASMAddrType); - - checkMemoryOverflow(Base, Offset); - - X64::RegNum RegNum = 0; - if (Base.isReg()) { - RegNum = Base.getReg(); - } else if (Base.isMem()) { - RegNum = Layout.getScopedTemp(); - ASM.mov(X64Reg::getRegRef(RegNum), - Base.getMem()); - } else if (Base.isImm()) { - uint64_t Offset64 = (uint64_t)Offset; - Offset64 += (uint32_t)Base.getImm(); - if (Offset64 > INT32_MAX) { - Offset = INT32_MAX; // invalid addr - } else { - Offset = (uint32_t)Offset64; - } - } else { - ZEN_ABORT(); - } - - // Addr = memoryBase + (in64) offset, so when offset < 0, - // the result i32 Addr works like add (2**32 + offset) - asmjit::x86::Mem Addr = - Base.isImm() ? asmjit::x86::Mem(ABI.getMemoryBaseReg(), Offset, - getWASMTypeSize()) - : asmjit::x86::Mem(ABI.getMemoryBaseReg(), - X64Reg::getRegRef(RegNum), 0, - Offset, getWASMTypeSize()); + Addr = Base.isImm() ? asmjit::x86::Mem(ABI.getMemoryBaseReg(), Offset, + getWASMTypeSize()) + : asmjit::x86::Mem(ABI.getMemoryBaseReg(), + X64Reg::getRegRef(BaseReg), + 0, Offset, getWASMTypeSize()); - mov(Addr, Value); - } - - Operand handleIntTruncImpl(Operand Op) { - auto Src = toReg(Op); - auto Dest = getTempOperand(WASMType::I32); - mov(Dest, Src); - return Dest; +#ifdef ZEN_ENABLE_CPU_EXCEPTION + if (!Base.isImm() && (Offset >= INT32_MAX)) { + // when offset >= INT32_MAX, then will cause inst like mov edi, dword + // ptr[r13+edi-1]. + auto MemAddrReg = Layout.getScopedTemp(); + _ mov(X64Reg::getRegRef(MemAddrReg), Offset); + _ add(X64Reg::getRegRef(MemAddrReg), + X64Reg::getRegRef(BaseReg)); + _ add(X64Reg::getRegRef(MemAddrReg), ABI.getMemoryBaseReg()); + Addr = asmjit::x86::Mem(X64Reg::getRegRef(MemAddrReg), 0, + getWASMTypeSize(SrcType)); } +#endif // ZEN_ENABLE_CPU_EXCEPTION - // floating-point unary operators - template - Operand handleUnaryOpImpl(Operand Op) { - ZEN_STATIC_ASSERT(Type == WASMType::F32 || Type == WASMType::F64); - switch (Opr) { - case UnaryOperator::UO_ABS: - return floatAbs(Op); - case UnaryOperator::UO_NEG: - return floatNeg(Op); - case UnaryOperator::UO_SQRT: - return floatSqrt(Op); - case UnaryOperator::UO_CEIL: - case UnaryOperator::UO_FLOOR: - case UnaryOperator::UO_NEAREST: - case UnaryOperator::UO_TRUNC: - return floatRound(Op); - default: - ZEN_ABORT(); - } - } + LoadOperatorImpl::emit( + ASM, X64Reg::getRegRef(ValReg), Addr); + if (HasTempReg) { + Layout.clearAvailReg(ValReg); + return Operand(DestType, ValReg, Operand::FLAG_TEMP_REG); + } + Operand Ret = getTempStackOperand(DestType); + ASM.mov(Ret.getMem(), + X64Reg::getRegRef(ValReg)); + return Ret; +} + +// shift +template +Operand handleShiftOpImpl(Operand LHS, Operand RHS) { + ZEN_ASSERT(LHS.getType() == Type); + ZEN_ASSERT(RHS.getType() == Type); + constexpr X64::Type X64Type = getX64TypeFromWASMType(); + + auto ResReg = toReg(LHS); + + if (RHS.isMem() || RHS.isReg()) { + mov(X64::RCX, RHS); + RHS = Operand(Type, X64::RCX, Operand::FLAG_NONE); + } + + BinaryOperatorImpl::emit( + ASM, X64Reg::getRegRef(ResReg), RHS); + + auto Ret = getTempOperand(Type); + mov(Ret, ResReg); + + return Ret; +} + +// store value to memory in Type +template +void handleStoreImpl(Operand Value, Operand Base, uint32_t Offset, + uint32_t Align) { + constexpr X64::Type X64Type = getX64TypeFromWASMType(); + constexpr X64::Type AddrType = + getX64TypeFromWASMType(); + ZEN_ASSERT(Base.getType() == X64OnePassABI::WASMAddrType); + + checkMemoryOverflow(Base, Offset); + + X64::RegNum RegNum = 0; + if (Base.isReg()) { + RegNum = Base.getReg(); + } else if (Base.isMem()) { + RegNum = Layout.getScopedTemp(); + ASM.mov(X64Reg::getRegRef(RegNum), + Base.getMem()); + } else if (Base.isImm()) { + uint64_t Offset64 = (uint64_t)Offset; + Offset64 += (uint32_t)Base.getImm(); + if (Offset64 > INT32_MAX) { + Offset = INT32_MAX; // invalid addr + } else { + Offset = (uint32_t)Offset64; + } + } else { + ZEN_ABORT(); + } + + // Addr = memoryBase + (in64) offset, so when offset < 0, + // the result i32 Addr works like add (2**32 + offset) + asmjit::x86::Mem Addr = + Base.isImm() ? asmjit::x86::Mem(ABI.getMemoryBaseReg(), Offset, + getWASMTypeSize()) + : asmjit::x86::Mem(ABI.getMemoryBaseReg(), + X64Reg::getRegRef(RegNum), 0, + Offset, getWASMTypeSize()); + + mov(Addr, Value); +} + +Operand handleIntTruncImpl(Operand Op) { + auto Src = toReg(Op); + auto Dest = getTempOperand(WASMType::I32); + mov(Dest, Src); + return Dest; +} + +// floating-point unary operators +template +Operand handleUnaryOpImpl(Operand Op) { + ZEN_STATIC_ASSERT(Type == WASMType::F32 || Type == WASMType::F64); + switch (Opr) { + case UnaryOperator::UO_ABS: + return floatAbs(Op); + case UnaryOperator::UO_NEG: + return floatNeg(Op); + case UnaryOperator::UO_SQRT: + return floatSqrt(Op); + case UnaryOperator::UO_CEIL: + case UnaryOperator::UO_FLOOR: + case UnaryOperator::UO_NEAREST: + case UnaryOperator::UO_TRUNC: + return floatRound(Op); + default: + ZEN_ABORT(); + } +} public: - // - // branch, call and return instructions - // +// +// branch, call and return instructions +// - // in alphabetical order - // branch to given label - void branch(uint32_t LabelIdx) { - asmjit::Label L(LabelIdx); +// in alphabetical order +// branch to given label +void branch(uint32_t LabelIdx) { + asmjit::Label L(LabelIdx); + _ jmp(L); +} + +void branchLTU(uint32_t LabelIdx) { _ jb(asmjit::Label(LabelIdx)); } + +// branch to label if cond is false +void branchFalse(Operand Cond, uint32_t LabelIdx) { + ZEN_ASSERT(Cond.getType() == WASMType::I32 || + Cond.getType() == WASMType::I64); + asmjit::Label L(LabelIdx); + if (!Cond.isImm()) { + test(Cond); + _ je(L); + } else if (!Cond.getImm()) { _ jmp(L); } +} - void branchLTU(uint32_t LabelIdx) { _ jb(asmjit::Label(LabelIdx)); } - - // branch to label if cond is false - void branchFalse(Operand Cond, uint32_t LabelIdx) { - ZEN_ASSERT(Cond.getType() == WASMType::I32 || - Cond.getType() == WASMType::I64); - asmjit::Label L(LabelIdx); - if (!Cond.isImm()) { - test(Cond); - _ je(L); - } else if (!Cond.getImm()) { - _ jmp(L); - } - } - - // branch to label if cond is true - void branchTrue(Operand Cond, uint32_t LabelIdx) { - ZEN_ASSERT(Cond.getType() == WASMType::I32 || - Cond.getType() == WASMType::I64); - asmjit::Label L(LabelIdx); - if (!Cond.isImm()) { - test(Cond); - _ jne(L); - } else if (Cond.getImm()) { - _ jmp(L); - } - } - - // branch to table index - void handleBranchTableImpl(Operand Index, - const std::vector &LabelIdxs) { - ZEN_ASSERT(Index.getType() == WASMType::I32); - ZEN_ASSERT(LabelIdxs.size() >= 1); - uint32_t Bound = LabelIdxs.size() - 1; // last item is default - // compare index with bound - if (Index.isImm()) { - uint32_t IndexImm = - ((uint32_t)Index.getImm() < Bound) ? Index.getImm() : Bound; - asmjit::Label L(LabelIdxs[IndexImm]); - _ jmp(L); - return; - } - - // load index into register if necessary - auto IndexReg = Index.isReg() - ? Index.getRegRef() - : Layout.getScopedTempReg(); - if (!Index.isReg()) { - _ mov(IndexReg, Index.getMem()); - } - // compare index with bound - _ cmp(IndexReg, Bound); - // jump to default label if index >= bound - _ jae(asmjit::Label(LabelIdxs[Bound])); - - // for small tables, generate if (index == i) goto i; - switch (Bound) { - case 4: - _ cmp(IndexReg, 3); - _ je(asmjit::Label(LabelIdxs[3])); - // fall through - case 3: - _ cmp(IndexReg, 2); - _ je(asmjit::Label(LabelIdxs[2])); - // fall through - case 2: - _ cmp(IndexReg, 1); - _ je(asmjit::Label(LabelIdxs[1])); - // fall through - case 1: - _ cmp(IndexReg, 0); - _ je(asmjit::Label(LabelIdxs[0])); - return; - default: - break; - } - - // jump to entry in jump table - uint32_t Table = createLabel(); - auto JmpReg = Layout.getScopedTempReg(); - _ lea(JmpReg, asmjit::x86::ptr(asmjit::Label(Table))); - _ jmp( - asmjit::x86::Mem(JmpReg, IndexReg, sizeof(uintptr_t) == 4 ? 2 : 3, 0)); - emitJumpTable(Table, LabelIdxs); +// branch to label if cond is true +void branchTrue(Operand Cond, uint32_t LabelIdx) { + ZEN_ASSERT(Cond.getType() == WASMType::I32 || + Cond.getType() == WASMType::I64); + asmjit::Label L(LabelIdx); + if (!Cond.isImm()) { + test(Cond); + _ jne(L); + } else if (Cond.getImm()) { + _ jmp(L); } - - // call - - Operand handleCallImpl(uint32_t FuncIdx, uintptr_t Target, bool IsImport, - bool FarCall, const ArgumentInfo &ArgInfo, - const std::vector &Args) { - return emitCall( - ArgInfo, Args, [this] { saveGasVal(); }, - [&]() { +} + +// branch to table index +void handleBranchTableImpl(Operand Index, + const std::vector &LabelIdxs) { + ZEN_ASSERT(Index.getType() == WASMType::I32); + ZEN_ASSERT(LabelIdxs.size() >= 1); + uint32_t Bound = LabelIdxs.size() - 1; // last item is default + // compare index with bound + if (Index.isImm()) { + uint32_t IndexImm = + ((uint32_t)Index.getImm() < Bound) ? Index.getImm() : Bound; + asmjit::Label L(LabelIdxs[IndexImm]); + _ jmp(L); + return; + } + + // load index into register if necessary + auto IndexReg = Index.isReg() + ? Index.getRegRef() + : Layout.getScopedTempReg(); + if (!Index.isReg()) { + _ mov(IndexReg, Index.getMem()); + } + // compare index with bound + _ cmp(IndexReg, Bound); + // jump to default label if index >= bound + _ jae(asmjit::Label(LabelIdxs[Bound])); + + // for small tables, generate if (index == i) goto i; + switch (Bound) { + case 4: + _ cmp(IndexReg, 3); + _ je(asmjit::Label(LabelIdxs[3])); + // fall through + case 3: + _ cmp(IndexReg, 2); + _ je(asmjit::Label(LabelIdxs[2])); + // fall through + case 2: + _ cmp(IndexReg, 1); + _ je(asmjit::Label(LabelIdxs[1])); + // fall through + case 1: + _ cmp(IndexReg, 0); + _ je(asmjit::Label(LabelIdxs[0])); + return; + default: + break; + } + + // jump to entry in jump table + uint32_t Table = createLabel(); + auto JmpReg = Layout.getScopedTempReg(); + _ lea(JmpReg, asmjit::x86::ptr(asmjit::Label(Table))); + _ jmp(asmjit::x86::Mem(JmpReg, IndexReg, sizeof(uintptr_t) == 4 ? 2 : 3, 0)); + emitJumpTable(Table, LabelIdxs); +} + +// call + +Operand handleCallImpl(uint32_t FuncIdx, uintptr_t Target, bool IsImport, + bool FarCall, const ArgumentInfo &ArgInfo, + const std::vector &Args) { + return emitCall( + ArgInfo, Args, [this] { saveGasVal(); }, + [&]() { #ifdef ZEN_ENABLE_DWASM - // if is_import, update WasmInstance::is_host_api - if (IsImport) { - auto InHostAPIFlagAddr = asmjit::x86::ptr( - ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); - _ mov(InHostAPIFlagAddr, 1); - } + // if is_import, update WasmInstance::is_host_api + if (IsImport) { + auto InHostAPIFlagAddr = asmjit::x86::ptr( + ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); + _ mov(InHostAPIFlagAddr, 1); + } #endif - // generate call, emit call or record relocation for patching - if (Target) { - _ call(Target); - } else { - size_t Offset = _ offset(); - _ dw(0); - _ dd(0); // reserve 6 bytes - ZEN_ASSERT(_ offset() - Offset == 6); - Patcher.addCallEntry(Offset, _ offset() - Offset, FuncIdx); - } - }, - [this, IsImport]() { - loadGasVal(); - checkCallException(IsImport); + // generate call, emit call or record relocation for patching + if (Target) { + _ call(Target); + } else { + size_t Offset = _ offset(); + _ dw(0); + _ dd(0); // reserve 6 bytes + ZEN_ASSERT(_ offset() - Offset == 6); + Patcher.addCallEntry(Offset, _ offset() - Offset, FuncIdx); + } + }, + [this, IsImport]() { + loadGasVal(); + checkCallException(IsImport); #ifdef ZEN_ENABLE_DWASM - // if is_import, update WasmInstance::is_host_api - if (IsImport) { - auto InHostAPIFlagAddr = asmjit::x86::ptr( - ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); - _ mov(InHostAPIFlagAddr, 0); - } + // if is_import, update WasmInstance::is_host_api + if (IsImport) { + auto InHostAPIFlagAddr = asmjit::x86::ptr( + ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); + _ mov(InHostAPIFlagAddr, 0); + } #endif - }); - } - - // call indirect - Operand handleCallIndirectImpl(uint32_t TypeIdx, Operand Callee, - uint32_t TblIdx, const ArgumentInfo &ArgInfo, - const std::vector &Arg) { - uint32_t NumHostAPIs = Ctx->Mod->getNumImportFunctions(); - return emitCall( - ArgInfo, Arg, - // prepare call, check and load callee address into %rax (return - // reg) - [this, NumHostAPIs, TypeIdx, Callee, TblIdx]() { - saveGasVal(); - - auto FuncIdxReg = Layout.getScopedTemp(); - - auto FuncIdx = X64Reg::getRegRef(FuncIdxReg); - - emitTableGet(TblIdx, Callee, FuncIdxReg); - - auto InstReg = ABI.getModuleInstReg(); - - _ cmp(FuncIdx, -1); - _ je(getExceptLabel(ErrorCode::UninitializedElement)); - - constexpr uint32_t Shift0 = 2; - auto IndexesBaseOffset = - Ctx->Mod->getLayout().FuncTypeIndexesBaseOffset; - asmjit::x86::Mem TypeIdxAddr(InstReg, FuncIdx, Shift0, - IndexesBaseOffset, sizeof(TypeIdx)); - - _ cmp(TypeIdxAddr, TypeIdx); - _ jne(getExceptLabel(ErrorCode::IndirectCallTypeMismatch)); + }); +} +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE +// Multi-value call implementation +std::vector +handleCallMultiValueImpl(uint32_t FuncIdx, uintptr_t Target, bool IsImport, + bool FarCall, const ArgumentInfo &ArgInfo, + const std::vector &Args) { + // For multi-value returns, we use a similar approach to single value + // but return multiple operands + uint32_t NumReturns = ArgInfo.getNumReturns(); + std::vector Results; + + // Call the function + Operand PrimaryResult = emitCall( + ArgInfo, Args, [this] { saveGasVal(); }, + [&]() { #ifdef ZEN_ENABLE_DWASM - // check func_idx < import_funcs_count (is_import) - // if is_import, update WasmInstance::is_host_api - auto UpdateFlagLabel = createLabel(); - auto EndUpdateFlagLabel = createLabel(); - - _ cmp(FuncIdx, NumHostAPIs); - branchLTU(UpdateFlagLabel); - branch(EndUpdateFlagLabel); - - bindLabel(UpdateFlagLabel); + if (IsImport) { auto InHostAPIFlagAddr = asmjit::x86::ptr( ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); _ mov(InHostAPIFlagAddr, 1); - - bindLabel(EndUpdateFlagLabel); + } #endif - - auto FuncPtr = ABI.getCallTargetReg(); - constexpr uint32_t Shift = sizeof(void *) == 4 ? 2 : 3; - asmjit::x86::Mem FuncPtrAddr( - InstReg, FuncIdx, Shift, - Ctx->Mod->getLayout().FuncPtrsBaseOffset); - - _ mov(FuncPtr, FuncPtrAddr); - }, - // generate call - [&]() { _ call(ABI.getCallTargetReg()); }, - [this]() { - loadGasVal(); - checkCallIndirectException(); + if (Target) { + _ call(Target); + } else { + size_t Offset = _ offset(); + _ dw(0); + _ dd(0); + ZEN_ASSERT(_ offset() - Offset == 6); + Patcher.addCallEntry(Offset, _ offset() - Offset, FuncIdx); + } + }, + [this, IsImport, &Results, NumReturns, &ArgInfo]() { + loadGasVal(); + checkCallException(IsImport); #ifdef ZEN_ENABLE_DWASM - // because func_idx reg not available in post_call - // so just update the flag directly now(have performance cost) + if (IsImport) { auto InHostAPIFlagAddr = asmjit::x86::ptr( ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); - _ mov(InHostAPIFlagAddr, 0); // update flag back + _ mov(InHostAPIFlagAddr, 0); + } #endif - }); - } - // branch to label if ZF is set - void je(uint32_t LabelIdx) { - asmjit::Label L(LabelIdx); - _ je(L); - } + // Collect multiple return values + if (NumReturns > 0) { + const WASMType *RetTypes = ArgInfo.getReturnTypes(); + Results.reserve(NumReturns); - // branch to label if ZF is 1 - void jne(uint32_t LabelIdx) { - asmjit::Label L(LabelIdx); - _ jne(L); - } - - // return - void handleReturnImpl(Operand Op) { emitEpilog(Op); } + // First result is in return register + if (NumReturns >= 1) { + Results.push_back(getReturnRegOperand(RetTypes[0])); + } + // Second integer result is in RDX, second FP in XMM1 + if (NumReturns >= 2) { + if (RetTypes[1] == WASMType::I32 || RetTypes[1] == WASMType::I64) { + Results.push_back(Operand(RetTypes[1], + ABI.template getParamRegNum(), + Operand::FLAG_NONE)); + } else { + Results.push_back(Operand(RetTypes[1], + ABI.template getParamRegNum(), + Operand::FLAG_NONE)); + } + } + // Additional results would need stack-based return + // For now, we handle up to 2 results + for (uint32_t I = 2; I < NumReturns; ++I) { + // TODO: Handle additional results from stack + Results.push_back(getTempStackOperand(RetTypes[I])); + } + } + }); - // unreachable - void handleUnreachableImpl() { - _ jmp(getExceptLabel(ErrorCode::Unreachable)); + // If we have results from the callback, use those; otherwise use primary + if (Results.empty() && NumReturns > 0) { + Results.push_back(PrimaryResult); } -public: - // - // non-templated method to handle other individual opcode - // + return Results; +} +#endif - // in alphabetical order - - // memory grow - Operand handleMemoryGrowImpl(Operand Op) { - static TypeEntry SigBuf = { - .NumParams = 1, - .NumParamCells = 1, - .NumReturns = 1, - .NumReturnCells = 1, - .ReturnTypes = {WASMType::I32}, - { - .ParamTypesVec = {WASMType::I32}, - }, - .SmallestTypeIdx = uint32_t(-1), - }; - - X64ArgumentInfo ArgInfo(&SigBuf); - std::vector Args({Op}); - return emitCall( - ArgInfo, Args, - []() { - // prepare call, no nothing - }, - [this]() { - // generate call, emit call to wasm_enlarge_memory_wrapper - _ call(uintptr_t(Instance::growInstanceMemoryOnJIT)); - asmjit::Label CallFail = _ newLabel(); - _ cmp(ABI.getRetReg(), 0); - _ jl(CallFail); // less than 0, jump to call fail - // call success, update r13 for mem base, r12 for mem size - auto InstReg = ABI.getModuleInstReg(); - _ mov(ABI.getMemorySizeReg(), - asmjit::x86::Mem(InstReg, - Ctx->Mod->getLayout().MemorySizeOffset)); - _ mov(ABI.getMemoryBaseReg(), - asmjit::x86::Mem(InstReg, - Ctx->Mod->getLayout().MemoryBaseOffset)); - _ bind(CallFail); - }, - [] {}); - } - - // memory size - Operand handleMemorySizeImpl() { - Operand Ret = getTempOperand(WASMType::I32); - const auto &RetReg = - Ret.isReg() ? Ret.getRegRef() - : Layout.getScopedTempReg(); - // Mov r12 to retReg and shift 16 (64KB) - _ mov(RetReg, X64Reg::getRegRef(ABI.getMemorySize())); - _ shr(RetReg, 16); - if (Ret.isMem()) { - // mov retReg to return memory - _ mov(Ret.getMem(), RetReg); - } - return Ret; - } +// call indirect +Operand handleCallIndirectImpl(uint32_t TypeIdx, Operand Callee, + uint32_t TblIdx, const ArgumentInfo &ArgInfo, + const std::vector &Arg) { + uint32_t NumHostAPIs = Ctx->Mod->getNumImportFunctions(); + return emitCall( + ArgInfo, Arg, + // prepare call, check and load callee address into %rax (return + // reg) + [this, NumHostAPIs, TypeIdx, Callee, TblIdx]() { + saveGasVal(); - // select - Operand handleSelectImpl(Operand Cond, Operand LHS, Operand RHS) { - ZEN_ASSERT(LHS.getType() == RHS.getType()); - ZEN_ASSERT(Cond.getType() == WASMType::I32 || - Cond.getType() == WASMType::I64); - switch (LHS.getType()) { - case WASMType::I32: - return selectWithCMov(Cond, LHS, RHS); - case WASMType::I64: - return selectWithCMov(Cond, LHS, RHS); - case WASMType::F32: - return selectWithIf(Cond, LHS, RHS); - case WASMType::F64: - return selectWithIf(Cond, LHS, RHS); - default: - ZEN_ABORT(); - } - } + auto FuncIdxReg = Layout.getScopedTemp(); -private: - // select, return value in type - // test cond - // mov rhs, res - // cmovne lhs, rhs - template - Operand selectWithCMov(Operand Cond, Operand LHS, Operand RHS) { - // handle condition - test(Cond); + auto FuncIdx = X64Reg::getRegRef(FuncIdxReg); - constexpr X64::Type X64Type = getX64TypeFromWASMType(); - typename X64TypeAttr::RegNum ResReg; - bool Exchanged = false; - if (LHS.isReg() && LHS.isTempReg()) { - // reuse lhs as return value - ResReg = (typename X64TypeAttr::RegNum)LHS.getReg(); - Layout.clearAvailReg(ResReg); - } else if (RHS.isReg() && RHS.isTempReg()) { - // reuse rhs as return value - ResReg = (typename X64TypeAttr::RegNum)RHS.getReg(); - Layout.clearAvailReg(ResReg); - Exchanged = true; - } else if (LHS.isImm()) { - // need a scoped temp for result, load lhs to temp at first - ResReg = Layout.getScopedTemp(); - mov(ResReg, LHS); - } else { - // need a scoped temp for result, load rhs to temp at first - ResReg = Layout.getScopedTemp(); - mov(ResReg, RHS); - Exchanged = true; - } + emitTableGet(TblIdx, Callee, FuncIdxReg); - // cmov rhs to lhsReg - Exchanged ? cmovne(ResReg, LHS) - : cmove(ResReg, RHS); + auto InstReg = ABI.getModuleInstReg(); - if (ResReg != Layout.getScopedTemp()) { - return Exchanged ? RHS : LHS; - } + _ cmp(FuncIdx, -1); + _ je(getExceptLabel(ErrorCode::UninitializedElement)); - // store lhsReg to return operand - typename X64TypeAttr::RegNum RetReg; - Operand Ret; - if (Layout.hasAvailTempReg(RetReg)) { - Ret = Operand(Type, RetReg, Operand::FLAG_TEMP_REG); - Layout.clearAvailReg(RetReg); - ASM.mov(X64Reg::getRegRef(RetReg), - X64Reg::getRegRef(ResReg)); - } else { - Ret = getTempStackOperand(Type); - ASM.mov(Ret.getMem(), - X64Reg::getRegRef(ResReg)); - } - return Ret; - } + constexpr uint32_t Shift0 = 2; + auto IndexesBaseOffset = + Ctx->Mod->getLayout().FuncTypeIndexesBaseOffset; + asmjit::x86::Mem TypeIdxAddr(InstReg, FuncIdx, Shift0, + IndexesBaseOffset, sizeof(TypeIdx)); - template - Operand selectWithIf(Operand Cond, Operand LHS, Operand RHS) { - auto Ret = getTempOperand(Type); - constexpr auto X64Type = getX64TypeFromWASMType(); - auto RegNum = Layout.getScopedTemp(); + _ cmp(TypeIdxAddr, TypeIdx); + _ jne(getExceptLabel(ErrorCode::IndirectCallTypeMismatch)); - auto Label = createLabel(); - mov(RegNum, LHS); - test(Cond); - jne(Label); - mov(RegNum, RHS); - bindLabel(Label); +#ifdef ZEN_ENABLE_DWASM + // check func_idx < import_funcs_count (is_import) + // if is_import, update WasmInstance::is_host_api + auto UpdateFlagLabel = createLabel(); + auto EndUpdateFlagLabel = createLabel(); - ZEN_ASSERT(!Ret.isImm()); - mov(Ret, - Operand(Type, RegNum, Operand::FLAG_NONE)); - return Ret; - } + _ cmp(FuncIdx, NumHostAPIs); + branchLTU(UpdateFlagLabel); + branch(EndUpdateFlagLabel); - template - Operand fusedCompareSelectWithIf(Operand LHS, Operand RHS, bool Exchanged) { - auto Ret = getTempOperand(Type); - constexpr auto X64Type = getX64TypeFromWASMType(); - auto RegNum = Layout.getScopedTemp(); + bindLabel(UpdateFlagLabel); + auto InHostAPIFlagAddr = asmjit::x86::ptr( + ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); + _ mov(InHostAPIFlagAddr, 1); - auto Label = createLabel(); - mov(RegNum, LHS); + bindLabel(EndUpdateFlagLabel); +#endif - if (Exchanged) { - constexpr auto ExchangedOpr = getExchangedCompareOperator(); - jmpcc(Label); - } else { - jmpcc(Label); - } + auto FuncPtr = ABI.getCallTargetReg(); + constexpr uint32_t Shift = sizeof(void *) == 4 ? 2 : 3; + asmjit::x86::Mem FuncPtrAddr(InstReg, FuncIdx, Shift, + Ctx->Mod->getLayout().FuncPtrsBaseOffset); - mov(RegNum, RHS); - bindLabel(Label); + _ mov(FuncPtr, FuncPtrAddr); + }, + // generate call + [&]() { _ call(ABI.getCallTargetReg()); }, + [this]() { + loadGasVal(); + checkCallIndirectException(); - ZEN_ASSERT(!Ret.isImm()); - mov(Ret, - Operand(Type, RegNum, Operand::FLAG_NONE)); - return Ret; - } +#ifdef ZEN_ENABLE_DWASM + // because func_idx reg not available in post_call + // so just update the flag directly now(have performance cost) + auto InHostAPIFlagAddr = asmjit::x86::ptr( + ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); + _ mov(InHostAPIFlagAddr, 0); // update flag back +#endif + }); +} -private: - // - // helper functions, move to op_assembler_x64.h? - // +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE +// Multi-value call indirect implementation +std::vector +handleCallIndirectMultiValueImpl(uint32_t TypeIdx, Operand Callee, + uint32_t TblIdx, const ArgumentInfo &ArgInfo, + const std::vector &Arg) { + uint32_t NumReturns = ArgInfo.getNumReturns(); + std::vector Results; - // conditional move value from rhs (reg, mem, imm to lhs (reg only) - template - void cmove(X64::RegNum LHS, Operand RHS) { - typedef typename X64TypeAttr::Type RegType; - const RegType &LHSReg = X64Reg::getRegRef(LHS); - if (RHS.isReg()) { - _ cmove(LHSReg, RHS.getRegRef()); - } else if (RHS.isMem()) { - _ cmove(LHSReg, RHS.getMem()); - } else if (RHS.isImm()) { - auto Tmp = Layout.getScopedTempReg(); - ASM.mov(Tmp, RHS.getImm()); - _ cmove(LHSReg, Tmp); - } else { - ZEN_ABORT(); - } - } + uint32_t NumHostAPIs = Ctx->Mod->getNumImportFunctions(); + Operand PrimaryResult = emitCall( + ArgInfo, Arg, + [this, NumHostAPIs, TypeIdx, Callee, TblIdx]() { + saveGasVal(); - // conditional move value from rhs (reg, mem, imm to lhs (reg only) - template - void cmovne(X64::RegNum LHS, Operand RHS) { - typedef typename X64TypeAttr::Type RegType; - const RegType &LHSReg = X64Reg::getRegRef(LHS); - if (RHS.isReg()) { - _ cmovne(LHSReg, RHS.getRegRef()); - } else if (RHS.isMem()) { - _ cmovne(LHSReg, RHS.getMem()); - } else if (RHS.isImm()) { - auto Tmp = Layout.getScopedTempReg(); - ASM.mov(Tmp, RHS.getImm()); - _ cmovne(LHSReg, Tmp); - } else { - ZEN_ABORT(); - } - } + auto FuncIdxReg = Layout.getScopedTemp(); + auto FuncIdx = X64Reg::getRegRef(FuncIdxReg); - // get an operand in register, using a scoped temp if necessary - template X64::RegNum toReg(Operand Op) { - if (Op.isReg()) { - return Op.getReg(); - } - auto TmpReg = Layout.getScopedTemp(); - mov(TmpReg, Op); - return TmpReg; - } - - // compare value - template - void cmp(Operand LHS, Operand RHS, bool &Exchanged) { - // floating-point constants are stored on stack - ZEN_ASSERT(Ty == X64::I32 || Ty == X64::I64 || - (!LHS.isImm() && !RHS.isImm())); - - // in case the caller forgets to initialize this parameter - Exchanged = false; - - if (LHS.isReg()) { - if (RHS.isReg()) { - ASM.cmp(LHS.getRegRef(), RHS.getRegRef()); - } else if (RHS.isMem()) { - ASM.cmp(LHS.getRegRef(), RHS.getMem()); - } else { - ASM.cmp(LHS.getRegRef(), RHS.getImm()); - } - } else if (LHS.isMem()) { - if (RHS.isReg()) { - Exchanged = true; - ASM.cmp(RHS.getRegRef(), LHS.getMem()); - } else if (RHS.isMem()) { - auto Reg = Layout.getScopedTempReg(); - ASM.mov(Reg, LHS.getMem()); - ASM.cmp(Reg, RHS.getMem()); - } else { - ASM.cmp(LHS.getMem(), RHS.getImm()); - } - } else { - if (RHS.isReg()) { - Exchanged = true; - ASM.cmp(RHS.getRegRef(), LHS.getImm()); - } else if (RHS.isMem()) { - Exchanged = true; - ASM.cmp(RHS.getMem(), LHS.getImm()); - } else { - auto Reg = Layout.getScopedTempReg(); - ASM.mov(Reg, LHS.getImm()); - ASM.cmp(Reg, RHS.getImm()); - } - } - } + emitTableGet(TblIdx, Callee, FuncIdxReg); - // test single value with 0 - template void test(Operand Op) { - if (Op.isReg()) { - auto Reg = Op.getRegRef(); - ASM.test(Reg, Reg); - } else if (Op.isMem()) { - auto Reg = Layout.getScopedTempReg(); - ASM.mov(Reg, Op.getMem()); - ASM.test(Reg, Reg); - } else { - auto Reg = Layout.getScopedTempReg(); - ASM.mov(Reg, Op.getImm()); - ASM.test(Reg, Reg); - } - } + auto InstReg = ABI.getModuleInstReg(); - // test single value with 0 - template void test(Operand Op) { - if (Op.getType() == WASMType::I32) { - test(Op); - } else if (Op.getType() == WASMType::I64) { - test(Op); - } else { - ZEN_ABORT(); - } - } + _ cmp(FuncIdx, -1); + _ je(getExceptLabel(ErrorCode::UninitializedElement)); - // Jmpcc - template void jmpcc(uint32_t LabelIdx) { - constexpr JmpccOperator JmpccOpr = getJmpccOperator(); - JmpccOperatorImpl::emit(ASM, LabelIdx); - } + constexpr uint32_t Shift0 = 2; + auto IndexesBaseOffset = + Ctx->Mod->getLayout().FuncTypeIndexesBaseOffset; + asmjit::x86::Mem TypeIdxAddr(InstReg, FuncIdx, Shift0, + IndexesBaseOffset, sizeof(TypeIdx)); - // Setcc - template void setcc(X64::RegNum RegNum) { - constexpr SetccOperator SetccOpr = getSetccOperator(); - SetccOperatorImpl::emit(ASM, - X64Reg::getRegRef(RegNum)); - } + _ cmp(TypeIdxAddr, TypeIdx); + _ jne(getExceptLabel(ErrorCode::IndirectCallTypeMismatch)); - template Operand floatNeg(Operand Op) { - constexpr auto X64Type = getX64TypeFromWASMType(); - constexpr auto X64IntType = - getX64TypeFromWASMType::IntType>(); +#ifdef ZEN_ENABLE_DWASM + auto UpdateFlagLabel = createLabel(); + auto EndUpdateFlagLabel = createLabel(); - auto Ret = getTempOperand(Type); - auto RegNum = Ret.isReg() - ? Ret.getReg() - : static_cast( - Layout.getScopedTemp()); - mov(RegNum, Op); + _ cmp(FuncIdx, NumHostAPIs); + branchLTU(UpdateFlagLabel); + branch(EndUpdateFlagLabel); - auto ImmReg = Layout.getScopedTempReg(); - auto ImmReg2 = Layout.getScopedTempReg(); + bindLabel(UpdateFlagLabel); + auto InHostAPIFlagAddr = asmjit::x86::ptr( + ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); + _ mov(InHostAPIFlagAddr, 1); - auto SignMask = FloatAttr::SignMask; - _ mov(ImmReg2, SignMask); - ASM.fmov(ImmReg, ImmReg2); - ASM.xor_(X64Reg::getRegRef(RegNum), ImmReg); + bindLabel(EndUpdateFlagLabel); +#endif - if (!Ret.isReg()) { - mov(Ret, - Operand(Type, RegNum, Operand::FLAG_NONE)); - } - return Ret; - } + auto FuncPtr = ABI.getCallTargetReg(); + constexpr uint32_t Shift = sizeof(void *) == 4 ? 2 : 3; + asmjit::x86::Mem FuncPtrAddr(InstReg, FuncIdx, Shift, + Ctx->Mod->getLayout().FuncPtrsBaseOffset); - template Operand floatAbs(Operand Op) { - constexpr auto X64Type = getX64TypeFromWASMType(); - constexpr auto X64IntType = - getX64TypeFromWASMType::IntType>(); + _ mov(FuncPtr, FuncPtrAddr); + }, + [&]() { _ call(ABI.getCallTargetReg()); }, + [this, &Results, NumReturns, &ArgInfo]() { + loadGasVal(); + checkCallIndirectException(); - auto TmpReg = Layout.getScopedTempReg(); - auto TmpRegNum = Layout.getScopedTemp(); - auto TmpIntReg = Layout.getScopedTempReg(); +#ifdef ZEN_ENABLE_DWASM + auto InHostAPIFlagAddr = asmjit::x86::ptr( + ABI.getModuleInstReg(), InHostApiOffset, InHostApiSize); + _ mov(InHostAPIFlagAddr, 0); +#endif - auto Mask = ~FloatAttr::SignMask; - _ mov(TmpIntReg, Mask); - ASM.fmov(TmpReg, TmpIntReg); + // Collect multiple return values + if (NumReturns > 0) { + const WASMType *RetTypes = ArgInfo.getReturnTypes(); + Results.reserve(NumReturns); - if (Op.isReg()) { - ASM.and_(TmpReg, Op.getRegRef()); - } else if (Op.isMem()) { - auto TmpReg2 = Layout.getScopedTemp(); - mov(TmpReg2, Op); - ASM.and_(TmpReg, X64Reg::getRegRef(TmpReg2)); - } else { - ZEN_ABORT(); - } + if (NumReturns >= 1) { + Results.push_back(getReturnRegOperand(RetTypes[0])); + } + if (NumReturns >= 2) { + if (RetTypes[1] == WASMType::I32 || RetTypes[1] == WASMType::I64) { + Results.push_back(Operand(RetTypes[1], + ABI.template getParamRegNum(), + Operand::FLAG_NONE)); + } else { + Results.push_back(Operand(RetTypes[1], + ABI.template getParamRegNum(), + Operand::FLAG_NONE)); + } + } + for (uint32_t I = 2; I < NumReturns; ++I) { + Results.push_back(getTempStackOperand(RetTypes[I])); + } + } + }); - auto Ret = getTempOperand(Type); - mov(Ret, - Operand(Type, TmpRegNum, Operand::FLAG_NONE)); - return Ret; + if (Results.empty() && NumReturns > 0) { + Results.push_back(PrimaryResult); } - template Operand floatSqrt(Operand Op) { - constexpr auto X64Type = getX64TypeFromWASMType(); - auto TmpReg = Layout.getScopedTempReg(); - auto TmpRegNum = Layout.getScopedTemp(); - - if (Op.isReg()) { - ASM.sqrt(TmpReg, Op.getRegRef()); - } else if (Op.isMem()) { - auto TmpReg2 = Layout.getScopedTemp(); - mov(TmpReg2, Op); - ASM.sqrt(TmpReg, X64Reg::getRegRef(TmpReg2)); - } else { - ZEN_ABORT(); - } - - auto Ret = getTempOperand(Type); - mov(Ret, - Operand(Type, TmpRegNum, Operand::FLAG_NONE)); - return Ret; - } + return Results; +} +#endif - template Operand floatRound(Operand Op) { - constexpr auto X64Type = getX64TypeFromWASMType(); - auto TmpReg = Layout.getScopedTempReg(); - auto TmpRegNum = Layout.getScopedTemp(); - - uint8_t Mode = [] { - switch (Opr) { - case UnaryOperator::UO_CEIL: - return 2; - case UnaryOperator::UO_FLOOR: - return 1; - case UnaryOperator::UO_NEAREST: - return 0; - case UnaryOperator::UO_TRUNC: - return 3; - default: - ZEN_ABORT(); +// branch to label if ZF is set +void je(uint32_t LabelIdx) { + asmjit::Label L(LabelIdx); + _ je(L); +} + +// branch to label if ZF is 1 +void jne(uint32_t LabelIdx) { + asmjit::Label L(LabelIdx); + _ jne(L); +} + +// return +void handleReturnImpl(Operand Op) { emitEpilog(Op); } + +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE +void handleReturnMultiValueImpl(const std::vector &Ops) { + // For multi-value returns, we need to handle multiple return values + // ABI for multiple returns: + // - First integer result: RAX + // - Second integer result: RDX + // - First FP result: XMM0 + // - Second FP result: XMM1 + // - Additional results: passed via stack (return buffer) + // NOTE: Currently only supports up to 2 results of each type + if (Ops.empty()) { + emitEpilog(Operand()); + return; + } + if (Ops.size() == 1) { + emitEpilog(Ops[0]); + return; + } + // For multiple return values, use registers for first 2 of each type + uint32_t GpRegIdx = 0; + uint32_t FpRegIdx = 0; + for (uint32_t I = 0; I < Ops.size(); ++I) { + const Operand &Op = Ops[I]; + WASMType Type = Op.getType(); + if (Type == WASMType::I32) { + if (GpRegIdx == 0) { + // First I32 result goes to EAX + mov(ABI.template getRetRegNum(), Op); + GpRegIdx++; + } else if (GpRegIdx == 1) { + // Second I32 result goes to EDX + mov(ABI.template getParamRegNum(), Op); + GpRegIdx++; } - }(); - - if (Op.isReg()) { - if (Type == WASMType::F32) { - _ roundss(TmpReg, Op.getRegRef(), Mode); - } else { - _ roundsd(TmpReg, Op.getRegRef(), Mode); + // Additional I32 results are not supported - skip + } else if (Type == WASMType::I64) { + if (GpRegIdx == 0) { + // First I64 result goes to RAX + mov(ABI.template getRetRegNum(), Op); + GpRegIdx++; + } else if (GpRegIdx == 1) { + // Second I64 result goes to RDX + mov(ABI.template getParamRegNum(), Op); + GpRegIdx++; } - } else if (Op.isMem()) { - auto TmpReg2 = Layout.getScopedTemp(); - mov(TmpReg2, Op); - if (Type == WASMType::F32) { - _ roundss(TmpReg, X64Reg::getRegRef(TmpReg2), Mode); - } else { - _ roundsd(TmpReg, X64Reg::getRegRef(TmpReg2), Mode); + // Additional I64 results are not supported - skip + } else if (Type == WASMType::F32) { + if (FpRegIdx == 0) { + // First F32 result goes to XMM0 + mov(ABI.template getRetRegNum(), Op); + FpRegIdx++; + } else if (FpRegIdx == 1) { + // Second F32 result goes to XMM1 + mov(ABI.template getParamRegNum(), Op); + FpRegIdx++; } - } else { - ZEN_ABORT(); + // Additional F32 results are not supported - skip + } else if (Type == WASMType::F64) { + if (FpRegIdx == 0) { + // First F64 result goes to XMM0 + mov(ABI.template getRetRegNum(), Op); + FpRegIdx++; + } else if (FpRegIdx == 1) { + // Second F64 result goes to XMM1 + mov(ABI.template getParamRegNum(), Op); + FpRegIdx++; + } + // Additional F64 results are not supported - skip } - - auto Ret = getTempOperand(Type); - mov(Ret, - Operand(Type, TmpRegNum, Operand::FLAG_NONE)); - return Ret; } + emitEpilog(Operand()); +} +#endif - // truncate float to signed integer - template - Operand handleFloatToIntImpl(Operand Opnd, std::true_type) { - constexpr auto X64DestType = getX64TypeFromWASMType(); - constexpr auto X64SrcType = getX64TypeFromWASMType(); - - auto Ret = getTempOperand(DestType); - auto RetReg = Ret.isReg() - ? Ret.getRegRef() - : Layout.getScopedTempReg(); - if (!Opnd.isReg()) { - auto RegNum = Layout.getScopedTemp(); - mov(RegNum, Opnd); - Opnd = Operand(SrcType, RegNum, Operand::FLAG_NONE); - } - auto OpndReg = Opnd.getRegRef(); - - ConvertOpImpl::emit(ASM, RetReg, OpndReg); - - auto Finish = _ newLabel(); - _ cmp(RetReg, 1); - _ jno(Finish); - - ASM.cmp(OpndReg, OpndReg); - _ jp(getExceptLabel(ErrorCode::InvalidConversionToInteger)); +// unreachable +void handleUnreachableImpl() { _ jmp(getExceptLabel(ErrorCode::Unreachable)); } - constexpr auto X64IntSrcType = - getX64TypeFromWASMType::IntType>(); - auto TmpFReg = Layout.getScopedTempReg(); - auto TmpIReg = Layout.getScopedTempReg(); +public: +// +// non-templated method to handle other individual opcode +// - auto IntMin = FloatAttr::template int_min(); - _ mov(TmpIReg, IntMin); - ASM.fmov(TmpFReg, TmpIReg); +// in alphabetical order + +// memory grow +Operand handleMemoryGrowImpl(Operand Op) { + static TypeEntry SigBuf; + static bool Initialized = false; + if (!Initialized) { + SigBuf.NumParams = 1; + SigBuf.NumParamCells = 1; + SigBuf.NumReturns = 1; + SigBuf.NumReturnCells = 1; +#ifdef ZEN_ENABLE_WASI_MULTI_VALUE + SigBuf.ReturnTypesVec[0] = WASMType::I32; +#else + SigBuf.ReturnTypes[0] = WASMType::I32; +#endif + SigBuf.ParamTypesVec[0] = WASMType::I32; + SigBuf.SmallestTypeIdx = uint32_t(-1); + Initialized = true; + } + + X64ArgumentInfo ArgInfo(&SigBuf); + std::vector Args({Op}); + return emitCall( + ArgInfo, Args, + []() { + // prepare call, no nothing + }, + [this]() { + // generate call, emit call to wasm_enlarge_memory_wrapper + _ call(uintptr_t(Instance::growInstanceMemoryOnJIT)); + asmjit::Label CallFail = _ newLabel(); + _ cmp(ABI.getRetReg(), 0); + _ jl(CallFail); // less than 0, jump to call fail + // call success, update r13 for mem base, r12 for mem size + auto InstReg = ABI.getModuleInstReg(); + _ mov( + ABI.getMemorySizeReg(), + asmjit::x86::Mem(InstReg, Ctx->Mod->getLayout().MemorySizeOffset)); + _ mov( + ABI.getMemoryBaseReg(), + asmjit::x86::Mem(InstReg, Ctx->Mod->getLayout().MemoryBaseOffset)); + _ bind(CallFail); + }, + [] {}); +} + +// memory size +Operand handleMemorySizeImpl() { + Operand Ret = getTempOperand(WASMType::I32); + const auto &RetReg = + Ret.isReg() ? Ret.getRegRef() + : Layout.getScopedTempReg(); + // Mov r12 to retReg and shift 16 (64KB) + _ mov(RetReg, X64Reg::getRegRef(ABI.getMemorySize())); + _ shr(RetReg, 16); + if (Ret.isMem()) { + // mov retReg to return memory + _ mov(Ret.getMem(), RetReg); + } + return Ret; +} + +// select +Operand handleSelectImpl(Operand Cond, Operand LHS, Operand RHS) { + ZEN_ASSERT(LHS.getType() == RHS.getType()); + ZEN_ASSERT(Cond.getType() == WASMType::I32 || + Cond.getType() == WASMType::I64); + switch (LHS.getType()) { + case WASMType::I32: + return selectWithCMov(Cond, LHS, RHS); + case WASMType::I64: + return selectWithCMov(Cond, LHS, RHS); + case WASMType::F32: + return selectWithIf(Cond, LHS, RHS); + case WASMType::F64: + return selectWithIf(Cond, LHS, RHS); + default: + ZEN_ABORT(); + } +} - ASM.cmp(OpndReg, TmpFReg); - _ jbe(getExceptLabel(ErrorCode::IntegerOverflow)); +private: +// select, return value in type +// test cond +// mov rhs, res +// cmovne lhs, rhs +template +Operand selectWithCMov(Operand Cond, Operand LHS, Operand RHS) { + // handle condition + test(Cond); + + constexpr X64::Type X64Type = getX64TypeFromWASMType(); + typename X64TypeAttr::RegNum ResReg; + bool Exchanged = false; + if (LHS.isReg() && LHS.isTempReg()) { + // reuse lhs as return value + ResReg = (typename X64TypeAttr::RegNum)LHS.getReg(); + Layout.clearAvailReg(ResReg); + } else if (RHS.isReg() && RHS.isTempReg()) { + // reuse rhs as return value + ResReg = (typename X64TypeAttr::RegNum)RHS.getReg(); + Layout.clearAvailReg(ResReg); + Exchanged = true; + } else if (LHS.isImm()) { + // need a scoped temp for result, load lhs to temp at first + ResReg = Layout.getScopedTemp(); + mov(ResReg, LHS); + } else { + // need a scoped temp for result, load rhs to temp at first + ResReg = Layout.getScopedTemp(); + mov(ResReg, RHS); + Exchanged = true; + } + + // cmov rhs to lhsReg + Exchanged ? cmovne(ResReg, LHS) + : cmove(ResReg, RHS); + + if (ResReg != Layout.getScopedTemp()) { + return Exchanged ? RHS : LHS; + } + + // store lhsReg to return operand + typename X64TypeAttr::RegNum RetReg; + Operand Ret; + if (Layout.hasAvailTempReg(RetReg)) { + Ret = Operand(Type, RetReg, Operand::FLAG_TEMP_REG); + Layout.clearAvailReg(RetReg); + ASM.mov(X64Reg::getRegRef(RetReg), + X64Reg::getRegRef(ResReg)); + } else { + Ret = getTempStackOperand(Type); + ASM.mov(Ret.getMem(), X64Reg::getRegRef(ResReg)); + } + return Ret; +} + +template +Operand selectWithIf(Operand Cond, Operand LHS, Operand RHS) { + auto Ret = getTempOperand(Type); + constexpr auto X64Type = getX64TypeFromWASMType(); + auto RegNum = Layout.getScopedTemp(); + + auto Label = createLabel(); + mov(RegNum, LHS); + test(Cond); + jne(Label); + mov(RegNum, RHS); + bindLabel(Label); + + ZEN_ASSERT(!Ret.isImm()); + mov(Ret, Operand(Type, RegNum, Operand::FLAG_NONE)); + return Ret; +} + +template +Operand fusedCompareSelectWithIf(Operand LHS, Operand RHS, bool Exchanged) { + auto Ret = getTempOperand(Type); + constexpr auto X64Type = getX64TypeFromWASMType(); + auto RegNum = Layout.getScopedTemp(); + + auto Label = createLabel(); + mov(RegNum, LHS); + + if (Exchanged) { + constexpr auto ExchangedOpr = getExchangedCompareOperator(); + jmpcc(Label); + } else { + jmpcc(Label); + } + + mov(RegNum, RHS); + bindLabel(Label); + + ZEN_ASSERT(!Ret.isImm()); + mov(Ret, Operand(Type, RegNum, Operand::FLAG_NONE)); + return Ret; +} - ASM.xor_(TmpFReg, TmpFReg); - ASM.cmp(TmpFReg, OpndReg); - _ jb(getExceptLabel(ErrorCode::IntegerOverflow)); +private: +// +// helper functions, move to op_assembler_x64.h? +// - _ bind(Finish); - if (Ret.isMem()) { - ASM.mov(Ret.getMem(), RetReg); +// conditional move value from rhs (reg, mem, imm to lhs (reg only) +template +void cmove(X64::RegNum LHS, Operand RHS) { + typedef typename X64TypeAttr::Type RegType; + const RegType &LHSReg = X64Reg::getRegRef(LHS); + if (RHS.isReg()) { + _ cmove(LHSReg, RHS.getRegRef()); + } else if (RHS.isMem()) { + _ cmove(LHSReg, RHS.getMem()); + } else if (RHS.isImm()) { + auto Tmp = Layout.getScopedTempReg(); + ASM.mov(Tmp, RHS.getImm()); + _ cmove(LHSReg, Tmp); + } else { + ZEN_ABORT(); + } +} + +// conditional move value from rhs (reg, mem, imm to lhs (reg only) +template +void cmovne(X64::RegNum LHS, Operand RHS) { + typedef typename X64TypeAttr::Type RegType; + const RegType &LHSReg = X64Reg::getRegRef(LHS); + if (RHS.isReg()) { + _ cmovne(LHSReg, RHS.getRegRef()); + } else if (RHS.isMem()) { + _ cmovne(LHSReg, RHS.getMem()); + } else if (RHS.isImm()) { + auto Tmp = Layout.getScopedTempReg(); + ASM.mov(Tmp, RHS.getImm()); + _ cmovne(LHSReg, Tmp); + } else { + ZEN_ABORT(); + } +} + +// get an operand in register, using a scoped temp if necessary +template X64::RegNum toReg(Operand Op) { + if (Op.isReg()) { + return Op.getReg(); + } + auto TmpReg = Layout.getScopedTemp(); + mov(TmpReg, Op); + return TmpReg; +} + +// compare value +template +void cmp(Operand LHS, Operand RHS, bool &Exchanged) { + // floating-point constants are stored on stack + ZEN_ASSERT(Ty == X64::I32 || Ty == X64::I64 || + (!LHS.isImm() && !RHS.isImm())); + + // in case the caller forgets to initialize this parameter + Exchanged = false; + + if (LHS.isReg()) { + if (RHS.isReg()) { + ASM.cmp(LHS.getRegRef(), RHS.getRegRef()); + } else if (RHS.isMem()) { + ASM.cmp(LHS.getRegRef(), RHS.getMem()); + } else { + ASM.cmp(LHS.getRegRef(), RHS.getImm()); } - return Ret; + } else if (LHS.isMem()) { + if (RHS.isReg()) { + Exchanged = true; + ASM.cmp(RHS.getRegRef(), LHS.getMem()); + } else if (RHS.isMem()) { + auto Reg = Layout.getScopedTempReg(); + ASM.mov(Reg, LHS.getMem()); + ASM.cmp(Reg, RHS.getMem()); + } else { + ASM.cmp(LHS.getMem(), RHS.getImm()); + } + } else { + if (RHS.isReg()) { + Exchanged = true; + ASM.cmp(RHS.getRegRef(), LHS.getImm()); + } else if (RHS.isMem()) { + Exchanged = true; + ASM.cmp(RHS.getMem(), LHS.getImm()); + } else { + auto Reg = Layout.getScopedTempReg(); + ASM.mov(Reg, LHS.getImm()); + ASM.cmp(Reg, RHS.getImm()); + } + } +} + +// test single value with 0 +template void test(Operand Op) { + if (Op.isReg()) { + auto Reg = Op.getRegRef(); + ASM.test(Reg, Reg); + } else if (Op.isMem()) { + auto Reg = Layout.getScopedTempReg(); + ASM.mov(Reg, Op.getMem()); + ASM.test(Reg, Reg); + } else { + auto Reg = Layout.getScopedTempReg(); + ASM.mov(Reg, Op.getImm()); + ASM.test(Reg, Reg); + } +} + +// test single value with 0 +template void test(Operand Op) { + if (Op.getType() == WASMType::I32) { + test(Op); + } else if (Op.getType() == WASMType::I64) { + test(Op); + } else { + ZEN_ABORT(); + } +} + +// Jmpcc +template void jmpcc(uint32_t LabelIdx) { + constexpr JmpccOperator JmpccOpr = getJmpccOperator(); + JmpccOperatorImpl::emit(ASM, LabelIdx); +} + +// Setcc +template void setcc(X64::RegNum RegNum) { + constexpr SetccOperator SetccOpr = getSetccOperator(); + SetccOperatorImpl::emit(ASM, + X64Reg::getRegRef(RegNum)); +} + +template Operand floatNeg(Operand Op) { + constexpr auto X64Type = getX64TypeFromWASMType(); + constexpr auto X64IntType = + getX64TypeFromWASMType::IntType>(); + + auto Ret = getTempOperand(Type); + auto RegNum = Ret.isReg() + ? Ret.getReg() + : static_cast( + Layout.getScopedTemp()); + mov(RegNum, Op); + + auto ImmReg = Layout.getScopedTempReg(); + auto ImmReg2 = Layout.getScopedTempReg(); + + auto SignMask = FloatAttr::SignMask; + _ mov(ImmReg2, SignMask); + ASM.fmov(ImmReg, ImmReg2); + ASM.xor_(X64Reg::getRegRef(RegNum), ImmReg); + + if (!Ret.isReg()) { + mov(Ret, + Operand(Type, RegNum, Operand::FLAG_NONE)); } - - // truncate float to unsigned integer - template - Operand handleFloatToIntImpl(Operand Op, std::false_type) { - constexpr auto X64DestType = getX64TypeFromWASMType(); - constexpr auto X64SrcType = getX64TypeFromWASMType(); - - auto Ret = getTempOperand(DestType); - auto RetReg = Ret.isReg() - ? Ret.getRegRef() - : Layout.getScopedTempReg(); - if (!Op.isReg()) { - auto RegNum = Layout.getScopedTemp(); - mov(RegNum, Op); - Op = Operand(SrcType, RegNum, Operand::FLAG_NONE); + return Ret; +} + +template Operand floatAbs(Operand Op) { + constexpr auto X64Type = getX64TypeFromWASMType(); + constexpr auto X64IntType = + getX64TypeFromWASMType::IntType>(); + + auto TmpReg = Layout.getScopedTempReg(); + auto TmpRegNum = Layout.getScopedTemp(); + auto TmpIntReg = Layout.getScopedTempReg(); + + auto Mask = ~FloatAttr::SignMask; + _ mov(TmpIntReg, Mask); + ASM.fmov(TmpReg, TmpIntReg); + + if (Op.isReg()) { + ASM.and_(TmpReg, Op.getRegRef()); + } else if (Op.isMem()) { + auto TmpReg2 = Layout.getScopedTemp(); + mov(TmpReg2, Op); + ASM.and_(TmpReg, X64Reg::getRegRef(TmpReg2)); + } else { + ZEN_ABORT(); + } + + auto Ret = getTempOperand(Type); + mov(Ret, + Operand(Type, TmpRegNum, Operand::FLAG_NONE)); + return Ret; +} + +template Operand floatSqrt(Operand Op) { + constexpr auto X64Type = getX64TypeFromWASMType(); + auto TmpReg = Layout.getScopedTempReg(); + auto TmpRegNum = Layout.getScopedTemp(); + + if (Op.isReg()) { + ASM.sqrt(TmpReg, Op.getRegRef()); + } else if (Op.isMem()) { + auto TmpReg2 = Layout.getScopedTemp(); + mov(TmpReg2, Op); + ASM.sqrt(TmpReg, X64Reg::getRegRef(TmpReg2)); + } else { + ZEN_ABORT(); + } + + auto Ret = getTempOperand(Type); + mov(Ret, + Operand(Type, TmpRegNum, Operand::FLAG_NONE)); + return Ret; +} + +template Operand floatRound(Operand Op) { + constexpr auto X64Type = getX64TypeFromWASMType(); + auto TmpReg = Layout.getScopedTempReg(); + auto TmpRegNum = Layout.getScopedTemp(); + + uint8_t Mode = [] { + switch (Opr) { + case UnaryOperator::UO_CEIL: + return 2; + case UnaryOperator::UO_FLOOR: + return 1; + case UnaryOperator::UO_NEAREST: + return 0; + case UnaryOperator::UO_TRUNC: + return 3; + default: + ZEN_ABORT(); } - auto OpndReg = Op.getRegRef(); - - constexpr auto X64IntSrcType = - getX64TypeFromWASMType::IntType>(); - auto TmpFReg = Layout.getScopedTempReg(); - auto TmpIReg = Layout.getScopedTempReg(); - - auto IntMax = FloatAttr::template int_max(); - _ mov(TmpIReg, IntMax); - ASM.fmov(TmpFReg, TmpIReg); - - auto AboveIntMax = _ newLabel(); - ASM.cmp(OpndReg, TmpFReg); - _ jae(AboveIntMax); - _ jp(getExceptLabel(ErrorCode::InvalidConversionToInteger)); - - ConvertOpImpl::emit(ASM, RetReg, OpndReg); + }(); - auto Finish = _ newLabel(); - _ cmp(RetReg, 0); - _ jge(Finish); - _ jmp(getExceptLabel(ErrorCode::IntegerOverflow)); - - _ bind(AboveIntMax); - ASM.sub(OpndReg, TmpFReg); - ConvertOpImpl::emit(ASM, RetReg, OpndReg); - - _ cmp(RetReg, 0); - _ jl(getExceptLabel(ErrorCode::IntegerOverflow)); - - auto TmpIReg2 = Layout.getScopedTempReg(); - _ mov(TmpIReg2, 1UL << (getWASMTypeSize() * CHAR_BIT - 1)); - _ add(RetReg, TmpIReg2); - - _ bind(Finish); - if (!Ret.isReg()) { - ASM.mov(Ret.getMem(), RetReg); + if (Op.isReg()) { + if (Type == WASMType::F32) { + _ roundss(TmpReg, Op.getRegRef(), Mode); + } else { + _ roundsd(TmpReg, Op.getRegRef(), Mode); } - return Ret; - } + } else if (Op.isMem()) { + auto TmpReg2 = Layout.getScopedTemp(); + mov(TmpReg2, Op); + if (Type == WASMType::F32) { + _ roundss(TmpReg, X64Reg::getRegRef(TmpReg2), Mode); + } else { + _ roundsd(TmpReg, X64Reg::getRegRef(TmpReg2), Mode); + } + } else { + ZEN_ABORT(); + } - // load gas value from 'module_inst' to register - void loadGasVal() { - auto InstReg = ABI.getModuleInstReg(); - auto GasAddr = asmjit::x86::ptr(InstReg, GasLeftOffset); - _ mov(ABI.getGasReg(), GasAddr); - } + auto Ret = getTempOperand(Type); + mov(Ret, + Operand(Type, TmpRegNum, Operand::FLAG_NONE)); + return Ret; +} + +// truncate float to signed integer +template +Operand handleFloatToIntImpl(Operand Opnd, std::true_type) { + constexpr auto X64DestType = getX64TypeFromWASMType(); + constexpr auto X64SrcType = getX64TypeFromWASMType(); - // save gas value from register to 'module_inst' - void saveGasVal() { - auto InstReg = ABI.getModuleInstReg(); - auto GasAddr = asmjit::x86::ptr(InstReg, GasLeftOffset); - _ mov(GasAddr, ABI.getGasReg()); - } + auto Ret = getTempOperand(DestType); + auto RetReg = Ret.isReg() + ? Ret.getRegRef() + : Layout.getScopedTempReg(); + if (!Opnd.isReg()) { + auto RegNum = Layout.getScopedTemp(); + mov(RegNum, Opnd); + Opnd = Operand(SrcType, RegNum, Operand::FLAG_NONE); + } + auto OpndReg = Opnd.getRegRef(); + + ConvertOpImpl::emit(ASM, RetReg, OpndReg); + + auto Finish = _ newLabel(); + _ cmp(RetReg, 1); + _ jno(Finish); + + ASM.cmp(OpndReg, OpndReg); + _ jp(getExceptLabel(ErrorCode::InvalidConversionToInteger)); + + constexpr auto X64IntSrcType = + getX64TypeFromWASMType::IntType>(); + auto TmpFReg = Layout.getScopedTempReg(); + auto TmpIReg = Layout.getScopedTempReg(); + + auto IntMin = FloatAttr::template int_min(); + _ mov(TmpIReg, IntMin); + ASM.fmov(TmpFReg, TmpIReg); + + ASM.cmp(OpndReg, TmpFReg); + _ jbe(getExceptLabel(ErrorCode::IntegerOverflow)); + + ASM.xor_(TmpFReg, TmpFReg); + ASM.cmp(TmpFReg, OpndReg); + _ jb(getExceptLabel(ErrorCode::IntegerOverflow)); + + _ bind(Finish); + if (Ret.isMem()) { + ASM.mov(Ret.getMem(), RetReg); + } + return Ret; +} + +// truncate float to unsigned integer +template +Operand handleFloatToIntImpl(Operand Op, std::false_type) { + constexpr auto X64DestType = getX64TypeFromWASMType(); + constexpr auto X64SrcType = getX64TypeFromWASMType(); + + auto Ret = getTempOperand(DestType); + auto RetReg = Ret.isReg() + ? Ret.getRegRef() + : Layout.getScopedTempReg(); + if (!Op.isReg()) { + auto RegNum = Layout.getScopedTemp(); + mov(RegNum, Op); + Op = Operand(SrcType, RegNum, Operand::FLAG_NONE); + } + auto OpndReg = Op.getRegRef(); + + constexpr auto X64IntSrcType = + getX64TypeFromWASMType::IntType>(); + auto TmpFReg = Layout.getScopedTempReg(); + auto TmpIReg = Layout.getScopedTempReg(); + + auto IntMax = FloatAttr::template int_max(); + _ mov(TmpIReg, IntMax); + ASM.fmov(TmpFReg, TmpIReg); + + auto AboveIntMax = _ newLabel(); + ASM.cmp(OpndReg, TmpFReg); + _ jae(AboveIntMax); + _ jp(getExceptLabel(ErrorCode::InvalidConversionToInteger)); + + ConvertOpImpl::emit(ASM, RetReg, OpndReg); + + auto Finish = _ newLabel(); + _ cmp(RetReg, 0); + _ jge(Finish); + _ jmp(getExceptLabel(ErrorCode::IntegerOverflow)); + + _ bind(AboveIntMax); + ASM.sub(OpndReg, TmpFReg); + ConvertOpImpl::emit(ASM, RetReg, OpndReg); + + _ cmp(RetReg, 0); + _ jl(getExceptLabel(ErrorCode::IntegerOverflow)); + + auto TmpIReg2 = Layout.getScopedTempReg(); + _ mov(TmpIReg2, 1UL << (getWASMTypeSize() * CHAR_BIT - 1)); + _ add(RetReg, TmpIReg2); + + _ bind(Finish); + if (!Ret.isReg()) { + ASM.mov(Ret.getMem(), RetReg); + } + return Ret; +} + +// load gas value from 'module_inst' to register +void loadGasVal() { + auto InstReg = ABI.getModuleInstReg(); + auto GasAddr = asmjit::x86::ptr(InstReg, GasLeftOffset); + _ mov(ABI.getGasReg(), GasAddr); +} + +// save gas value from register to 'module_inst' +void saveGasVal() { + auto InstReg = ABI.getModuleInstReg(); + auto GasAddr = asmjit::x86::ptr(InstReg, GasLeftOffset); + _ mov(GasAddr, ABI.getGasReg()); +} public: - void subGasVal(Operand Delta) { - Operand GasReg(WASMType::I64, ABI.getGasRegNum(), Operand::FLAG_NONE); - BinaryOperatorImpl::emit(ASM, GasReg, - Delta); - } - - template - Operand checkedArithmetic(Operand LHS, Operand RHS) { - constexpr auto X64Type = getX64TypeFromWASMType(); - auto OverflowLabel = getExceptLabel(ErrorCode::IntegerOverflow); - X64::RegNum LHSRegNum = -1; - if (Opr == BinaryOperator::BO_MUL) { - LHSRegNum = X64::RAX; - mov(LHSRegNum, LHS); - auto RhsRegNum = toReg(RHS); // avoid RAX - auto RhsReg = X64Reg::getRegRef(RhsRegNum); - if (Sign) - _ imul(RhsReg); - else - _ mul(RhsReg); - _ jo(OverflowLabel); - } else { - LHSRegNum = toReg(LHS); - auto LhsReg = X64Reg::getRegRef(LHSRegNum); - BinaryOperatorImpl::emit(ASM, LhsReg, RHS); - if (Sign) - _ jo(OverflowLabel); - else - _ jb(OverflowLabel); - } - // sign/zero extension - constexpr bool IsSmallType = (getWASMTypeSize() < 4); - if (IsSmallType) { - auto Dest = X64Reg::getRegRef(LHSRegNum); - auto Src = X64Reg::getRegRef(LHSRegNum); - if (Sign) - _ movsx(Dest, Src); - else - _ movzx(Dest, Src); - } - constexpr auto ResType = IsSmallType ? WASMType::I32 : Type; - constexpr auto X64ResType = getX64TypeFromWASMType(); - auto Ret = getTempOperand(ResType); - mov(Ret, LHSRegNum); - return Ret; - } - template - Operand checkedI128Arithmetic(Operand LHSLo, Operand LHSHi, Operand RHSLo, - Operand RHSHi) { - auto LHSLoRegNum = toReg(LHSLo); - auto LHSHiRegNum = toReg(LHSHi); - auto LHSLoReg = X64Reg::getRegRef(LHSLoRegNum); - auto LHSHiReg = X64Reg::getRegRef(LHSHiRegNum); - // NOTE: 'ScopedTempReg2' will be reused subsequently - auto RHSLoRegNum = toReg(RHSLo); - auto RHSLoReg = X64Reg::getRegRef(RHSLoRegNum); - if (Opr == BinaryOperator::BO_ADD) - _ add(LHSLoReg, RHSLoReg); - else - _ sub(LHSLoReg, RHSLoReg); - auto RHSHiRegNum = toReg(RHSHi); - auto RHSHiReg = X64Reg::getRegRef(RHSHiRegNum); - if (Opr == BinaryOperator::BO_ADD) - _ adc(LHSHiReg, RHSHiReg); +void subGasVal(Operand Delta) { + Operand GasReg(WASMType::I64, ABI.getGasRegNum(), Operand::FLAG_NONE); + BinaryOperatorImpl::emit(ASM, GasReg, + Delta); +} + +template +Operand checkedArithmetic(Operand LHS, Operand RHS) { + constexpr auto X64Type = getX64TypeFromWASMType(); + auto OverflowLabel = getExceptLabel(ErrorCode::IntegerOverflow); + X64::RegNum LHSRegNum = -1; + if (Opr == BinaryOperator::BO_MUL) { + LHSRegNum = X64::RAX; + mov(LHSRegNum, LHS); + auto RhsRegNum = toReg(RHS); // avoid RAX + auto RhsReg = X64Reg::getRegRef(RhsRegNum); + if (Sign) + _ imul(RhsReg); else - _ sbb(LHSHiReg, RHSHiReg); - auto OverflowLabel = getExceptLabel(ErrorCode::IntegerOverflow); + _ mul(RhsReg); + _ jo(OverflowLabel); + } else { + LHSRegNum = toReg(LHS); + auto LhsReg = X64Reg::getRegRef(LHSRegNum); + BinaryOperatorImpl::emit(ASM, LhsReg, RHS); if (Sign) _ jo(OverflowLabel); else _ jb(OverflowLabel); - auto Ret = getTempOperand(WASMType::I64); - mov(Ret, LHSHiRegNum); - return Ret; } + // sign/zero extension + constexpr bool IsSmallType = (getWASMTypeSize() < 4); + if (IsSmallType) { + auto Dest = X64Reg::getRegRef(LHSRegNum); + auto Src = X64Reg::getRegRef(LHSRegNum); + if (Sign) + _ movsx(Dest, Src); + else + _ movzx(Dest, Src); + } + constexpr auto ResType = IsSmallType ? WASMType::I32 : Type; + constexpr auto X64ResType = getX64TypeFromWASMType(); + auto Ret = getTempOperand(ResType); + mov(Ret, LHSRegNum); + return Ret; +} +template +Operand checkedI128Arithmetic(Operand LHSLo, Operand LHSHi, Operand RHSLo, + Operand RHSHi) { + auto LHSLoRegNum = toReg(LHSLo); + auto LHSHiRegNum = toReg(LHSHi); + auto LHSLoReg = X64Reg::getRegRef(LHSLoRegNum); + auto LHSHiReg = X64Reg::getRegRef(LHSHiRegNum); + // NOTE: 'ScopedTempReg2' will be reused subsequently + auto RHSLoRegNum = toReg(RHSLo); + auto RHSLoReg = X64Reg::getRegRef(RHSLoRegNum); + if (Opr == BinaryOperator::BO_ADD) + _ add(LHSLoReg, RHSLoReg); + else + _ sub(LHSLoReg, RHSLoReg); + auto RHSHiRegNum = toReg(RHSHi); + auto RHSHiReg = X64Reg::getRegRef(RHSHiRegNum); + if (Opr == BinaryOperator::BO_ADD) + _ adc(LHSHiReg, RHSHiReg); + else + _ sbb(LHSHiReg, RHSHiReg); + auto OverflowLabel = getExceptLabel(ErrorCode::IntegerOverflow); + if (Sign) + _ jo(OverflowLabel); + else + _ jb(OverflowLabel); + auto Ret = getTempOperand(WASMType::I64); + mov(Ret, LHSHiRegNum); + return Ret; +} }; // X64OnePassCodeGenImpl // undefine abbr for assembler diff --git a/src/singlepass/x64/operand.h b/src/singlepass/x64/operand.h index 600d4ec20..42aca3f57 100644 --- a/src/singlepass/x64/operand.h +++ b/src/singlepass/x64/operand.h @@ -117,6 +117,7 @@ class X64InstOperand { return getKind() >= OK_BaseIndexScale1 && getKind() <= OK_BaseIndexScale8; } bool isTempReg() const { return (OpKind & FLAG_TEMP_REG); } + uint8_t getRawOpKind() const { return OpKind; } bool isTempMem() const { return (OpKind & FLAG_TEMP_MEM); } template diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 89a1c2f37..f3728f664 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -30,7 +30,14 @@ if(ZEN_ENABLE_SPEC_TEST) if(ZEN_ENABLE_DWASM) set(SPEC_CATEGORIES "dwasm") else() - set(SPEC_CATEGORIES "spec/test/core" "proposals") + # Core spec tests are designed for base WebAssembly without multi-value. + # When multi-value is enabled, some tests that expect "invalid result arity" + # errors will fail because multi-value allows multiple returns. + if(NOT ZEN_ENABLE_WASI_MULTI_VALUE) + set(SPEC_CATEGORIES "spec/test/core" "proposals") + else() + set(SPEC_CATEGORIES "proposals") + endif() if(ZEN_ENABLE_CHECKED_ARITHMETIC) list(APPEND SPEC_CATEGORIES "chain") endif() @@ -45,6 +52,11 @@ if(ZEN_ENABLE_SPEC_TEST) list(APPEND SPEC_CATEGORIES "spec_extra") + # Multi-value tests (only when ZEN_ENABLE_WASI_MULTI_VALUE is enabled) + if(ZEN_ENABLE_WASI_MULTI_VALUE) + list(APPEND SPEC_CATEGORIES "multi_value") + endif() + foreach(SPEC_CATEGORY ${SPEC_CATEGORIES}) process_spec_files("${SPEC_DIR}/${SPEC_CATEGORY}") endforeach() diff --git a/src/utils/wasm.cpp b/src/utils/wasm.cpp index 03d8b2afe..68ad5085c 100644 --- a/src/utils/wasm.cpp +++ b/src/utils/wasm.cpp @@ -36,7 +36,8 @@ const uint8_t *skipCurrentBlock(const uint8_t *Ip, const uint8_t *End) { case LOOP: case IF: ++NestedLevel; - ++Ip; // skip value_type + // Skip blocktype: can be value type (1 byte) or type index (s33 LEB128) + Ip = skipLEBNumber(Ip, End); break; case ELSE: diff --git a/tests/wast/multi_value/basic_test.wast b/tests/wast/multi_value/basic_test.wast new file mode 100644 index 000000000..6cbf76be5 --- /dev/null +++ b/tests/wast/multi_value/basic_test.wast @@ -0,0 +1,14 @@ +;; Basic test module without multi-value +;; This tests basic functionality + +(module + ;; Simple function to test basic functionality + (func $simple_add (param i32 i32) (result i32) + local.get 0 + local.get 1 + i32.add + ) + + ;; Export function for testing + (export "simple_add" (func $simple_add)) +) \ No newline at end of file diff --git a/tests/wast/multi_value/basic_test_main.wast b/tests/wast/multi_value/basic_test_main.wast new file mode 100644 index 000000000..8fb87c48e --- /dev/null +++ b/tests/wast/multi_value/basic_test_main.wast @@ -0,0 +1,19 @@ +;; Basic test module with main function +;; This tests basic functionality + +(module + ;; Simple function to test basic functionality + (func $simple_add (param i32 i32) (result i32) + local.get 0 + local.get 1 + i32.add + ) + + ;; Main function + (func (export "_start") + i32.const 1 + i32.const 2 + call $simple_add + drop + ) +) \ No newline at end of file diff --git a/tests/wast/multi_value/multi_value.wast b/tests/wast/multi_value/multi_value.wast new file mode 100644 index 000000000..8b2057465 --- /dev/null +++ b/tests/wast/multi_value/multi_value.wast @@ -0,0 +1,147 @@ +;; Multi-value proposal test cases +;; This file tests the multi-value extension for WebAssembly +;; Requires ZEN_ENABLE_WASI_MULTI_VALUE to be enabled +;; +;; Multi-value allows: +;; - Blocks to have multiple results +;; - Loops to have input parameters and multiple results +;; - If-else to produce multiple results +;; - Functions to return multiple values + +(module + ;; ============================================================ + ;; Type definitions for multi-value functions + ;; ============================================================ + (type $pair_i32_i32 (func (result i32 i32))) + (type $triple_i32 (func (result i32 i32 i32))) + (type $swap_i32 (func (param i32 i32) (result i32 i32))) + + ;; ============================================================ + ;; Test 1: Block with multiple results + ;; A block that produces two i32 values on the stack + ;; ============================================================ + (func (export "block_pair") (result i32) + (block (result i32 i32) + i32.const 1 + i32.const 2 + ) + i32.add ;; 1 + 2 = 3 + ) + + ;; ============================================================ + ;; Test 2: Function returning multiple values + ;; ============================================================ + (func $get_pair (type $pair_i32_i32) (result i32 i32) + i32.const 10 + i32.const 20 + ) + + (func (export "call_multi_return") (result i32) + call $get_pair + i32.add ;; 10 + 20 = 30 + ) + + ;; ============================================================ + ;; Test 3: Swap function - takes two values, returns them swapped + ;; ============================================================ + (func $swap (type $swap_i32) (param i32 i32) (result i32 i32) + local.get 1 ;; second param + local.get 0 ;; first param + ) + + (func (export "test_swap") (result i32) + i32.const 5 + i32.const 3 + call $swap + i32.sub ;; 5 - 3 = 2 (swapped: second was 3, first was 5) + ) + + ;; ============================================================ + ;; Test 4: If-else with multiple results + ;; ============================================================ + (func (export "if_pair") (param i32) (result i32) + (if (result i32 i32) (local.get 0) + (then + i32.const 1 + i32.const 2 + ) + (else + i32.const 3 + i32.const 4 + ) + ) + i32.add ;; if true: 1+2=3, if false: 3+4=7 + ) + + ;; ============================================================ + ;; Test 5: Nested blocks with multiple results + ;; ============================================================ + (func (export "nested_block") (result i32) + (block (result i32 i32) + (block (result i32 i32) + i32.const 1 + i32.const 2 + ) + ;; stack now has: 1 2 from inner block + i32.add ;; 1 + 2 = 3 + i32.const 4 + ;; stack now has: 3 4 + ) + i32.mul ;; 3 * 4 = 12 + ) + + ;; ============================================================ + ;; Test 6: Branch with multiple values + ;; Br copies the block's result values to the stack + ;; ============================================================ + (func (export "br_multi") (result i32) + (block (result i32 i32) + i32.const 10 + i32.const 20 + br 0 + ;; unreachable + i32.const 0 + i32.const 0 + ) + i32.add ;; 10 + 20 = 30 + ) + + ;; ============================================================ + ;; Test 7: Block with type index (explicit multi-value type) + ;; ============================================================ + (func (export "block_with_type_idx") (result i32) + (block (type $pair_i32_i32) (result i32 i32) + i32.const 100 + i32.const 200 + ) + i32.sub ;; 100 - 200 = -100 + ) + + ;; ============================================================ + ;; Test 8: Three results + ;; ============================================================ + (func $get_triple (type $triple_i32) (result i32 i32 i32) + i32.const 1 + i32.const 2 + i32.const 3 + ) + + (func (export "three_results") (result i32) + call $get_triple + i32.add ;; 2 + 3 = 5 + i32.add ;; 1 + 5 = 6 + ) + + ;; ============================================================ + ;; Main entry point + ;; ============================================================ + (func (export "_start") + ;; Run all tests and verify results + call $get_pair + drop + drop + + i32.const 42 + drop + ) +) \ No newline at end of file