Skip to content

[VPlan] Thread scalar types through VPReplicateRecipe. (NFC)#199379

Open
fhahn wants to merge 2 commits into
llvm:mainfrom
fhahn:vplan-scalar-type-vpreplicaterecipe
Open

[VPlan] Thread scalar types through VPReplicateRecipe. (NFC)#199379
fhahn wants to merge 2 commits into
llvm:mainfrom
fhahn:vplan-scalar-type-vpreplicaterecipe

Conversation

@fhahn
Copy link
Copy Markdown
Contributor

@fhahn fhahn commented May 23, 2026

Update VPReplicateRecipe to populate VPSingleDefValue's scalar
type. For most opcodes, the scalar type is determine from the operands,
via computeScalarTypeForInstruction (from
#199378).
For some opcodes, like Loads and casts, the type must be
provided explicitly.

Depends on #199378 (included in
PR).

@llvmorg-github-actions
Copy link
Copy Markdown

llvmorg-github-actions Bot commented May 23, 2026

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

Update VPReplicateRecipe to populate VPSingleDefValue's scalar
type. For most opcodes, the scalar type is determine from the operands,
via computeScalarTypeForInstruction (from
#199378).
For some opcodes, like Loads and casts, the type must be
provided explicitly.

Depends on #199378 (included in
PR).


Patch is 28.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/199379.diff

8 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+14-7)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+25-13)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+3-177)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.h (-4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp (+5-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+102-6)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+3-4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp (+23)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index adfe28e679a37..0ed9f8b5c1988 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -197,9 +197,10 @@ class VPBuilder {
                               const VPIRFlags &Flags = {},
                               const VPIRMetadata &MD = {},
                               DebugLoc DL = DebugLoc::getUnknown(),
-                              const Twine &Name = "") {
+                              const Twine &Name = "",
+                              Type *ResultTy = nullptr) {
     VPInstruction *NewVPInst = tryInsertInstruction(
-        new VPInstruction(Opcode, Operands, Flags, MD, DL, Name));
+        new VPInstruction(Opcode, Operands, Flags, MD, DL, Name, ResultTy));
     NewVPInst->setUnderlyingValue(Inst);
     return NewVPInst;
   }
