Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLIR AA: wip rebase #1985

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,18 @@ void printActivityAnalysisResults(const DataFlowSolver &solver,
});
};

if (value.getDefiningOp<LLVM::PoisonOp>() || value.getDefiningOp<LLVM::ZeroOp>() || value.getDefiningOp<LLVM::UndefOp>()) {
return true;
}

// If this triggers, investigate why the alias classes weren't computed.
// If they weren't computed legitimately, treat the value as
// conservatively non-constant or change the return type to be tri-state.
if (aliasClassLattice->isUndefined()) {
//llvm::errs() << *callee->getParentOp() << "\n";
llvm::errs() << value.getDefiningOp() << " undef alias latice " << value << "\n";
return false;
}
assert(!aliasClassLattice->isUndefined() &&
"didn't compute alias classes");

Expand All @@ -889,6 +898,12 @@ void printActivityAnalysisResults(const DataFlowSolver &solver,
if (fma->hasActiveData(aliasClass) &&
bma->activeDataFlowsOut(aliasClass))
return false;

if (pointsToSets->getPointsTo(aliasClass).isUndefined()) {

llvm::errs() << value.getDefiningOp() << " undef pointrstoalias latice " << value << "\n";
return false;
}

// If this triggers, investigate why points-to sets couldn't be
// computed. Treat conservatively as "unknown" if necessary.
Expand Down
11 changes: 10 additions & 1 deletion enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ static bool isPointerLike(Type type) {
return isa<MemRefType, LLVM::LLVMPointerType>(type);
}

/*
void mlir::enzyme::PointsToSets::dumpSet() const {
for (const auto &pair : map) {
llvm::errs() << pair.first.getAsOpaquePointer() << "\n";
}
}
*/

//===----------------------------------------------------------------------===//
// PointsToAnalysis
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -108,7 +116,8 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
llvm::errs() << "\n";
valuesCopy.print(llvm::errs());
llvm::errs() << "\n";
assert(valuesCopy == values &&
if (valuesCopy != values)
llvm::errs() <<
"attempting to replace a pointsTo entry with an alias class "
"set that is ordered _before_ the existing one -> "
"non-monotonous update ");
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class PointsToSets : public MapOfSetsLattice<DistinctAttr, DistinctAttr> {

const AliasClassSet &getPointsTo(DistinctAttr id) const { return lookup(id); }

//void dumpSets() const;
private:
/// Update all alias classes in `keysToUpdate` to additionally point to alias
/// classes in `values`. Handle undefined keys optimistically (ignore) and
Expand All @@ -159,6 +160,7 @@ class PointsToSets : public MapOfSetsLattice<DistinctAttr, DistinctAttr> {
/// in the lattice, not only the replacements described above.
ChangeResult update(const AliasClassSet &keysToUpdate,
const AliasClassSet &values, bool replace);

};

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def : InactiveOp<"LLVM", "Prefetch">;
def : InactiveOp<"LLVM", "MemsetOp">;

def : InactiveOp<"LLVM", "UndefOp">;
def : InactiveOp<"LLVM", "PoisonOp">;
def : InactiveOp<"LLVM", "ConstantOp">;
def : InactiveOp<"LLVM", "UnreachableOp">;

Expand Down
10 changes: 10 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def DifferentiateWrapperPass : Pass<"enzyme-wrap"> {
def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> {
let summary = "Print the results of activity analysis";
let constructor = "mlir::enzyme::createPrintActivityAnalysisPass()";
let dependentDialects = [
"enzyme::EnzymeDialect"
];
let options = [
ListOption<
/*C++ variable name=*/"funcsToAnalyze",
Expand Down Expand Up @@ -113,6 +116,13 @@ def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> {
/*default=*/"true",
/*description=*/"Whether to use the new Dataflow activity analysis"
>,
Option<
/*C++ variable name=*/"relative",
/*CLI argument=*/"relative",
/*type=*/"bool",
/*default=*/"false",
/*description=*/"Use relative bottom-up activity analysis"
>,
Option<
/*C++ variable name=*/"inactiveArgs",
/*CLI argument=*/"inactive-args",
Expand Down
32 changes: 25 additions & 7 deletions enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "Analysis/ActivityAnalysis.h"
#include "Analysis/ActivityAnnotations.h"
#include "Analysis/DataFlowActivityAnalysis.h"
#include "Dialect/Ops.h"
#include "Interfaces/EnzymeLogic.h"
Expand Down Expand Up @@ -200,6 +201,11 @@ struct PrintActivityAnalysisPass
}

void runOnOperation() override {
enzyme::ActivityPrinterConfig config;
config.annotate = annotate;
config.inferFromAutodiff = false;
config.verbose = verbose;

auto moduleOp = cast<ModuleOp>(getOperation());

if (annotate && dataflow) {
Expand Down Expand Up @@ -238,23 +244,31 @@ struct PrintActivityAnalysisPass
// supplied annotation. First argument is the callee
inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call,
argActivities, resultActivities);
runActivityAnalysis(dataflow, callee, argActivities, resultActivities,
/*print=*/true, verbose, annotate);
if (relative) {
enzyme::runActivityAnnotations(callee, argActivities, config);
} else {
runActivityAnalysis(dataflow, callee, argActivities, resultActivities,
/*print=*/true, verbose, annotate);
}
}
return;
}

if (funcsToAnalyze.empty()) {
moduleOp.walk([this](FunctionOpInterface callee) {
moduleOp.walk([this, &config](FunctionOpInterface callee) {
if (callee.isExternal() || callee.isPrivate())
return;

SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
resultActivities{callee.getNumResults()};
initializeArgAndResActivities(callee, argActivities, resultActivities);

runActivityAnalysis(dataflow, callee, argActivities, resultActivities,
/*print=*/true, verbose, annotate);
if (relative) {
enzyme::runActivityAnnotations(callee, argActivities, config);
} else {
runActivityAnalysis(dataflow, callee, argActivities, resultActivities,
/*print=*/true, verbose, annotate);
}
});
return;
}
Expand All @@ -276,8 +290,12 @@ struct PrintActivityAnalysisPass
resultActivities{callee.getNumResults()};
initializeArgAndResActivities(callee, argActivities, resultActivities);

runActivityAnalysis(dataflow, callee, argActivities, resultActivities,
/*print=*/true, verbose, annotate);
if (relative) {
enzyme::runActivityAnnotations(callee, argActivities, config);
} else {
runActivityAnalysis(dataflow, callee, argActivities, resultActivities,
/*print=*/true, verbose, annotate);
}
}
}
};
Expand Down
Loading
Loading