diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index c6f14018a750f5..b4033fc2a418a1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1185,14 +1185,27 @@ static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) { /// (icmp eq X0, Y0) & (icmp eq X1, Y1) -> icmp eq X01, Y01 /// (icmp ne X0, Y0) | (icmp ne X1, Y1) -> icmp ne X01, Y01 /// where X0, X1 and Y0, Y1 are adjacent parts extracted from an integer. -Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, - bool IsAnd) { +Value *InstCombinerImpl::foldEqOfParts(Value *Cmp0, Value *Cmp1, bool IsAnd) { if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse()) return nullptr; CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; - auto GetMatchPart = [&](ICmpInst *Cmp, + auto GetMatchPart = [&](Value *CmpV, unsigned OpNo) -> std::optional { + assert(CmpV->getType()->isIntOrIntVectorTy(1) && "Must be bool"); + + Value *X, *Y; + // icmp ne (and x, 1), (and y, 1) <=> trunc (xor x, y) to i1 + // icmp eq (and x, 1), (and y, 1) <=> not (trunc (xor x, y) to i1) + if (Pred == CmpInst::ICMP_NE + ? match(CmpV, m_Trunc(m_Xor(m_Value(X), m_Value(Y)))) + : match(CmpV, m_Not(m_Trunc(m_Xor(m_Value(X), m_Value(Y)))))) + return {{OpNo == 0 ? X : Y, 0, 1}}; + + auto *Cmp = dyn_cast(CmpV); + if (!Cmp) + return std::nullopt; + if (Pred == Cmp->getPredicate()) return matchIntPart(Cmp->getOperand(OpNo)); @@ -3404,9 +3417,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return X; } - if (Value *X = foldEqOfParts(LHS, RHS, IsAnd)) - return X; - // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) // TODO: Remove this and below when foldLogOpOfMaskedICmps can handle undefs. @@ -3529,6 +3539,9 @@ Value *InstCombinerImpl::foldBooleanAndOr(Value *LHS, Value *RHS, if (Value *Res = foldLogicOfFCmps(LHSCmp, RHSCmp, IsAnd, IsLogical)) return Res; + if (Value *Res = foldEqOfParts(LHS, RHS, IsAnd)) + return Res; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 9588930d7658c4..0508ed48fc19c4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -412,7 +412,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final bool IsAnd, bool IsLogical = false); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor); - Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd); + Value *foldEqOfParts(Value *Cmp0, Value *Cmp1, bool IsAnd); Value *foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, ICmpInst *ICmp2, bool IsAnd); diff --git a/llvm/test/Transforms/InstCombine/eq-of-parts.ll b/llvm/test/Transforms/InstCombine/eq-of-parts.ll index 9494dd6bf8e5b5..d07c2e6a5be521 100644 --- a/llvm/test/Transforms/InstCombine/eq-of-parts.ll +++ b/llvm/test/Transforms/InstCombine/eq-of-parts.ll @@ -1441,11 +1441,7 @@ define i1 @ne_optimized_highbits_cmp_todo_overlapping(i32 %x, i32 %y) { define i1 @and_trunc_i1(i8 %a1, i8 %a2) { ; CHECK-LABEL: @and_trunc_i1( -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[A1:%.*]], [[A2:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[XOR]], 2 -; CHECK-NEXT: [[LOBIT:%.*]] = trunc i8 [[XOR]] to i1 -; CHECK-NEXT: [[LOBIT_INV:%.*]] = xor i1 [[LOBIT]], true -; CHECK-NEXT: [[AND:%.*]] = and i1 [[CMP]], [[LOBIT_INV]] +; CHECK-NEXT: [[AND:%.*]] = icmp eq i8 [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: ret i1 [[AND]] ; %xor = xor i8 %a1, %a2 @@ -1494,10 +1490,7 @@ define i1 @and_trunc_i1_wrong_operands(i8 %a1, i8 %a2, i8 %a3) { define i1 @or_trunc_i1(i64 %a1, i64 %a2) { ; CHECK-LABEL: @or_trunc_i1( -; CHECK-NEXT: [[XOR:%.*]] = xor i64 [[A2:%.*]], [[A1:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i64 [[XOR]], 1 -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i64 [[XOR]] to i1 -; CHECK-NEXT: [[OR:%.*]] = or i1 [[CMP]], [[TRUNC]] +; CHECK-NEXT: [[OR:%.*]] = icmp ne i64 [[A2:%.*]], [[A1:%.*]] ; CHECK-NEXT: ret i1 [[OR]] ; %xor = xor i64 %a2, %a1