Skip to content

[InstCombine] Fix vector_reduce_mul(sext <n x i1>) for odd n.#199401

Merged
jlebar merged 2 commits into
llvm:mainfrom
jlebar:fix-013-vecreduce-mul-sext-i1-odd
May 24, 2026
Merged

[InstCombine] Fix vector_reduce_mul(sext <n x i1>) for odd n.#199401
jlebar merged 2 commits into
llvm:mainfrom
jlebar:fix-013-vecreduce-mul-sext-i1-odd

Conversation

@jlebar
Copy link
Copy Markdown
Member

@jlebar jlebar commented May 24, 2026

Before this patch, instcombine folded

vector_reduce_mul(sext (<n x i1> val))

to

zext(vector_reduce_and(<n x i1> val)).

But this is incorrect when n is odd: The result of the reduction is -1,
not 1.

This bug was found by a large run of Opus 4.7 looking for bugs in LLVM.

@jlebar jlebar requested a review from nikic as a code owner May 24, 2026 06:32
@llvmorg-github-actions llvmorg-github-actions Bot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms labels May 24, 2026
@llvmorg-github-actions
Copy link
Copy Markdown

@llvm/pr-subscribers-llvm-transforms

Author: Justin Lebar (jlebar)

Changes

Before this patch, instcombine folded

vector_reduce_mul(sext (<n x i1> val))

to

zext(vector_reduce_and(<n x i1> val)).

But this is incorrect when n is odd: The result of the reduction is -1,
not 1.

This bug was found by a large run of Opus 4.7 looking for bugs in LLVM.