@@ -226,15 +227,19 @@ class VPBuilder {
   VPInstruction *createFirstActiveLane(ArrayRef<VPValue *> Masks,
                                        DebugLoc DL = DebugLoc::getUnknown(),
                                        const Twine &Name = "") {
+    VPlan &Plan = getPlan();
+    Type *IndexTy = Plan.getDataLayout().getIndexType(Plan.getContext(), 0);
     return tryInsertInstruction(new VPInstruction(
-        VPInstruction::FirstActiveLane, Masks, {}, {}, DL, Name));
+        VPInstruction::FirstActiveLane, Masks, {}, {}, DL, Name, IndexTy));
   }
 
   VPInstruction *createLastActiveLane(ArrayRef<VPValue *> Masks,
                                       DebugLoc DL = DebugLoc::getUnknown(),
                                       const Twine &Name = "") {
-    return tryInsertInstruction(new VPInstruction(VPInstruction::LastActiveLane,
-                                                  Masks, {}, {}, DL, Name));
+    VPlan &Plan = getPlan();
+    Type *IndexTy = Plan.getDataLayout().getIndexType(Plan.getContext(), 0);
+    return tryInsertInstruction(new VPInstruction(
+        VPInstruction::LastActiveLane, Masks, {}, {}, DL, Name, IndexTy));
   }
 
   VPInstruction *createOverflowingOp(
@@ -359,8 +364,10 @@ class VPBuilder {
 
   VPPhi *createScalarPhi(ArrayRef<VPValue *> IncomingValues,
                          DebugLoc DL = DebugLoc::getUnknown(),
-                         const Twine &Name = "", const VPIRFlags &Flags = {}) {
-    return tryInsertInstruction(new VPPhi(IncomingValues, Flags, DL, Name));
+                         const Twine &Name = "", const VPIRFlags &Flags = {},
+                         Type *ResultTy = nullptr) {
+    return tryInsertInstruction(
+        new VPPhi(IncomingValues, Flags, DL, Name, ResultTy));
   }
 
   VPWidenPHIRecipe *createWidenPhi(ArrayRef<VPValue *> IncomingValues,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 7025f39decaf5..7d85fc172f3a2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -604,6 +604,10 @@ class LLVM_ABI_FOR_TEST VPRecipeBase
 /// types.
 LLVM_ABI Type *getScalarTypeOrInfer(VPValue *V);
 
+/// Compute the scalar result type for an IR \p Opcode given \p Operands.
+LLVM_ABI Type *computeScalarTypeForInstruction(unsigned Opcode,
+                                               ArrayRef<VPValue *> Operands);
+
 /// VPSingleDefRecipe is a base class for recipes that model a sequence of one
 /// or more output IR that define a single result VPValue. Note that
 /// VPSingleDefRecipe must inherit from VPRecipeBase before VPSingleDefValue.
@@ -1393,15 +1397,19 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
 public:
   VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
                 const VPIRFlags &Flags = {}, const VPIRMetadata &MD = {},
-                DebugLoc DL = DebugLoc::getUnknown(), const Twine &Name = "");
+                DebugLoc DL = DebugLoc::getUnknown(), const Twine &Name = "",
+                Type *ResultTy = nullptr);
 
   VP_CLASSOF_IMPL(VPRecipeBase::VPInstructionSC)
 
-  VPInstruction *clone() override { return cloneWithOperands(operands()); }
+  VPInstruction *clone() override {
+    return cloneWithOperands(operands(), getScalarType());
+  }
 
-  VPInstruction *cloneWithOperands(ArrayRef<VPValue *> NewOperands) {
+  VPInstruction *cloneWithOperands(ArrayRef<VPValue *> NewOperands,
+                                   Type *ResultTy = nullptr) {
     auto *New = new VPInstruction(Opcode, NewOperands, *this, *this,
-                                  getDebugLoc(), Name);
+                                  getDebugLoc(), Name, ResultTy);
     if (getUnderlyingValue())
       New->setUnderlyingValue(getUnderlyingInstr());
     return New;
@@ -1521,18 +1529,15 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
 /// directly determine the result type. Note that there is no separate recipe ID
 /// for VPInstructionWithType; it shares the same ID as VPInstruction and is
 /// distinguished purely by the opcode.
+/// TODO: Merge with VPInstruction, now that VPRecipeValue provides the type.
 class VPInstructionWithType : public VPInstruction {
-  /// Scalar result type produced by the recipe.
-  Type *ResultTy;
-
 public:
   VPInstructionWithType(unsigned Opcode, ArrayRef<VPValue *> Operands,
                         Type *ResultTy, const VPIRFlags &Flags = {},
                         const VPIRMetadata &Metadata = {},
                         DebugLoc DL = DebugLoc::getUnknown(),
                         const Twine &Name = "")
-      : VPInstruction(Opcode, Operands, Flags, Metadata, DL, Name),
-        ResultTy(ResultTy) {}
+      : VPInstruction(Opcode, Operands, Flags, Metadata, DL, Name, ResultTy) {}
 
   static inline bool classof(const VPRecipeBase *R) {
     // VPInstructionWithType are VPInstructions with specific opcodes requiring
@@ -1575,7 +1580,7 @@ class VPInstructionWithType : public VPInstruction {
     return 0;
   }
 
-  Type *getResultType() const { return ResultTy; }
+  Type *getResultType() const { return getScalarType(); }
 
 protected:
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -1652,8 +1657,9 @@ class VPPhiAccessors {
 
 struct LLVM_ABI_FOR_TEST VPPhi : public VPInstruction, public VPPhiAccessors {
   VPPhi(ArrayRef<VPValue *> Operands, const VPIRFlags &Flags, DebugLoc DL,
-        const Twine &Name = "")
-      : VPInstruction(Instruction::PHI, Operands, Flags, {}, DL, Name) {}
+        const Twine &Name = "", Type *ResultTy = nullptr)
+      : VPInstruction(Instruction::PHI, Operands, Flags, {}, DL, Name,
+                      ResultTy) {}
 
   static inline bool classof(const VPUser *U) {
     auto *VPI = dyn_cast<VPInstruction>(U);
@@ -3313,7 +3319,8 @@ class LLVM_ABI_FOR_TEST VPReplicateRecipe : public VPRecipeWithIRFlags,
                     bool IsSingleScalar, VPValue *Mask = nullptr,
                     const VPIRFlags &Flags = {}, VPIRMetadata Metadata = {},
                     DebugLoc DL = DebugLoc::getUnknown())
-      : VPRecipeWithIRFlags(VPRecipeBase::VPReplicateSC, Operands, Flags, DL),
+      : VPRecipeWithIRFlags(VPRecipeBase::VPReplicateSC, Operands,
+                            computeScalarType(I, Operands), Flags, DL),
         VPIRMetadata(Metadata), IsSingleScalar(IsSingleScalar),
         IsPredicated(Mask) {
     setUnderlyingValue(I);
@@ -3323,6 +3330,11 @@ class LLVM_ABI_FOR_TEST VPReplicateRecipe : public VPRecipeWithIRFlags,
 
   ~VPReplicateRecipe() override = default;
 
+  /// Compute the scalar result type for a VPReplicateRecipe wrapping \p I with
+  /// \p Operands (excluding any predicate mask).
+  static Type *computeScalarType(const Instruction *I,
+                                 ArrayRef<VPValue *> Operands);
+
   VPReplicateRecipe *clone() override {
     auto *Copy = new VPReplicateRecipe(
         getUnderlyingInstr(), operands(), IsSingleScalar,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 40cce07557ab5..fe1747098fb73 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -35,122 +35,6 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPBlendRecipe *R) {
   return ResTy;
 }
 
-Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
-  // Set the result type from the first operand, check if the types for all
-  // other operands match and cache them.
-  auto SetResultTyFromOp = [this, R]() {
-    Type *ResTy = inferScalarType(R->getOperand(0));
-    unsigned NumOperands = R->getNumOperandsWithoutMask();
-    for (unsigned Op = 1; Op != NumOperands; ++Op) {
-      VPValue *OtherV = R->getOperand(Op);
-      assert(inferScalarType(OtherV) == ResTy &&
-             "different types inferred for different operands");
-      CachedTypes[OtherV] = ResTy;
-    }
-    return ResTy;
-  };
-
-  unsigned Opcode = R->getOpcode();
-  if (Instruction::isBinaryOp(Opcode) || Instruction::isUnaryOp(Opcode))
-    return SetResultTyFromOp();
-
-  switch (Opcode) {
-  case Instruction::PHI:
-    for (VPValue *Op : R->operands()) {
-      if (auto *VIR = dyn_cast<VPIRValue>(Op))
-        return VIR->getType();
-      if (auto *Ty = CachedTypes.lookup(Op))
-        return Ty;
-    }
-  LLVM_FALLTHROUGH;
-  case Instruction::ExtractElement:
-  case Instruction::InsertElement:
-  case Instruction::Freeze:
-  case VPInstruction::Broadcast:
-  case VPInstruction::ComputeReductionResult:
-  case VPInstruction::ExitingIVValue:
-  case VPInstruction::ExtractLastLane:
-  case VPInstruction::ExtractPenultimateElement:
-  case VPInstruction::ExtractLastPart:
-  case VPInstruction::ExtractLastActive:
-  case VPInstruction::PtrAdd:
-  case VPInstruction::WidePtrAdd:
-  case VPInstruction::ReductionStartVector:
-  case VPInstruction::ResumeForEpilogue:
-  case VPInstruction::Reverse:
-    return inferScalarType(R->getOperand(0));
-  case Instruction::Select: {
-    Type *ResTy = inferScalarType(R->getOperand(1));
-    VPValue *OtherV = R->getOperand(2);
-    assert(inferScalarType(OtherV) == ResTy &&
-           "different types inferred for different operands");
-    CachedTypes[OtherV] = ResTy;
-    return ResTy;
-  }
-  case Instruction::ICmp:
-  case Instruction::FCmp:
-  case VPInstruction::ActiveLaneMask:
-    assert(inferScalarType(R->getOperand(0)) ==
-               inferScalarType(R->getOperand(1)) &&
-           "different types inferred for different operands");
-    return IntegerType::get(Ctx, 1);
-  case VPInstruction::ExplicitVectorLength:
-    return Type::getIntNTy(Ctx, 32);
-  case VPInstruction::FirstOrderRecurrenceSplice:
-  case VPInstruction::Not:
-  case VPInstruction::CalculateTripCountMinusVF:
-  case VPInstruction::CanonicalIVIncrementForPart:
-  case VPInstruction::AnyOf:
-  case VPInstruction::BuildStructVector:
-  case VPInstruction::BuildVector:
-  case VPInstruction::Unpack:
-    return SetResultTyFromOp();
-  case VPInstruction::ExtractLane:
-    return inferScalarType(R->getOperand(1));
-  case VPInstruction::FirstActiveLane:
-  case VPInstruction::LastActiveLane:
-    // Assume that the maximum possible number of elements in a vector fits
-    // within the index type for the default address space.
-    return DL.getIndexType(Ctx, 0);
-  case VPInstruction::LogicalAnd:
-  case VPInstruction::LogicalOr:
-    assert(inferScalarType(R->getOperand(0))->isIntegerTy(1) &&
-           inferScalarType(R->getOperand(1))->isIntegerTy(1) &&
-           "LogicalAnd/Or operands should be bool");
-    return IntegerType::get(Ctx, 1);
-  case VPInstruction::MaskedCond:
-    assert(inferScalarType(R->getOperand(0))->isIntegerTy(1));
-    return IntegerType::get(Ctx, 1);
-  case VPInstruction::BranchOnCond:
-  case VPInstruction::BranchOnTwoConds:
-  case VPInstruction::BranchOnCount:
-  case Instruction::Store:
-  case Instruction::Switch:
-    return Type::getVoidTy(Ctx);
-  case Instruction::Load:
-    return cast<LoadInst>(R->getUnderlyingValue())->getType();
-  case Instruction::Alloca:
-    return cast<AllocaInst>(R->getUnderlyingValue())->getType();
-  case Instruction::Call: {
-    unsigned CallIdx = R->getNumOperandsWithoutMask() - 1;
-    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
-        ->getReturnType();
-  }
-  case Instruction::GetElementPtr:
-    return inferScalarType(R->getOperand(0));
-  case Instruction::ExtractValue:
-    return cast<ExtractValueInst>(R->getUnderlyingValue())->getType();
-  default:
-    break;
-  }
-  // Type inference not implemented for opcode.
-  LLVM_DEBUG({
-    dbgs() << "LV: Found unhandled opcode for: ";
-    R->getVPSingleValue()->dump();
-  });
-  llvm_unreachable("Unhandled opcode!");
-}
-
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
   unsigned Opcode = R->getOpcode();
   if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
@@ -195,62 +79,6 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
   llvm_unreachable("Unhandled opcode!");
 }
 
-Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
-  unsigned Opcode = R->getUnderlyingInstr()->getOpcode();
-
-  if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
-      Instruction::isBitwiseLogicOp(Opcode)) {
-    Type *ResTy = inferScalarType(R->getOperand(0));
-    assert(ResTy == inferScalarType(R->getOperand(1)) &&
-           "inferred types for operands of binary op don't match");
-    CachedTypes[R->getOperand(1)] = ResTy;
-    return ResTy;
-  }
-
-  if (Instruction::isCast(Opcode))
-    return R->getUnderlyingInstr()->getType();
-
-  switch (Opcode) {
-  case Instruction::Call: {
-    unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
-    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
-        ->getReturnType();
-  }
-  case Instruction::Select: {
-    Type *ResTy = inferScalarType(R->getOperand(1));
-    assert(ResTy == inferScalarType(R->getOperand(2)) &&
-           "inferred types for operands of select op don't match");
-    CachedTypes[R->getOperand(2)] = ResTy;
-    return ResTy;
-  }
-  case Instruction::ICmp:
-  case Instruction::FCmp:
-    return IntegerType::get(Ctx, 1);
-  case Instruction::Alloca:
-  case Instruction::ExtractValue:
-    return R->getUnderlyingInstr()->getType();
-  case Instruction::Freeze:
-  case Instruction::FNeg:
-  case Instruction::GetElementPtr:
-    return inferScalarType(R->getOperand(0));
-  case Instruction::Load:
-    return cast<LoadInst>(R->getUnderlyingInstr())->getType();
-  case Instruction::Store:
-    // FIXME: VPReplicateRecipes with store opcodes still define a result
-    // VPValue, so we need to handle them here. Remove the code here once this
-    // is modeled accurately in VPlan.
-    return Type::getVoidTy(Ctx);
-  default:
-    break;
-  }
-  // Type inference not implemented for opcode.
-  LLVM_DEBUG({
-    dbgs() << "LV: Found unhandled opcode for: ";
-    R->getVPSingleValue()->dump();
-  });
-  llvm_unreachable("Unhandled opcode");
-}
-
 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
@@ -260,7 +88,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
           VPScalarIVStepsRecipe, VPWidenCanonicalIVRecipe, VPWidenCastRecipe,
           VPWidenIntrinsicRecipe, VPWidenGEPRecipe, VPVectorPointerRecipe,
           VPVectorEndPointerRecipe, VPWidenCallRecipe, VPWidenLoadRecipe,
-          VPWidenLoadEVLRecipe, VPDerivedIVRecipe, VPHeaderPHIRecipe>(V)) {
+          VPWidenLoadEVLRecipe, VPDerivedIVRecipe, VPHeaderPHIRecipe,
+          VPInstruction, VPReplicateRecipe>(V)) {
     Type *Ty = V->getScalarType();
     assert(Ty && "Scalar type must be set by recipe construction");
     return Ty;
@@ -268,10 +97,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
 
   Type *ResultTy =
       TypeSwitch<const VPRecipeBase *, Type *>(V->getDefiningRecipe())
-          // VPInstructionWithType must be handled before VPInstruction.
-          .Case<VPInstructionWithType>(
-              [](const auto *R) { return R->getResultType(); })
-          .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe>(
+          .Case<VPBlendRecipe, VPWidenRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case([this](const VPReductionRecipe *R) {
             return inferScalarType(R->getChainOp());
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 8832c85b1dd02..57e438c1b03c0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -22,9 +22,7 @@ namespace llvm {
 class LLVMContext;
 class VPValue;
 class VPBlendRecipe;
-class VPInstruction;
 class VPWidenRecipe;
-class VPReplicateRecipe;
 class VPRecipeBase;
 class VPlan;
 class Value;
@@ -48,9 +46,7 @@ class VPTypeAnalysis {
   const DataLayout &DL;
 
   Type *inferScalarTypeForRecipe(const VPBlendRecipe *R);
-  Type *inferScalarTypeForRecipe(const VPInstruction *R);
   Type *inferScalarTypeForRecipe(const VPWidenRecipe *R);
-  Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
 
 public:
   VPTypeAnalysis(const VPlan &Plan)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 391c358b22fa3..bcb0d28283025 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -225,8 +225,8 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB,
       // Phi node's operands may not have been visited at this point. We create
       // an empty VPInstruction that we will fix once the whole plain CFG has
       // been built.
-      NewR =
-          VPIRBuilder.createScalarPhi({}, Phi->getDebugLoc(), "vec.phi", *Phi);
+      NewR = VPIRBuilder.createScalarPhi({}, Phi->getDebugLoc(), "vec.phi",
+                                         *Phi, Phi->getType());
       NewR->setUnderlyingValue(Phi);
       if (isHeaderBB(Phi->getParent(), LI->getLoopFor(Phi->getParent()))) {
         // Header phis need to be fixed after the VPBB for the latch has been
@@ -275,9 +275,9 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB,
       } else {
         // Build VPInstruction for any arbitrary Instruction without specific
         // representation in VPlan.
-        NewR =
-            VPIRBuilder.createNaryOp(Inst->getOpcode(), VPOperands, Inst,
-                                     VPIRFlags(*Inst), MD, Inst->getDebugLoc());
+        NewR = VPIRBuilder.createNaryOp(
+            Inst->getOpcode(), VPOperands, Inst, VPIRFlags(*Inst), MD,
+            Inst->getDebugLoc(), "", Inst->getType());
       }
     }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f9fc97794f997..0fbf067e15838 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -437,10 +437,104 @@ VPExpandSCEVRecipe::VPExpandSCEVRecipe(const SCEV *Expr)
     : VPSingleDefRecipe(VPRecipeBase::VPExpandSCEVSC, {}, Expr->getType()),
       Expr(Expr) {}
 
+Type *llvm::computeScalarTypeForInstruction(unsigned Opcode,
+                                            ArrayRef<VPValue *> Operands) {
+  assert(!Operands.empty() &&
+         "zero-operand VPInstruction opcodes must pass explicit ResultTy");
+  // Assert operand \p Idx (if present and typed) has type \p ExpectedTy.
+  auto AssertOperandType = [&Operands]([[maybe_unused]] unsigned Idx,
+                                       [[maybe_unused]] Type *ExpectedTy) {
+#ifndef NDEBUG
+    if (!ExpectedTy || Operands.size() <= Idx)
+      return;
+    Type *OpTy = getScalarTypeOrInfer(Operands[Idx]);
+    assert((!OpTy || OpTy == ExpectedTy) &&
+           "different types inferred for different operands");
+#endif
+  };
+
+  Type *Op0Ty = getScalarTypeOrInfer(Operands[0]);
+  LLVMContext &Ctx = Op0Ty->getContext();
+  switch (Opcode) {
+  case VPInstruction::BranchOnCond:
+  case VPInstruction::BranchOnTwoConds:
+  case VPInstruction::BranchOnCount:
+  case Instruction::Store:
+  case Instruction::Switch:
+    return Type::getVoidTy(Ctx);
+  case Instruction::ICmp...
[truncated]

fhahn added 2 commits May 24, 2026 13:19
Update VPInstruction and VPPhi to populate VPSingleDefValue's scalar
type. For most opcodes, the scalar type is determine from the operands,
via computeScalarTypeForInstruction, which roughly matches to removed
inference code. For some opcodes, like FirstActiveLane, the type must be
provided explicitly.
Update VPReplicateRecipe to populate VPSingleDefValue's scalar
type. For most opcodes, the scalar type is determine from the operands,
via computeScalarTypeForInstruction (from
llvm#199378).
For some opcodes, like Loads and casts, the type must be
provided explicitly.

Depends on llvm#199378 (included in
PR).
@fhahn fhahn force-pushed the vplan-scalar-type-vpreplicaterecipe branch from 00a83ff to 54fe153 Compare May 24, 2026 19:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant