From e9cc2cd61aae0176fc390c8ab8e27bc61b73d0cb Mon Sep 17 00:00:00 2001 From: "shaoding.zw" Date: Wed, 11 Mar 2026 21:00:10 +0800 Subject: [PATCH] opt: optimize EVM-to-dMIR compilation for reduced instruction count and register pressure - Replace select-chain patterns in SHL/SHR/SAR with scratch-based indexed loads, reducing ~120 dMIR instructions to ~20 per shift operation - Add SBB (subtract-with-borrow) dMIR instruction and use SUB/SBB chains for multi-limb subtraction, cutting x86 output from ~36 to ~12 instructions - Optimize unsigned LT/GT comparisons using SUB/SBB chain + ADC(0,0) to capture the borrow flag, replacing multi-limb select chains - Specialize MUL for small constant operands with a 1-limb fast path (~80% reduction) and full constant folding when both operands are known - Optimize SIGNEXTEND with compile-time constant index fast path and scratch-based sign-bit extraction for runtime indices - Add peephole constant folding for PUSH+arithmetic/bitwise sequences (ADD, SUB, AND, OR, XOR) using intx::uint256 - Remove SHL/SHR/SAR from RA-expensive opcode list; update MIR_OPCODE_WEIGHT for SUB (20->12) and LT/GT (12->8) to reflect efficiency gains Made-with: Cursor --- src/compiler/cgir/lowering.h | 3 + src/compiler/evm_frontend/evm_analyzer.h | 7 +- .../evm_frontend/evm_mir_compiler.cpp | 742 +++++++++--------- src/compiler/evm_frontend/evm_mir_compiler.h | 109 ++- src/compiler/mir/instruction.h | 1 + src/compiler/mir/instructions.cpp | 5 + src/compiler/mir/instructions.h | 23 + src/compiler/mir/opcodes.def | 1 + src/compiler/mir/pass/visitor.h | 4 + src/compiler/target/x86/x86lowering.cpp | 34 + src/compiler/target/x86/x86lowering.h | 1 + 11 files changed, 521 insertions(+), 409 deletions(-) diff --git a/src/compiler/cgir/lowering.h b/src/compiler/cgir/lowering.h index decd58b3f..a56a1a177 100644 --- a/src/compiler/cgir/lowering.h +++ b/src/compiler/cgir/lowering.h @@ -194,6 +194,9 @@ template class CgLowering { case MInstruction::ADC: ResultReg = SELF.lowerAdcExpr(llvm::cast(Inst)); break; + case MInstruction::SBB: + ResultReg = SELF.lowerSbbExpr(llvm::cast(Inst)); + break; case MInstruction::CMP: ResultReg = SELF.lowerCmpExpr(llvm::cast(Inst)); break; diff --git a/src/compiler/evm_frontend/evm_analyzer.h b/src/compiler/evm_frontend/evm_analyzer.h index c5af915c1..90cb72be5 100644 --- a/src/compiler/evm_frontend/evm_analyzer.h +++ b/src/compiler/evm_frontend/evm_analyzer.h @@ -30,11 +30,11 @@ namespace COMPILER { // clang-format off static constexpr uint32_t MIR_OPCODE_WEIGHT[256] = { // 0x00 STOP ADD MUL SUB DIV SDIV MOD SMOD - 5, 12, 80, 20, 5, 5, 5, 5, + 5, 12, 80, 12, 5, 5, 5, 5, // 0x08 ADDMOD MULMOD EXP SIGNEXT (0x0c-0x0f undefined) 5, 5, 5, 20, 2, 2, 2, 2, // 0x10 LT GT SLT SGT EQ ISZERO AND OR - 12, 12, 12, 12, 12, 8, 8, 8, + 8, 8, 12, 12, 12, 8, 8, 8, // 0x18 XOR NOT BYTE SHL SHR SAR CLZ (0x1f) 8, 8, 8, 15, 15, 15, 8, 2, // 0x20 KECCAK256 (0x21-0x2f undefined) @@ -79,9 +79,6 @@ inline bool isRAExpensiveOpcode(uint8_t Op) { switch (Op) { case 0x02: // MUL — ~50-60 MIR, heavy partial-product fan-out case 0x0b: // SIGNEXTEND — ~21 Selects, two dependency chain loops - case 0x1b: // SHL — ~92 Selects, nested J,K loops - case 0x1c: // SHR — ~96 Selects, nested J,K loops - case 0x1d: // SAR — ~52 Selects, sign-extended variant return true; default: return false; diff --git a/src/compiler/evm_frontend/evm_mir_compiler.cpp b/src/compiler/evm_frontend/evm_mir_compiler.cpp index 40f1e4ffb..71487fbec 100644 --- a/src/compiler/evm_frontend/evm_mir_compiler.cpp +++ b/src/compiler/evm_frontend/evm_mir_compiler.cpp @@ -9,6 +9,7 @@ #include "runtime/evm_instance.h" #include "utils/hash_utils.h" #include +#include #ifdef ZEN_ENABLE_EVM_GAS_REGISTER #include "compiler/llvm-prebuild/Target/X86/X86Subtarget.h" @@ -1300,7 +1301,79 @@ MInstruction *EVMMirBuilder::createEvmUmul128Hi(MInstruction *MulInst) { typename EVMMirBuilder::Operand EVMMirBuilder::handleMul(Operand MultiplicandOp, Operand MultiplierOp) { - // Optimized schoolbook multiplication for U256 (4x64-bit limbs) + // Full constant fold: both operands are compile-time constants. + if (MultiplicandOp.isConstant() && MultiplierOp.isConstant()) { + intx::uint256 L = u256FromLimbs(MultiplicandOp.getConstValue()); + intx::uint256 R = u256FromLimbs(MultiplierOp.getConstValue()); + return createU256ConstOperand(L * R); + } + + // Fast path: if one operand is a compile-time constant that fits in 1 limb, + // only 4 partial products are needed instead of the full 10. + auto tryOneLimbMul = [&](Operand &WideOp, const U256Value &ConstVal) + -> std::optional { + if (ConstVal[1] != 0 || ConstVal[2] != 0 || ConstVal[3] != 0) + return std::nullopt; + if (ConstVal[0] == 0) { + MType *I64Type = &Ctx.I64Type; + MInstruction *Zero = createIntConstInstruction(I64Type, 0); + return Operand(U256Inst{Zero, Zero, Zero, Zero}, EVMType::UINT256); + } + if (ConstVal[0] == 1) + return WideOp; + + U256Inst A = extractU256Operand(WideOp); + MType *I64Type = &Ctx.I64Type; + MInstruction *Zero = createIntConstInstruction(I64Type, 0); + MInstruction *BLimb = createIntConstInstruction(I64Type, ConstVal[0]); + + // R[k] = sum of (A[i] * BLimb) contributions where i <= k + MInstruction *PLo[EVM_ELEMENTS_COUNT] = {}; + MInstruction *PHi[EVM_ELEMENTS_COUNT] = {}; + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + PLo[I] = createEvmUmul128(A[I], BLimb); + if (I < EVM_ELEMENTS_COUNT - 1) + PHi[I] = createEvmUmul128Hi(PLo[I]); + } + MInstruction *R0 = PLo[0]; + MInstruction *R1 = + createInstruction(false, OP_add, I64Type, PHi[0], + PLo[1]); + MInstruction *C1 = + createInstruction(false, I64Type, Zero, Zero, Zero); + R1 = protectUnsafeValue(R1, I64Type); + C1 = protectUnsafeValue(C1, I64Type); + + MInstruction *R2 = + createInstruction(false, OP_add, I64Type, PHi[1], + PLo[2]); + MInstruction *C2a = + createInstruction(false, I64Type, Zero, Zero, Zero); + R2 = protectUnsafeValue(R2, I64Type); + C2a = protectUnsafeValue(C2a, I64Type); + R2 = createInstruction(false, OP_add, I64Type, R2, C1); + R2 = protectUnsafeValue(R2, I64Type); + + MInstruction *R3 = + createInstruction(false, OP_add, I64Type, PHi[2], + PLo[3]); + R3 = protectUnsafeValue( + createInstruction(false, OP_add, I64Type, R3, C2a), + I64Type); + + return Operand(U256Inst{R0, R1, R2, R3}, EVMType::UINT256); + }; + + if (MultiplicandOp.isConstant()) { + if (auto R = tryOneLimbMul(MultiplierOp, MultiplicandOp.getConstValue())) + return *R; + } + if (MultiplierOp.isConstant()) { + if (auto R = tryOneLimbMul(MultiplicandOp, MultiplierOp.getConstValue())) + return *R; + } + + // Full schoolbook multiplication for U256 (4x64-bit limbs) // U256 layout: [0]=lo64, [1]=mid-lo, [2]=mid-hi, [3]=hi64 // // For 256-bit truncated result, we need products where i+j < 4: @@ -1786,66 +1859,82 @@ EVMMirBuilder::handleCompareGT_LT(const U256Inst &LHS, const U256Inst &RHS, U256Inst Result = {}; MType *MirI64Type = EVMFrontendContext::getMIRTypeFromEVMType(EVMType::UINT64); - - // Compare from most significant to least significant component - // If components are equal, continue to next - MInstruction *FinalResult = nullptr; MInstruction *Zero = createIntConstInstruction(MirI64Type, 0); - MInstruction *One = createIntConstInstruction(ResultType, 1); - - CmpInstruction::Predicate SignedPredicate; - CmpInstruction::Predicate UnsignedPredicate; - bool IsSigned = false; - if (Operator == CompareOperator::CO_LT) { - SignedPredicate = CmpInstruction::Predicate::ICMP_ULT; - UnsignedPredicate = CmpInstruction::Predicate::ICMP_ULT; - } else if (Operator == CompareOperator::CO_LT_S) { - SignedPredicate = CmpInstruction::Predicate::ICMP_SLT; - UnsignedPredicate = CmpInstruction::Predicate::ICMP_ULT; - IsSigned = true; - } else if (Operator == CompareOperator::CO_GT) { - SignedPredicate = CmpInstruction::Predicate::ICMP_UGT; - UnsignedPredicate = CmpInstruction::Predicate::ICMP_UGT; - } else if (Operator == CompareOperator::CO_GT_S) { - SignedPredicate = CmpInstruction::Predicate::ICMP_SGT; - UnsignedPredicate = CmpInstruction::Predicate::ICMP_UGT; - IsSigned = true; - } else { - ZEN_ASSERT_TODO(); - } - auto EQPredicate = CmpInstruction::Predicate::ICMP_EQ; - // Track if all higher components are equal - MInstruction *AllEqual = nullptr; + bool IsSigned = (Operator == CompareOperator::CO_LT_S || + Operator == CompareOperator::CO_GT_S); - for (int I = EVM_ELEMENTS_COUNT - 1; I >= 0; --I) { - ZEN_ASSERT(LHS[I] && RHS[I]); + if (!IsSigned) { + // Unsigned LT/GT: use SUB/SBB borrow chain + ADC to capture CF. + // A < B iff (A - B) produces a borrow; for GT swap operands. + const U256Inst &SubLHS = + (Operator == CompareOperator::CO_LT) ? LHS : RHS; + const U256Inst &SubRHS = + (Operator == CompareOperator::CO_LT) ? RHS : LHS; + + U256Inst MatLHS, MatRHS; + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + MatLHS[I] = protectUnsafeValue(SubLHS[I], MirI64Type); + MatRHS[I] = protectUnsafeValue(SubRHS[I], MirI64Type); + } - // For signed 256-bit comparison, only the most significant component - // carries the sign bit; lower components are magnitude-only and must - // use unsigned comparison. - auto Pred = (IsSigned && I == EVM_ELEMENTS_COUNT - 1) ? SignedPredicate - : UnsignedPredicate; - MInstruction *CompResult = createInstruction( - false, Pred, ResultType, LHS[I], RHS[I]); - MInstruction *EqResult = createInstruction( - false, EQPredicate, ResultType, LHS[I], RHS[I]); - - if (FinalResult == nullptr) { - FinalResult = CompResult; - AllEqual = EqResult; + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + MInstruction *Diff; + if (I == 0) { + Diff = createInstruction( + false, OP_sub, MirI64Type, MatLHS[I], MatRHS[I]); + } else { + Diff = createInstruction( + false, MirI64Type, MatLHS[I], MatRHS[I], Zero); + } + protectUnsafeValue(Diff, MirI64Type); + } + + // ADC(0, 0) captures the final borrow as 0 or 1. + MInstruction *BorrowResult = + createInstruction(false, MirI64Type, Zero, Zero, Zero); + Result[0] = protectUnsafeValue(BorrowResult, MirI64Type); + } else { + // Signed SLT/SGT: select-chain comparison (sign bit in highest limb). + CmpInstruction::Predicate SignedPredicate; + CmpInstruction::Predicate UnsignedPredicate; + if (Operator == CompareOperator::CO_LT_S) { + SignedPredicate = CmpInstruction::Predicate::ICMP_SLT; + UnsignedPredicate = CmpInstruction::Predicate::ICMP_ULT; } else { - // FinalResult = EqResult_prev ? CompResult : FinalResult - FinalResult = createInstruction( - false, ResultType, AllEqual, CompResult, FinalResult); - // Update AllEqual: AllEqual = AllEqual_prev && EqResult - AllEqual = createInstruction(false, OP_and, ResultType, - AllEqual, EqResult); + SignedPredicate = CmpInstruction::Predicate::ICMP_SGT; + UnsignedPredicate = CmpInstruction::Predicate::ICMP_UGT; } + auto EQPredicate = CmpInstruction::Predicate::ICMP_EQ; + + MInstruction *FinalResult = nullptr; + MInstruction *AllEqual = nullptr; + + for (int I = EVM_ELEMENTS_COUNT - 1; I >= 0; --I) { + ZEN_ASSERT(LHS[I] && RHS[I]); + + auto Pred = (I == EVM_ELEMENTS_COUNT - 1) ? SignedPredicate + : UnsignedPredicate; + MInstruction *CompResult = createInstruction( + false, Pred, ResultType, LHS[I], RHS[I]); + MInstruction *EqResult = createInstruction( + false, EQPredicate, ResultType, LHS[I], RHS[I]); + + if (FinalResult == nullptr) { + FinalResult = CompResult; + AllEqual = EqResult; + } else { + FinalResult = createInstruction( + false, ResultType, AllEqual, CompResult, FinalResult); + AllEqual = createInstruction( + false, OP_and, ResultType, AllEqual, EqResult); + } + } + + ZEN_ASSERT(FinalResult); + Result[0] = protectUnsafeValue(FinalResult, MirI64Type); } - ZEN_ASSERT(FinalResult); - Result[0] = protectUnsafeValue(FinalResult, MirI64Type); for (size_t I = 1; I < EVM_ELEMENTS_COUNT; ++I) { Result[I] = Zero; } @@ -1884,14 +1973,8 @@ EVMMirBuilder::handleLeftShift(const U256Inst &Value, MInstruction *ShiftAmount, U256Inst Result = {}; MInstruction *Zero = createIntConstInstruction(MirI64Type, 0); - MInstruction *One = createIntConstInstruction(MirI64Type, 1); MInstruction *Const64 = createIntConstInstruction(MirI64Type, 64); - // EVM SHL operation: result = value << shift - // DMIR implementation maps 256-bit shift to 4x64-bit components - // shift_mod = shift % 64 (shift amount within 64-bit range) - // shift_comp = shift / 64 (which component index shift from) - // remaining_bits = 64 - shift_mod (remaining bits for carry calculation) MInstruction *ShiftMod64 = createInstruction( false, OP_urem, MirI64Type, ShiftAmount, Const64); MInstruction *ComponentShift = createInstruction( @@ -1899,111 +1982,62 @@ EVMMirBuilder::handleLeftShift(const U256Inst &Value, MInstruction *ShiftAmount, MInstruction *RemainingBits = createInstruction( false, OP_sub, MirI64Type, Const64, ShiftMod64); - MInstruction *MaxIndex = - createIntConstInstruction(MirI64Type, EVM_ELEMENTS_COUNT); + // Store padded array to scratch: [0, 0, 0, 0, Value[0..3]] + // This allows indexed load without bounds-checking select chains. + // SrcIdx = I - ComponentShift + 4 maps to [1..7], all valid. + // PrevIdx = I - ComponentShift + 3 maps to [0..6], all valid. + const int32_t ScratchBase = + zen::runtime::EVMInstance::getHostArgScratchOffset(); + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + setInstanceElement(MirI64Type, Zero, + ScratchBase + static_cast(I * sizeof(uint64_t))); + } + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + setInstanceElement( + MirI64Type, Value[I], + ScratchBase + + static_cast((EVM_ELEMENTS_COUNT + I) * sizeof(uint64_t))); + } + + MInstruction *HasBitShift = createInstruction( + false, CmpInstruction::Predicate::ICMP_NE, &Ctx.I64Type, ShiftMod64, + Zero); + MInstruction *IsValidShift = createInstruction( + false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, RemainingBits, + Const64); + MInstruction *UseCarry = createInstruction( + false, OP_and, MirI64Type, IsValidShift, HasBitShift); - // Process each 64-bit component from low to high - // Example: For shift=72 (1*64 + 8), component_shift=1, shift_mod=8 - // Component 0 gets bits from component -1 (invalid, use 0) - // Component 1 gets bits from component 0 shifted left by 8 - // Component 2 gets bits from component 1 shifted left by 8 - // Component 3 gets bits from component 2 shifted left by 8 for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { - MInstruction *CurrentIdx = createIntConstInstruction(MirI64Type, I); - - // Calculate source component index: current index - component shift - MInstruction *SrcIdx = createInstruction( - false, OP_sub, MirI64Type, CurrentIdx, ComponentShift); - - // Validate source index bounds - // if (0 <= src_idx < EVM_ELEMENTS_COUNT) use Value[src_idx] else 0 - MInstruction *IsValidLow = createInstruction( - false, CmpInstruction::Predicate::ICMP_UGE, &Ctx.I64Type, SrcIdx, Zero); - MInstruction *IsValidHigh = createInstruction( - false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, SrcIdx, - MaxIndex); - MInstruction *IsInBounds = createInstruction( - false, OP_and, MirI64Type, IsValidLow, IsValidHigh); - - // Select source value from the appropriate component - // src_value = (src_idx == J) ? Value[J] : 0 for all J - MInstruction *SrcValue = Zero; - for (size_t J = 0; J < EVM_ELEMENTS_COUNT; ++J) { - MInstruction *TargetIdx = createIntConstInstruction(MirI64Type, J); - MInstruction *IsMatch = createInstruction( - false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, SrcIdx, - TargetIdx); - SrcValue = createInstruction( - false, MirI64Type, IsMatch, Value[J], SrcValue); - } - SrcValue = createInstruction(false, MirI64Type, - IsInBounds, SrcValue, Zero); - - // Calculate previous component index for carry bits - // prev_idx = src_idx - 1 - MInstruction *PrevIdx = createInstruction( - false, OP_sub, MirI64Type, SrcIdx, One); - - // Validate previous component bounds - // if (0 <= prev_idx < EVM_ELEMENTS_COUNT) use Value[prev_idx] else 0 - MInstruction *IsValidPrevLow = createInstruction( - false, CmpInstruction::Predicate::ICMP_UGE, &Ctx.I64Type, PrevIdx, - Zero); - MInstruction *IsValidPrevHigh = createInstruction( - false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, PrevIdx, - MaxIndex); - MInstruction *IsPrevValid = createInstruction( - false, OP_and, MirI64Type, IsValidPrevLow, IsValidPrevHigh); - - // Only calculate carry when there is actual bit-level shifting (ShiftMod64 - // > 0) - // carry_bits = (prev_idx == K) ? (Value[K] >> remaining_bits) : 0 - MInstruction *HasBitShift = createInstruction( - false, CmpInstruction::Predicate::ICMP_NE, &Ctx.I64Type, ShiftMod64, - Zero); - MInstruction *CarryValue = Zero; - for (size_t K = 0; K < EVM_ELEMENTS_COUNT; ++K) { - MInstruction *TargetIdx = createIntConstInstruction(MirI64Type, K); - MInstruction *IsMatch = createInstruction( - false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, PrevIdx, - TargetIdx); - MInstruction *PrevValue = createInstruction( - false, MirI64Type, IsMatch, Value[K], Zero); - PrevValue = createInstruction( - false, MirI64Type, IsPrevValid, PrevValue, Zero); - - // Extract carry bits by shifting right the remaining bits - // Avoid undefined behavior when RemainingBits >= 64 - MInstruction *IsValidShift = createInstruction( - false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, - RemainingBits, Const64); - MInstruction *CarryBits = createInstruction( - false, OP_ushr, MirI64Type, PrevValue, RemainingBits); - // Use carry bits only if shift amount is valid (< 64) AND there is - // bit-level shifting - MInstruction *UseCarry = createInstruction( - false, OP_and, MirI64Type, IsValidShift, HasBitShift); - CarryBits = createInstruction( - false, MirI64Type, UseCarry, CarryBits, Zero); - CarryValue = createInstruction( - false, MirI64Type, IsMatch, CarryBits, CarryValue); - } + // SrcIdx_mapped = I + 4 - ComponentShift (always in [1..7]) + MInstruction *SrcMapped = createInstruction( + false, OP_sub, MirI64Type, + createIntConstInstruction(MirI64Type, I + EVM_ELEMENTS_COUNT), + ComponentShift); + MInstruction *SrcValue = + getInstanceElement(MirI64Type, sizeof(uint64_t), SrcMapped, ScratchBase); + + // PrevIdx_mapped = I + 3 - ComponentShift (always in [0..6]) + MInstruction *PrevMapped = createInstruction( + false, OP_sub, MirI64Type, + createIntConstInstruction(MirI64Type, I + EVM_ELEMENTS_COUNT - 1), + ComponentShift); + MInstruction *PrevValue = getInstanceElement(MirI64Type, sizeof(uint64_t), + PrevMapped, ScratchBase); - // Shift the source value left by the modulo amount - // shifted_value = src_value << shift_mod MInstruction *ShiftedValue = createInstruction( false, OP_shl, MirI64Type, SrcValue, ShiftMod64); - // combined_value = shifted_value | carry_bits + MInstruction *CarryBits = createInstruction( + false, OP_ushr, MirI64Type, PrevValue, RemainingBits); + CarryBits = createInstruction(false, MirI64Type, + UseCarry, CarryBits, Zero); + MInstruction *CombinedValue = createInstruction( - false, OP_or, MirI64Type, ShiftedValue, CarryValue); + false, OP_or, MirI64Type, ShiftedValue, CarryBits); - // Final result selection based on bounds checking and large shift flag - // result[I] = IsLargeShift ? 0 : (IsInBounds ? CombinedValue : 0) MInstruction *FinalValue = createInstruction( - false, MirI64Type, IsLargeShift, Zero, - createInstruction(false, MirI64Type, IsInBounds, - CombinedValue, Zero)); + false, MirI64Type, IsLargeShift, Zero, CombinedValue); Result[I] = protectUnsafeValue(FinalValue, MirI64Type); } @@ -2019,119 +2053,69 @@ EVMMirBuilder::handleLogicalRightShift(const U256Inst &Value, U256Inst Result = {}; MInstruction *Zero = createIntConstInstruction(MirI64Type, 0); - MInstruction *One = createIntConstInstruction(MirI64Type, 1); MInstruction *Const64 = createIntConstInstruction(MirI64Type, 64); - // EVM SHR operation: result = value >> shift (logical right shift) - // DMIR implementation maps 256-bit shift to 4x64-bit components - // shift_mod = shift % 64 (shift amount within 64-bit range) - // shift_comp = shift / 64 (which component index shift from) MInstruction *ShiftMod64 = createInstruction( false, OP_urem, MirI64Type, ShiftAmount, Const64); MInstruction *ComponentShift = createInstruction( false, OP_udiv, MirI64Type, ShiftAmount, Const64); + MInstruction *CarryShiftAmt = createInstruction( + false, OP_sub, MirI64Type, Const64, ShiftMod64); + + // Store padded array to scratch: [Value[0..3], 0, 0, 0, 0] + // SrcIdx = I + ComponentShift maps to [0..6], all valid. + // NextIdx = I + ComponentShift + 1 maps to [1..7], all valid. + const int32_t ScratchBase = + zen::runtime::EVMInstance::getHostArgScratchOffset(); + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + setInstanceElement( + MirI64Type, Value[I], + ScratchBase + static_cast(I * sizeof(uint64_t))); + } + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + setInstanceElement( + MirI64Type, Zero, + ScratchBase + + static_cast((EVM_ELEMENTS_COUNT + I) * sizeof(uint64_t))); + } - MInstruction *MaxIndex = - createIntConstInstruction(MirI64Type, EVM_ELEMENTS_COUNT); + MInstruction *HasBitShift = createInstruction( + false, CmpInstruction::Predicate::ICMP_NE, &Ctx.I64Type, ShiftMod64, + Zero); + MInstruction *IsValidCarryShift = createInstruction( + false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, CarryShiftAmt, + Const64); + MInstruction *UseCarry = createInstruction( + false, OP_and, MirI64Type, IsValidCarryShift, HasBitShift); - // Process each 64-bit component from low to high - // Example: For shift=72 (1*64 + 8), component_shift=1, shift_mod=8 - // Component 0 gets bits from component 1 shifted right by 8 - // Component 1 gets bits from component 2 shifted right by 8 - // Component 2 gets bits from component 3 shifted right by 8 - // Component 3 gets bits from component 4 (invalid, use 0) for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { - MInstruction *CurrentIdx = createIntConstInstruction(MirI64Type, I); - - // Calculate source component index: current index + component shift - MInstruction *SrcIdx = createInstruction( - false, OP_add, MirI64Type, CurrentIdx, ComponentShift); - - // Validate source index bounds - // if (0 <= src_idx < EVM_ELEMENTS_COUNT) use Value[src_idx] else 0 - MInstruction *IsValidLow = createInstruction( - false, CmpInstruction::Predicate::ICMP_UGE, &Ctx.I64Type, SrcIdx, Zero); - MInstruction *IsValidHigh = createInstruction( - false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, SrcIdx, - MaxIndex); - MInstruction *IsInBounds = createInstruction( - false, OP_and, MirI64Type, IsValidLow, IsValidHigh); - - // Select source value from the appropriate component - // src_value = (src_idx == J) ? Value[J] : 0 for all J - MInstruction *SrcValue = Zero; - for (size_t J = 0; J < EVM_ELEMENTS_COUNT; ++J) { - MInstruction *TargetIdx = createIntConstInstruction(MirI64Type, J); - MInstruction *IsMatch = createInstruction( - false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, SrcIdx, - TargetIdx); - SrcValue = createInstruction( - false, MirI64Type, IsMatch, Value[J], SrcValue); - } - SrcValue = createInstruction(false, MirI64Type, - IsInBounds, SrcValue, Zero); - - // Calculate next component index for carry bits - // next_idx = src_idx + 1 - MInstruction *NextIdx = createInstruction( - false, OP_add, MirI64Type, SrcIdx, One); - - // Validate next component bounds - // if (0 <= next_idx < EVM_ELEMENTS_COUNT) use Value[next_idx] else 0 - MInstruction *IsValidNextLow = createInstruction( - false, CmpInstruction::Predicate::ICMP_UGE, &Ctx.I64Type, NextIdx, - Zero); - MInstruction *IsValidNextHigh = createInstruction( - false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, NextIdx, - MaxIndex); - MInstruction *IsNextValid = createInstruction( - false, OP_and, MirI64Type, IsValidNextLow, IsValidNextHigh); - - // Calculate carry bits from the next component - // carry_bits = (next_idx == K) ? (Value[K] << (64 - shift_mod)) : 0 - MInstruction *HasBitShift = createInstruction( - false, CmpInstruction::Predicate::ICMP_NE, &Ctx.I64Type, ShiftMod64, - Zero); - MInstruction *CarryShift = createInstruction( - false, MirI64Type, HasBitShift, - createInstruction(false, OP_sub, MirI64Type, Const64, - ShiftMod64), - Zero); - MInstruction *CarryValue = Zero; - for (size_t K = 0; K < EVM_ELEMENTS_COUNT; ++K) { - MInstruction *TargetIdx = createIntConstInstruction(MirI64Type, K); - MInstruction *IsMatch = createInstruction( - false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, NextIdx, - TargetIdx); - MInstruction *NextValue = createInstruction( - false, MirI64Type, IsMatch, Value[K], Zero); - NextValue = createInstruction( - false, MirI64Type, IsNextValid, NextValue, Zero); - - // Extract carry bits by shifting left the remaining bits - MInstruction *CarryBits = createInstruction( - false, OP_shl, MirI64Type, NextValue, CarryShift); - CarryBits = createInstruction( - false, MirI64Type, HasBitShift, CarryBits, Zero); - CarryValue = createInstruction( - false, MirI64Type, IsMatch, CarryBits, CarryValue); - } + // SrcIdx_mapped = I + ComponentShift (always in [0..6]) + MInstruction *SrcMapped = createInstruction( + false, OP_add, MirI64Type, + createIntConstInstruction(MirI64Type, I), ComponentShift); + MInstruction *SrcValue = + getInstanceElement(MirI64Type, sizeof(uint64_t), SrcMapped, ScratchBase); + + // NextIdx_mapped = I + ComponentShift + 1 (always in [1..7]) + MInstruction *NextMapped = createInstruction( + false, OP_add, MirI64Type, + createIntConstInstruction(MirI64Type, I + 1), ComponentShift); + MInstruction *NextValue = getInstanceElement(MirI64Type, sizeof(uint64_t), + NextMapped, ScratchBase); - // Shift the source value right by the modulo amount - // shifted_value = src_value >> shift_mod MInstruction *ShiftedValue = createInstruction( false, OP_ushr, MirI64Type, SrcValue, ShiftMod64); - // combined_value = shifted_value | carry_bits + MInstruction *CarryBits = createInstruction( + false, OP_shl, MirI64Type, NextValue, CarryShiftAmt); + CarryBits = createInstruction(false, MirI64Type, + UseCarry, CarryBits, Zero); + MInstruction *CombinedValue = createInstruction( - false, OP_or, MirI64Type, ShiftedValue, CarryValue); + false, OP_or, MirI64Type, ShiftedValue, CarryBits); - // Final result selection based on bounds checking and large shift flag - // result[I] = IsLargeShift ? 0 : (IsInBounds ? CombinedValue : 0) MInstruction *FinalValue = createInstruction( - false, MirI64Type, IsLargeShift, Zero, - createInstruction(false, MirI64Type, IsInBounds, - CombinedValue, Zero)); + false, MirI64Type, IsLargeShift, Zero, CombinedValue); Result[I] = protectUnsafeValue(FinalValue, MirI64Type); } @@ -2146,7 +2130,6 @@ EVMMirBuilder::handleArithmeticRightShift(const U256Inst &Value, EVMFrontendContext::getMIRTypeFromEVMType(EVMType::UINT64); U256Inst Result = {}; - // Arithmetic right shift: sign-extend when shift >= 256 MInstruction *Zero = createIntConstInstruction(MirI64Type, 0); MInstruction *AllOnes = createIntConstInstruction(MirI64Type, ~0ULL); @@ -2155,112 +2138,71 @@ EVMMirBuilder::handleArithmeticRightShift(const U256Inst &Value, MInstruction *Const63 = createIntConstInstruction(MirI64Type, 63); MInstruction *SignBit = createInstruction( false, OP_ushr, MirI64Type, HighComponent, Const63); - - // Sign bit is 1 if negative MInstruction *One = createIntConstInstruction(MirI64Type, 1); MInstruction *IsNegative = createInstruction( false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, SignBit, One); - - // Large shift result: all 1s if negative, all 0s if positive - MInstruction *LargeShiftResult = createInstruction( + MInstruction *SignFill = createInstruction( false, MirI64Type, IsNegative, AllOnes, Zero); - // intra-component shifts = shift % 64 - // shift_comp = shift / 64 (which component index shift from) MInstruction *Const64 = createIntConstInstruction(MirI64Type, 64); MInstruction *ShiftMod64 = createInstruction( false, OP_urem, MirI64Type, ShiftAmount, Const64); MInstruction *ComponentShift = createInstruction( false, OP_udiv, MirI64Type, ShiftAmount, Const64); + MInstruction *CarryShiftAmt = createInstruction( + false, OP_sub, MirI64Type, Const64, ShiftMod64); - MInstruction *MaxIndex = - createIntConstInstruction(MirI64Type, EVM_ELEMENTS_COUNT); - - // Process each component from low to high + // Store padded array: [Value[0..3], SignFill, SignFill, SignFill, SignFill] + // Out-of-bounds indexed loads will read sign-extended fill values. + const int32_t ScratchBase = + zen::runtime::EVMInstance::getHostArgScratchOffset(); for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { - MInstruction *CurrentIdx = createIntConstInstruction(MirI64Type, I); - - MInstruction *SrcIdx = createInstruction( - false, OP_add, MirI64Type, CurrentIdx, ComponentShift); - - // Validate source index bounds - // if (0 <= src_idx < EVM_ELEMENTS_COUNT) use Value[src_idx] else 0 - MInstruction *IsValidLow = createInstruction( - false, CmpInstruction::Predicate::ICMP_UGE, &Ctx.I64Type, SrcIdx, Zero); - MInstruction *IsValidHigh = createInstruction( - false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, SrcIdx, - MaxIndex); - MInstruction *IsInBounds = createInstruction( - false, OP_and, MirI64Type, IsValidLow, IsValidHigh); - - // Select source value from the component at SrcIdx index - MInstruction *SrcValue = LargeShiftResult; - for (size_t J = 0; J < EVM_ELEMENTS_COUNT; ++J) { - MInstruction *TargetIdx = createIntConstInstruction(MirI64Type, J); - MInstruction *IsMatch = createInstruction( - false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, SrcIdx, - TargetIdx); - SrcValue = createInstruction( - false, MirI64Type, IsMatch, Value[J], SrcValue); - } - SrcValue = createInstruction( - false, MirI64Type, IsInBounds, SrcValue, LargeShiftResult); - - // Calculate next component index for carry bits - // next_idx = src_idx + 1 - MInstruction *NextIdx = createInstruction( - false, OP_add, MirI64Type, SrcIdx, One); - - // Validate next component bounds - // if (0 <= next_idx < EVM_ELEMENTS_COUNT) use Value[next_idx] else - // sign_extend - MInstruction *IsValidNextLow = createInstruction( - false, CmpInstruction::Predicate::ICMP_UGE, &Ctx.I64Type, NextIdx, - Zero); - MInstruction *IsValidNextHigh = createInstruction( - false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, NextIdx, - MaxIndex); - MInstruction *IsNextValid = createInstruction( - false, OP_and, MirI64Type, IsValidNextLow, IsValidNextHigh); - - // Calculate carry bits from the next component (higher index). - MInstruction *HasShift = createInstruction( - false, CmpInstruction::Predicate::ICMP_NE, &Ctx.I64Type, ShiftMod64, - Zero); - MInstruction *CarryShift = createInstruction( - false, MirI64Type, HasShift, - createInstruction(false, OP_sub, MirI64Type, Const64, - ShiftMod64), - Zero); - MInstruction *NextValue = LargeShiftResult; - for (size_t K = 0; K < EVM_ELEMENTS_COUNT; ++K) { - MInstruction *TargetIdx = createIntConstInstruction(MirI64Type, K); - MInstruction *IsMatch = createInstruction( - false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, NextIdx, - TargetIdx); - NextValue = createInstruction( - false, MirI64Type, IsMatch, Value[K], NextValue); - } - NextValue = createInstruction( - false, MirI64Type, IsNextValid, NextValue, LargeShiftResult); + setInstanceElement( + MirI64Type, Value[I], + ScratchBase + static_cast(I * sizeof(uint64_t))); + } + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + setInstanceElement( + MirI64Type, SignFill, + ScratchBase + + static_cast((EVM_ELEMENTS_COUNT + I) * sizeof(uint64_t))); + } - // Extract low bits from next component as carry. When next_idx is out of - // bounds, use sign-extension bits from LargeShiftResult. - MInstruction *CarryBits = createInstruction( - false, OP_shl, MirI64Type, NextValue, CarryShift); - MInstruction *CarryValue = createInstruction( - false, MirI64Type, HasShift, CarryBits, Zero); + MInstruction *HasBitShift = createInstruction( + false, CmpInstruction::Predicate::ICMP_NE, &Ctx.I64Type, ShiftMod64, + Zero); + MInstruction *IsValidCarryShift = createInstruction( + false, CmpInstruction::Predicate::ICMP_ULT, &Ctx.I64Type, CarryShiftAmt, + Const64); + MInstruction *UseCarry = createInstruction( + false, OP_and, MirI64Type, IsValidCarryShift, HasBitShift); + + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + MInstruction *SrcMapped = createInstruction( + false, OP_add, MirI64Type, + createIntConstInstruction(MirI64Type, I), ComponentShift); + MInstruction *SrcValue = + getInstanceElement(MirI64Type, sizeof(uint64_t), SrcMapped, ScratchBase); + + MInstruction *NextMapped = createInstruction( + false, OP_add, MirI64Type, + createIntConstInstruction(MirI64Type, I + 1), ComponentShift); + MInstruction *NextValue = getInstanceElement(MirI64Type, sizeof(uint64_t), + NextMapped, ScratchBase); - // Use logical right shift; sign extension is handled via LargeShiftResult. MInstruction *ShiftedValue = createInstruction( false, OP_ushr, MirI64Type, SrcValue, ShiftMod64); + + MInstruction *CarryBits = createInstruction( + false, OP_shl, MirI64Type, NextValue, CarryShiftAmt); + CarryBits = createInstruction(false, MirI64Type, + UseCarry, CarryBits, Zero); + MInstruction *CombinedValue = createInstruction( - false, OP_or, MirI64Type, ShiftedValue, CarryValue); + false, OP_or, MirI64Type, ShiftedValue, CarryBits); MInstruction *FinalValue = createInstruction( - false, MirI64Type, IsLargeShift, LargeShiftResult, - createInstruction(false, MirI64Type, IsInBounds, - CombinedValue, LargeShiftResult)); + false, MirI64Type, IsLargeShift, SignFill, CombinedValue); Result[I] = protectUnsafeValue(FinalValue, MirI64Type); } @@ -2346,39 +2288,84 @@ typename EVMMirBuilder::Operand EVMMirBuilder::handleByte(Operand IndexOp, // SIGNEXTEND(31, 0x1234) = 0x1234 (no extension when index >= 31) typename EVMMirBuilder::Operand EVMMirBuilder::handleSignextend(Operand IndexOp, Operand ValueOp) { - U256Inst IndexComponents = extractU256Operand(IndexOp); U256Inst ValueComponents = extractU256Operand(ValueOp); - - // Check if index >= 31 (no sign extension needed) - MInstruction *NoExtension = isU256GreaterOrEqual(IndexComponents, 31); - MType *MirI64Type = EVMFrontendContext::getMIRTypeFromEVMType(EVMType::UINT64); - // Calculate sign bit position: index * 8 + 7 + // Fast path: compile-time constant index (covers the common PUSH+SIGNEXTEND + // pattern emitted by Solidity). No selects required. + if (IndexOp.isConstant()) { + const auto &IdxVal = IndexOp.getConstValue(); + bool NoExtension = + (IdxVal[1] | IdxVal[2] | IdxVal[3]) != 0 || IdxVal[0] >= 31; + if (NoExtension) + return Operand(ValueComponents, EVMType::UINT256); + + uint64_t ByteIdx = IdxVal[0]; + uint64_t SignBitPos = ByteIdx * 8 + 7; + size_t CompIdx = static_cast(SignBitPos / 64); + uint64_t BitOff = SignBitPos % 64; + uint64_t FullMask = + (BitOff == 63) ? ~0ULL : ((1ULL << (BitOff + 1)) - 1); + uint64_t InvMask = ~FullMask; + + MInstruction *Zero = createIntConstInstruction(MirI64Type, 0); + MInstruction *AllOnes = createIntConstInstruction(MirI64Type, ~0ULL); + MInstruction *One = createIntConstInstruction(MirI64Type, 1); + MInstruction *BitOffInst = createIntConstInstruction(MirI64Type, BitOff); + MInstruction *FullMaskInst = + createIntConstInstruction(MirI64Type, FullMask); + MInstruction *InvMaskInst = + createIntConstInstruction(MirI64Type, InvMask); + + MInstruction *SignBit = createInstruction( + false, OP_ushr, MirI64Type, ValueComponents[CompIdx], BitOffInst); + SignBit = createInstruction(false, OP_and, MirI64Type, + SignBit, One); + MInstruction *HighValue = createInstruction( + false, MirI64Type, SignBit, AllOnes, Zero); + + U256Inst Result = {}; + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + if (I < CompIdx) { + Result[I] = ValueComponents[I]; + } else if (I == CompIdx) { + MInstruction *Kept = createInstruction( + false, OP_and, MirI64Type, ValueComponents[I], FullMaskInst); + MInstruction *Ext = createInstruction( + false, OP_and, MirI64Type, InvMaskInst, HighValue); + Result[I] = protectUnsafeValue( + createInstruction(false, OP_or, MirI64Type, Kept, + Ext), + MirI64Type); + } else { + Result[I] = HighValue; + } + } + return Operand(Result, EVMType::UINT256); + } + + // Runtime path: use scratch-based indexed load for sign bit extraction. + U256Inst IndexComponents = extractU256Operand(IndexOp); + MInstruction *NoExtension = isU256GreaterOrEqual(IndexComponents, 31); + + MInstruction *Zero = createIntConstInstruction(MirI64Type, 0); + MInstruction *One = createIntConstInstruction(MirI64Type, 1); + MInstruction *AllOnes = createIntConstInstruction(MirI64Type, ~0ULL); MInstruction *Const8 = createIntConstInstruction(MirI64Type, 8); + MInstruction *Const64 = createIntConstInstruction(MirI64Type, 64); + MInstruction *ByteBitPos = createInstruction( false, OP_mul, MirI64Type, IndexComponents[0], Const8); MInstruction *Const7 = createIntConstInstruction(MirI64Type, 7); MInstruction *SignBitPos = createInstruction( false, OP_add, MirI64Type, ByteBitPos, Const7); - // ComponentIndex = (index * 8 + 7) / 64 - MInstruction *Const64 = createIntConstInstruction(MirI64Type, 64); MInstruction *ComponentIndex = createInstruction( false, OP_udiv, MirI64Type, SignBitPos, Const64); - // BitOffset = (index * 8 + 7) % 64 MInstruction *BitOffset = createInstruction( false, OP_urem, MirI64Type, SignBitPos, Const64); - // Calculate sign extension mask - // FullMask = (1 << (BitOffset + 1)) - 1 - // InvMask = ~FullMask = FullMask ^ AllOnes - // Note: When BitOffset == 63, MaskBits == 64, and (1 << 64) causes undefined - // behavior on x86-64 (SHL masks shift amount to 6 bits, so 1 << 64 becomes - // 1 << 0 = 1). We need to handle this case specially. - MInstruction *One = createIntConstInstruction(MirI64Type, 1); - MInstruction *AllOnes = createIntConstInstruction(MirI64Type, ~0ULL); MInstruction *MaskBits = createInstruction( false, OP_add, MirI64Type, BitOffset, One); MInstruction *Is64 = createInstruction( @@ -2388,31 +2375,29 @@ EVMMirBuilder::handleSignextend(Operand IndexOp, Operand ValueOp) { false, OP_shl, MirI64Type, One, MaskBits); MInstruction *FullMaskNormal = createInstruction( false, OP_sub, MirI64Type, Mask, One); - // When MaskBits == 64, FullMask should be AllOnes (0xFFFFFFFFFFFFFFFF) MInstruction *FullMask = createInstruction( false, MirI64Type, Is64, AllOnes, FullMaskNormal); MInstruction *InvMask = createInstruction( false, OP_xor, MirI64Type, FullMask, AllOnes); - // Extract sign bit - MInstruction *Zero = createIntConstInstruction(MirI64Type, 0); - MInstruction *SignBit = Zero; - for (int I = 0; I < 4; I++) { - MInstruction *IsComp = createInstruction( - false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, ComponentIndex, - createIntConstInstruction(MirI64Type, I)); - // Shifted = ValueComponents[I] >> BitOffset - MInstruction *Shifted = createInstruction( - false, OP_ushr, MirI64Type, ValueComponents[I], BitOffset); - // Bit = Shifted & 1 - MInstruction *Bit = createInstruction( - false, OP_and, MirI64Type, Shifted, One); - // SignBit = IsComp ? Bit : SignBit - SignBit = createInstruction(false, MirI64Type, IsComp, - Bit, SignBit); + // Extract sign bit via scratch-based indexed load instead of select chain. + const int32_t ScratchBase = + zen::runtime::EVMInstance::getHostArgScratchOffset(); + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + setInstanceElement( + MirI64Type, ValueComponents[I], + ScratchBase + static_cast(I * sizeof(uint64_t))); } + MInstruction *SignComp = + getInstanceElement(MirI64Type, sizeof(uint64_t), ComponentIndex, + ScratchBase); + MInstruction *Shifted = createInstruction( + false, OP_ushr, MirI64Type, SignComp, BitOffset); + MInstruction *SignBit = createInstruction( + false, OP_and, MirI64Type, Shifted, One); + MInstruction *HighValue = createInstruction( + false, MirI64Type, SignBit, AllOnes, Zero); - // Create sign extension for each component U256Inst ResultComponents = {}; for (int I = 0; I < 4; I++) { MInstruction *CompIdx = createIntConstInstruction(MirI64Type, I); @@ -2423,11 +2408,6 @@ EVMMirBuilder::handleSignextend(Operand IndexOp, Operand ValueOp) { false, CmpInstruction::Predicate::ICMP_EQ, &Ctx.I64Type, CompIdx, ComponentIndex); - // For components above sign bit: all 1s if negative, all 0s if positive - MInstruction *HighValue = createInstruction( - false, MirI64Type, SignBit, AllOnes, Zero); - - // For sign component: apply mask and sign extension MInstruction *SignCompValue = createInstruction( false, OP_and, MirI64Type, ValueComponents[I], FullMask); MInstruction *SignExtBits = createInstruction( @@ -2435,13 +2415,11 @@ EVMMirBuilder::handleSignextend(Operand IndexOp, Operand ValueOp) { MInstruction *ExtendedSignComp = createInstruction( false, OP_or, MirI64Type, SignCompValue, SignExtBits); - // Select appropriate value based on position relative to sign bit MInstruction *ComponentResult = createInstruction( false, MirI64Type, IsAbove, HighValue, createInstruction( false, MirI64Type, IsEqual, ExtendedSignComp, ValueComponents[I])); - // If index >= 31, use original value; otherwise use sign-extended value ResultComponents[I] = protectUnsafeValue(createInstruction( false, MirI64Type, NoExtension, diff --git a/src/compiler/evm_frontend/evm_mir_compiler.h b/src/compiler/evm_frontend/evm_mir_compiler.h index 32687f575..9bf95c178 100644 --- a/src/compiler/evm_frontend/evm_mir_compiler.h +++ b/src/compiler/evm_frontend/evm_mir_compiler.h @@ -224,6 +224,33 @@ class EVMMirBuilder final { template Operand handleBinaryArithmetic(const Operand &LHSOp, const Operand &RHSOp) { + // Constant folding: if both operands are compile-time constants, compute + // the result directly and return a constant Operand. + if (LHSOp.isConstant() && RHSOp.isConstant()) { + intx::uint256 L = u256FromLimbs(LHSOp.getConstValue()); + intx::uint256 R = u256FromLimbs(RHSOp.getConstValue()); + intx::uint256 Res; + if constexpr (Operator == BinaryOperator::BO_ADD) + Res = L + R; + else if constexpr (Operator == BinaryOperator::BO_SUB) + Res = L - R; + else + ZEN_ASSERT_TODO(); + return createU256ConstOperand(Res); + } + + // Identity / annihilation shortcuts for single constant operand. + if constexpr (Operator == BinaryOperator::BO_ADD) { + if (LHSOp.isConstant() && isU256Zero(LHSOp.getConstValue())) + return RHSOp; + if (RHSOp.isConstant() && isU256Zero(RHSOp.getConstValue())) + return LHSOp; + } + if constexpr (Operator == BinaryOperator::BO_SUB) { + if (RHSOp.isConstant() && isU256Zero(RHSOp.getConstValue())) + return LHSOp; + } + U256Inst Result = {}; U256Inst LHS = extractU256Operand(LHSOp); U256Inst RHS = extractU256Operand(RHSOp); @@ -266,31 +293,29 @@ class EVMMirBuilder final { } } } else if constexpr (Operator == BinaryOperator::BO_SUB) { + // Borrow placeholder (not consumed by x86 lowering; the hardware CF + // from the preceding SUB/SBB is used directly, mirroring the ADC + // carry chain pattern). MInstruction *Borrow = createIntConstInstruction(MirI64Type, 0); + // Pre-materialize all operand components into variables before the + // SUB/SBB borrow chain. This prevents lazy expression lowering from + // emitting flag-clobbering instructions between the SUB and SBB + // instructions that form the borrow chain. for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { - // Sub: LHS[I] - RHS[I] - Borrow - MInstruction *Diff1 = createInstruction( - false, OP_sub, MirI64Type, LHS[I], RHS[I]); - MInstruction *Diff2 = createInstruction( - false, OP_sub, MirI64Type, Diff1, Borrow); - - Result[I] = protectUnsafeValue(Diff2, MirI64Type); - - // (LHS[I] < RHS[I]) || (Diff1 < Borrow) - if (I < EVM_ELEMENTS_COUNT - 1) { - auto LTPredicate = CmpInstruction::Predicate::ICMP_ULT; - MInstruction *Borrow1 = createInstruction( - false, LTPredicate, &Ctx.I64Type, LHS[I], RHS[I]); - MInstruction *Borrow2 = createInstruction( - false, LTPredicate, &Ctx.I64Type, Diff1, Borrow); - // NOLINTBEGIN(readability-identifier-naming) - MInstruction *Borrow1_64 = zeroExtendToI64(Borrow1); - MInstruction *Borrow2_64 = zeroExtendToI64(Borrow2); - // NOLINTEND(readability-identifier-naming) - - Borrow = createInstruction( - false, OP_or, MirI64Type, Borrow1_64, Borrow2_64); + LHS[I] = protectUnsafeValue(LHS[I], MirI64Type); + RHS[I] = protectUnsafeValue(RHS[I], MirI64Type); + } + + for (size_t I = 0; I < EVM_ELEMENTS_COUNT; ++I) { + if (I == 0) { + MInstruction *LocalResult = createInstruction( + false, OP_sub, MirI64Type, LHS[I], RHS[I]); + Result[I] = protectUnsafeValue(LocalResult, MirI64Type); + } else { + MInstruction *LocalResult = createInstruction( + false, MirI64Type, LHS[I], RHS[I], Borrow); + Result[I] = protectUnsafeValue(LocalResult, MirI64Type); } } } else { @@ -317,6 +342,34 @@ class EVMMirBuilder final { // EVM bitwise opcode: and, or, xor template Operand handleBitwiseOp(const Operand &LHSOp, const Operand &RHSOp) { + if (LHSOp.isConstant() && RHSOp.isConstant()) { + intx::uint256 L = u256FromLimbs(LHSOp.getConstValue()); + intx::uint256 R = u256FromLimbs(RHSOp.getConstValue()); + intx::uint256 Res; + if constexpr (Operator == BinaryOperator::BO_AND) + Res = L & R; + else if constexpr (Operator == BinaryOperator::BO_OR) + Res = L | R; + else if constexpr (Operator == BinaryOperator::BO_XOR) + Res = L ^ R; + else + ZEN_ASSERT_TODO(); + return createU256ConstOperand(Res); + } + // AND with 0 -> 0, OR/XOR with 0 -> identity + if constexpr (Operator == BinaryOperator::BO_AND) { + if ((LHSOp.isConstant() && isU256Zero(LHSOp.getConstValue())) || + (RHSOp.isConstant() && isU256Zero(RHSOp.getConstValue()))) + return createU256ConstOperand(intx::uint256{0}); + } + if constexpr (Operator == BinaryOperator::BO_OR || + Operator == BinaryOperator::BO_XOR) { + if (LHSOp.isConstant() && isU256Zero(LHSOp.getConstValue())) + return RHSOp; + if (RHSOp.isConstant() && isU256Zero(RHSOp.getConstValue())) + return LHSOp; + } + U256Inst Result = {}; U256Inst LHS = extractU256Operand(LHSOp); U256Inst RHS = extractU256Operand(RHSOp); @@ -482,6 +535,18 @@ class EVMMirBuilder final { MInstruction *getInstanceStackPeekInt(int32_t IndexFromTop); void drainGas(); + static intx::uint256 u256FromLimbs(const U256Value &V) { + intx::uint256 R = 0; + for (int I = EVM_ELEMENTS_COUNT - 1; I >= 0; --I) { + R = (R << 64) | intx::uint256{V[static_cast(I)]}; + } + return R; + } + + static bool isU256Zero(const U256Value &V) { + return (V[0] | V[1] | V[2] | V[3]) == 0; + } + // Create a full U256 operand from intx::uint256 value Operand createU256ConstOperand(const intx::uint256 &V); diff --git a/src/compiler/mir/instruction.h b/src/compiler/mir/instruction.h index df470ea93..7b9859e42 100644 --- a/src/compiler/mir/instruction.h +++ b/src/compiler/mir/instruction.h @@ -23,6 +23,7 @@ class MInstruction : public NonCopyable { UNARY, BINARY, ADC, + SBB, CMP, CONVERSION, SELECT, diff --git a/src/compiler/mir/instructions.cpp b/src/compiler/mir/instructions.cpp index 5e4b821a2..9805bb77f 100644 --- a/src/compiler/mir/instructions.cpp +++ b/src/compiler/mir/instructions.cpp @@ -86,6 +86,11 @@ void MInstruction::print(llvm::raw_ostream &OS) const { << getOperand<2>() << ')'; break; } + case SBB: { + OS << "sbb (" << getOperand<0>() << ", " << getOperand<1>() << ", " + << getOperand<2>() << ')'; + break; + } case BR: { auto *br = llvm::cast(this); OS << "br @" << br->getTargetBlock()->getIdx() << '\n'; diff --git a/src/compiler/mir/instructions.h b/src/compiler/mir/instructions.h index 565edd706..1a515d926 100644 --- a/src/compiler/mir/instructions.h +++ b/src/compiler/mir/instructions.h @@ -94,6 +94,29 @@ class AdcInstruction : public FixedOperandInstruction<3> { } }; +class SbbInstruction : public FixedOperandInstruction<3> { +public: + template + static SbbInstruction *create(Arguments &&...Args) { + return FixedOperandInstruction::create( + std::forward(Args)...); + } + + static bool classof(const MInstruction *Inst) { + return Inst->getOpcode() == OP_sbb; + } + +private: + friend class FixedOperandInstruction; + SbbInstruction(MType *Type, MInstruction *Operand1, MInstruction *Operand2, + MInstruction *Borrow) + : FixedOperandInstruction(MInstruction::SBB, OP_sbb, 3, Type) { + setOperand<0>(Operand1); + setOperand<1>(Operand2); + setOperand<2>(Borrow); + } +}; + class UnaryInstruction : public FixedOperandInstruction<1> { public: template diff --git a/src/compiler/mir/opcodes.def b/src/compiler/mir/opcodes.def index 030574692..58f93a063 100644 --- a/src/compiler/mir/opcodes.def +++ b/src/compiler/mir/opcodes.def @@ -58,6 +58,7 @@ OPCODE(dread) // OP_OTHER_EXPR_START OPCODE(const) OPCODE(cmp) OPCODE(adc) +OPCODE(sbb) OPCODE(select) OPCODE(load) OPCODE(wasm_sadd128_overflow) diff --git a/src/compiler/mir/pass/visitor.h b/src/compiler/mir/pass/visitor.h index 1adfa8570..074a05a0d 100644 --- a/src/compiler/mir/pass/visitor.h +++ b/src/compiler/mir/pass/visitor.h @@ -41,6 +41,9 @@ class MVisitor { case MInstruction::ADC: visitAdcInstruction(static_cast(I)); break; + case MInstruction::SBB: + visitSbbInstruction(static_cast(I)); + break; case MInstruction::OVERFLOW_I128_BINARY: visitWasmOverflowI128BinaryInstruction( static_cast(I)); @@ -132,6 +135,7 @@ class MVisitor { virtual void visitBinaryInstruction(BinaryInstruction &I) { VISIT_OPERAND_2 } virtual void visitCmpInstruction(CmpInstruction &I) { VISIT_OPERAND_2 } virtual void visitAdcInstruction(AdcInstruction &I) { VISIT_OPERAND_3 } + virtual void visitSbbInstruction(SbbInstruction &I) { VISIT_OPERAND_3 } virtual void visitSelectInstruction(SelectInstruction &I) { VISIT_OPERAND_3 } virtual void visitDassignInstruction(DassignInstruction &I) { VISIT_OPERAND_1 diff --git a/src/compiler/target/x86/x86lowering.cpp b/src/compiler/target/x86/x86lowering.cpp index a72a516b6..83ba9f62e 100644 --- a/src/compiler/target/x86/x86lowering.cpp +++ b/src/compiler/target/x86/x86lowering.cpp @@ -942,6 +942,40 @@ CgRegister X86CgLowering::lowerAdcExpr(const AdcInstruction &Inst) { } } +CgRegister X86CgLowering::lowerSbbExpr(const SbbInstruction &Inst) { + const MInstruction *LHS = Inst.getOperand<0>(); + const MInstruction *RHS = Inst.getOperand<1>(); + + MVT VT = getMVT(*Inst.getType()); + ZEN_ASSERT(VT.isInteger()); + const TargetRegisterClass *RC = TLI.getRegClassFor(VT); + + CgRegister LHSReg = lowerExpr(*LHS); + CgRegister RHSReg = lowerExpr(*RHS); + + CgRegister DiffReg = fastEmitCopy(RC, LHSReg); + switch (VT.SimpleTy) { + case MVT::i8: + MF->createCgInstruction(*CurBB, TII.get(X86::SBB8rr), DiffReg, RHSReg, + DiffReg); + return DiffReg; + case MVT::i16: + MF->createCgInstruction(*CurBB, TII.get(X86::SBB16rr), DiffReg, RHSReg, + DiffReg); + return DiffReg; + case MVT::i32: + MF->createCgInstruction(*CurBB, TII.get(X86::SBB32rr), DiffReg, RHSReg, + DiffReg); + return DiffReg; + case MVT::i64: + MF->createCgInstruction(*CurBB, TII.get(X86::SBB64rr), DiffReg, RHSReg, + DiffReg); + return DiffReg; + default: + throw getError(ErrorCode::NoMatchedInstruction); + } +} + CgRegister X86CgLowering::lowerEvmUmul128Expr(const EvmUmul128Instruction &Inst) { // 64x64->128 bit multiplication using x86 MUL64r diff --git a/src/compiler/target/x86/x86lowering.h b/src/compiler/target/x86/x86lowering.h index d7c67cb82..fb3bb68f9 100644 --- a/src/compiler/target/x86/x86lowering.h +++ b/src/compiler/target/x86/x86lowering.h @@ -72,6 +72,7 @@ class X86CgLowering : public CgLowering { CgRegister lowerEvmUmul128Expr(const EvmUmul128Instruction &Inst); CgRegister lowerEvmUmul128HiExpr(const EvmUmul128HiInstruction &Inst); CgRegister lowerAdcExpr(const AdcInstruction &Inst); + CgRegister lowerSbbExpr(const SbbInstruction &Inst); // ==================== Memory Instructions ====================