Full diff: https://github.com/llvm/llvm-project/pull/199401.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+31-14)
  • (modified) llvm/test/Transforms/InstCombine/reduction-mul-sext-zext-i1.ll (+25)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 360326f47594d..b567d5a0b665b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -4111,14 +4111,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
   }
   case Intrinsic::vector_reduce_mul: {
     if (IID == Intrinsic::vector_reduce_mul) {
-      // Multiplicative reduction over the vector with (potentially-extended)
-      // i1 element type is actually a (potentially zero-extended)
-      // logical `and` reduction over the original non-extended value:
-      //   vector_reduce_mul(?ext(<n x i1>))
-      //     -->
-      //   zext(vector_reduce_and(<n x i1>))
       Value *Arg = II->getArgOperand(0);
-      Value *Vect;
 
       if (Value *NewOp =
               simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
@@ -4126,13 +4119,37 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
         return II;
       }
 
-      if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
-        if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
-          if (VTy->getElementType() == Builder.getInt1Ty()) {
-            Value *Res = Builder.CreateAndReduce(Vect);
-            Res = Builder.CreateZExt(Res, II->getType());
-            return replaceInstUsesWith(CI, Res);
-          }
+      auto IsI1Vec = [&](Value *V) {
+        auto *VTy = dyn_cast<VectorType>(V->getType());
+        return VTy && VTy->getElementType() == Builder.getInt1Ty();
+      };
+
+      // vector_reduce_mul(zext(<n x i1>)) --> zext(vector_reduce_and(<n x i1>))
+      Value *Vect;
+      if (match(Arg, m_ZExt(m_Value(Vect))) && IsI1Vec(Vect)) {
+        Value *Res = Builder.CreateAndReduce(Vect);
+        return replaceInstUsesWith(CI, Builder.CreateZExt(Res, II->getType()));
+      }
+
+      // vector_reduce_mul(sext(<n x i1>)) -->
+      //   sext(vector_reduce_and(<n x i1>)) if n is odd
+      //   zext(vector_reduce_and(<n x i1>)) if n is even.
+      // This is because if the vector is all `true`, we are multiplying n -1s.
+      // Therefore the answer is -1 if n is odd, or 1 if n is even.
+      if (match(Arg, m_SExt(m_Value(Vect)))) {
+        if (auto *VTy = dyn_cast<FixedVectorType>(Vect->getType());
+            VTy && VTy->getElementType() == Builder.getInt1Ty()) {
+          Value *Res = Builder.CreateAndReduce(Vect);
+          Res = (VTy->getNumElements() & 1)
+                    ? Builder.CreateSExt(Res, II->getType())
+                    : Builder.CreateZExt(Res, II->getType());
+          return replaceInstUsesWith(CI, Res);
+        }
+      }
+
+      // vector_reduce_mul(<n x i1>) --> vector_reduce_and(<n x i1>)
+      if (IsI1Vec(Arg)) {
+        return replaceInstUsesWith(CI, Builder.CreateAndReduce(Arg));
       }
     }
     [[fallthrough]];
diff --git a/llvm/test/Transforms/InstCombine/reduction-mul-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-mul-sext-zext-i1.ll
index f70820801602c..c867e076d9b4a 100644
--- a/llvm/test/Transforms/InstCombine/reduction-mul-sext-zext-i1.ll
+++ b/llvm/test/Transforms/InstCombine/reduction-mul-sext-zext-i1.ll
@@ -95,8 +95,33 @@ define i64 @reduce_mul_zext_external_use(<8 x i1> %x) {
   ret i64 %res
 }
 
+define i8 @reduce_mul_sext_odd_lanes(<3 x i1> %x) {
+; CHECK-LABEL: @reduce_mul_sext_odd_lanes(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <3 x i1> [[X:%.*]] to i3
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i3 [[TMP1]], -1
+; CHECK-NEXT:    [[RES:%.*]] = sext i1 [[TMP2]] to i8
+; CHECK-NEXT:    ret i8 [[RES]]
+;
+  %sext = sext <3 x i1> %x to <3 x i8>
+  %res = call i8 @llvm.vector.reduce.mul.v3i8(<3 x i8> %sext)
+  ret i8 %res
+}
+
+define i8 @reduce_mul_zext_odd_lanes(<3 x i1> %x) {
+; CHECK-LABEL: @reduce_mul_zext_odd_lanes(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <3 x i1> [[X:%.*]] to i3
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i3 [[TMP1]], -1
+; CHECK-NEXT:    [[RES:%.*]] = zext i1 [[TMP2]] to i8
+; CHECK-NEXT:    ret i8 [[RES]]
+;
+  %zext = zext <3 x i1> %x to <3 x i8>
+  %res = call i8 @llvm.vector.reduce.mul.v3i8(<3 x i8> %zext)
+  ret i8 %res
+}
+
 declare i1 @llvm.vector.reduce.mul.v8i32(<8 x i1> %a)
 declare i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> %a)
 declare i64 @llvm.vector.reduce.mul.v8i64(<8 x i64> %a)
 declare i16 @llvm.vector.reduce.mul.v16i16(<16 x i16> %a)
 declare i8 @llvm.vector.reduce.mul.v128i8(<128 x i8> %a)
+declare i8 @llvm.vector.reduce.mul.v3i8(<3 x i8> %a)

Comment thread llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated
Comment thread llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated
Comment thread llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated
Comment thread llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated
Comment thread llvm/test/Transforms/InstCombine/reduction-mul-sext-zext-i1.ll
Comment thread llvm/test/Transforms/InstCombine/reduction-mul-sext-zext-i1.ll Outdated
@jlebar jlebar requested a review from dtcxzyw May 24, 2026 09:39
@jlebar
Copy link
Copy Markdown
Member Author

jlebar commented May 24, 2026

Thank you for all the reviews.

@jlebar jlebar changed the title [InstCombine] Fix vector_reduce_mul(sext <n x i1>). [InstCombine] Fix vector_reduce_mul(sext <n x i1>) for odd n. May 24, 2026
Copy link
Copy Markdown
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

jlebar added 2 commits May 24, 2026 03:00
Before this patch, instcombine folded

  vector_reduce_mul(sext (<n x i1> val))

to

  zext(vector_reduce_and(<n x i1> val)).

But this is incorrect when n is odd: The result of the reduction is -1,
not 1.

This bug was found by a large run of Opus 4.7 looking for bugs in LLVM.
@jlebar jlebar force-pushed the fix-013-vecreduce-mul-sext-i1-odd branch from b4a0310 to 008728f Compare May 24, 2026 10:01
@jlebar jlebar enabled auto-merge (squash) May 24, 2026 10:01
@jlebar jlebar merged commit 46666d9 into llvm:main May 24, 2026
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants