Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JLInstSimplify multi arg #2038

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 99 additions & 14 deletions enzyme/Enzyme/JLInstSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,58 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst) {
return true;
}

static inline SetVector<llvm::Value *> getBaseObjects(llvm::Value *V,
bool offsetAllowed) {
SetVector<llvm::Value *> results;

SmallPtrSet<llvm::Value *, 2> seen;
SmallVector<llvm::Value *, 1> todo = {V};

while (todo.size()) {
auto cur = todo.back();
todo.pop_back();
if (seen.count(cur))
continue;
seen.insert(cur);
auto obj = getBaseObject(cur, offsetAllowed);
if (auto PN = dyn_cast<PHINode>(obj)) {
for (auto &val : PN->incoming_values()) {
todo.push_back(val);
}
continue;
}
if (auto SI = dyn_cast<SelectInst>(obj)) {
todo.push_back(SI->getTrueValue());
todo.push_back(SI->getFalseValue());
continue;
}
results.insert(obj);
}
return results;
}

bool noaliased_or_arg(SetVector<llvm::Value *> &lhs_v,
SetVector<llvm::Value *> &rhs_v) {
for (auto lhs : lhs_v) {
auto lhs_na = isNoAlias(lhs);
auto lhs_arg = isa<Argument>(lhs);

// This LHS value is neither noalias or an argument
if (!lhs_na && !lhs_arg)
return false;

for (auto rhs : rhs_v) {
if (lhs == rhs)
return false;
if (isNoAlias(lhs))
continue;
if (!lhs_na && !isa<Argument>(rhs))
return false;
}
}
return true;
}

bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
llvm::AAResults &AA, llvm::LoopInfo &LI) {
bool changed = false;
Expand Down Expand Up @@ -175,33 +227,59 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
}

if (legal) {
auto lhs = getBaseObject(I.getOperand(0), /*offsetAllowed*/ false);
auto rhs = getBaseObject(I.getOperand(1), /*offsetAllowed*/ false);
if (lhs == rhs) {
auto lhs_v = getBaseObjects(I.getOperand(0), /*offsetAllowed*/ false);
auto rhs_v = getBaseObjects(I.getOperand(1), /*offsetAllowed*/ false);
if (lhs_v.size() == 1 && rhs_v.size() == 1 && lhs_v[0] == rhs_v[0]) {
auto repval = ICmpInst::isTrueWhenEqual(pred)
? ConstantInt::get(I.getType(), 1)
: ConstantInt::get(I.getType(), 0);
I.replaceAllUsesWith(repval);
changed = true;
continue;
}
if ((isNoAlias(lhs) && (isNoAlias(rhs) || isa<Argument>(rhs))) ||
(isNoAlias(rhs) && isa<Argument>(lhs))) {
if (noaliased_or_arg(lhs_v, rhs_v)) {
auto repval = ICmpInst::isTrueWhenEqual(pred)
? ConstantInt::get(I.getType(), 0)
: ConstantInt::get(I.getType(), 1);
I.replaceAllUsesWith(repval);
changed = true;
continue;
}
auto llhs = dyn_cast<LoadInst>(lhs);
auto lrhs = dyn_cast<LoadInst>(rhs);
if (llhs && lrhs && isa<PointerType>(llhs->getType()) &&
isa<PointerType>(lrhs->getType())) {
auto lhsv =
getBaseObject(llhs->getOperand(0), /*offsetAllowed*/ false);
auto rhsv =
getBaseObject(lrhs->getOperand(0), /*offsetAllowed*/ false);
bool loadlegal = true;
SmallVector<LoadInst *, 1> llhs, lrhs;
for (auto lhs : lhs_v) {
auto ld = dyn_cast<LoadInst>(lhs);
if (!ld || !isa<PointerType>(ld->getType())) {
loadlegal = false;
break;
}
llhs.push_back(ld);
}
for (auto rhs : rhs_v) {
auto ld = dyn_cast<LoadInst>(rhs);
if (!ld || !isa<PointerType>(ld->getType())) {
loadlegal = false;
break;
}
lrhs.push_back(ld);
}
SetVector<Value *> llhs_s, lrhs_s;
for (auto v : llhs) {
for (auto obj :
getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) {
llhs_s.insert(obj);
}
}
for (auto v : lrhs) {
for (auto obj :
getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) {
lrhs_s.insert(obj);
}
}
// TODO handle multi size
if (llhs_s.size() == 1 && lrhs_s.size() == 1 && loadlegal) {
auto lhsv = llhs_s[0];
auto rhsv = lrhs_s[0];
if ((isNoAlias(lhsv) && (isNoAlias(rhsv) || isa<Argument>(rhsv) ||
notCapturedBefore(lhsv, &I))) ||
(isNoAlias(rhsv) &&
Expand All @@ -225,7 +303,14 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
if (!I->mayWriteToMemory())
return /*earlyBreak*/ false;

for (auto LI : {llhs, lrhs})
for (auto LI : llhs)
if (writesToMemoryReadBy(AA, TLI,
/*maybeReader*/ LI,
/*maybeWriter*/ I)) {
overwritten = true;
return /*earlyBreak*/ true;
}
for (auto LI : lrhs)
if (writesToMemoryReadBy(AA, TLI,
/*maybeReader*/ LI,
/*maybeWriter*/ I)) {
Expand Down
26 changes: 26 additions & 0 deletions enzyme/test/Enzyme/JLSimplify/yesptr2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -jl-inst-simplify -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="jl-inst-simplify" -S | FileCheck %s

declare i8** @malloc(i64)

define fastcc i1 @augmented_julia__affine_normalize_1484(i1 %c) {
%i5 = call noalias i8** @malloc(i64 16)
br i1 %c, label %tval, label %fval

tval:
%j29 = load i8*, i8** %i5, align 8
br label %end

fval:
%k29 = load i8*, i8** %i5, align 8
br label %end

end:
%i29 = phi i8* [ %j29, %tval ], [ %k29, %fval ]
%i31 = call noalias nonnull i8* addrspace(10)* inttoptr (i64 137352001798896 to i8* addrspace(10)* ({} addrspace(10)*, i64, i64)*)({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 137351863426640 to {}*) to {} addrspace(10)*), i64 10, i64 10)
%i35 = load i8*, i8* addrspace(10)* %i31, align 8
%i39 = icmp ne i8* %i35, %i29
ret i1 %i39
}

; CHECK: ret i1 true
Loading