From bffec329c951e40fc3cbd7e798c147bf64ecce4e Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Tue, 10 Sep 2024 21:53:10 -0400 Subject: [PATCH 1/4] Add derivative rule for stablehlo.scatter --- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 94ce0706..88f174c4 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -469,6 +469,68 @@ class AutoDiffConcatenateRev MGradientUtilsReverse *gutils) const {} }; +class AutoDiffScatterRev + : public ReverseAutoDiffOpInterface::ExternalModel { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto op = cast(orig); + auto inDiffe = gutils->diffe(op->getResult(0), builder); + if (op.getUpdates().size() != 1) { + return op.emitError() << "Expected exactly 1 update operand"; + } + + Value updates = op.getUpdates().front(); + ScatterDimensionNumbersAttr scatterDims = op.getScatterDimensionNumbers(); + ArrayRef updateWindowDims = scatterDims.getUpdateWindowDims(); + ArrayRef insertedWindowDims = scatterDims.getInsertedWindowDims(); + ArrayRef inputBatchingDims = scatterDims.getInputBatchingDims(); + + auto gatherDims = GatherDimensionNumbersAttr::get( + op.getContext(), + /*offsetDims=*/updateWindowDims, + /*collapsedSliceDims=*/insertedWindowDims, + /*operandBatchingDims=*/inputBatchingDims, + /*startIndicesBatchingDims=*/ + scatterDims.getScatterIndicesBatchingDims(), + /*startIndexMap=*/scatterDims.getScatterDimsToOperandDims(), + /*indexVectorDim=*/scatterDims.getIndexVectorDim()); + + auto operandType = cast(inDiffe.getType()); + auto updatesType = cast(updates.getType()); + + // Compute slice sizes + SmallVector sliceSizes(operandType.getRank(), 1); + DenseSet skippedDims; + skippedDims.insert(insertedWindowDims.begin(), insertedWindowDims.end()); + skippedDims.insert(inputBatchingDims.begin(), inputBatchingDims.end()); + + unsigned windowIdx = 0; + for (int64_t i = 0; i < operandType.getRank(); ++i) { + if (!skippedDims.contains(i)) { + sliceSizes[i] = updatesType.getShape()[updateWindowDims[windowIdx++]]; + } + } + + Value indices = gutils->getNewFromOriginal(op.getScatterIndices()); + auto gather = + builder.create(op.getLoc(), inDiffe, indices, gatherDims, + sliceSizes, op.getIndicesAreSorted()); + gutils->addToDiffe(updates, gather, builder); + return success(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return {}; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + struct SHLOConstantOpBatchInterface : public BatchOpInterface::ExternalModel { @@ -512,6 +574,7 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( SliceOp::attachInterface(*context); ReduceOp::attachInterface(*context); ConcatenateOp::attachInterface(*context); + ScatterOp::attachInterface(*context); ConstantOp::attachInterface(*context); }); } From a8fe6cfce7e276016bcb63e7955e70951165ad52 Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Tue, 10 Sep 2024 23:19:51 -0400 Subject: [PATCH 2/4] Add derivative rule for gather --- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 64 +++++++++++++++++-- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 88f174c4..d9e90c88 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -469,6 +469,63 @@ class AutoDiffConcatenateRev MGradientUtilsReverse *gutils) const {} }; +class AutoDiffGatherRev + : public ReverseAutoDiffOpInterface::ExternalModel { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto op = cast(orig); + auto inDiffe = gutils->diffe(op, builder); + GatherDimensionNumbersAttr dims = op.getDimensionNumbers(); + + auto scatterDims = ScatterDimensionNumbersAttr::get( + op.getContext(), + /*updateWindowDims=*/dims.getOffsetDims(), + /*insertedWindowDims=*/dims.getCollapsedSliceDims(), + /*inputBatchingDims=*/dims.getOperandBatchingDims(), + /*scatterIndicesBatchingDims=*/dims.getStartIndicesBatchingDims(), + /*scatterDimsToOperandDims=*/dims.getStartIndexMap(), + /*indexVectorDim=*/dims.getIndexVectorDim()); + Value indices = gutils->getNewFromOriginal(op.getStartIndices()); + auto operandType = cast(op.getOperand().getType()); + + auto scatter = builder.create( + op.getLoc(), operandType, gutils->diffe(op.getOperand(), builder), + indices, inDiffe, scatterDims, op.getIndicesAreSorted()); + + // Add the update computation. This includes the plus-equalling into the + // gradient. + { + OpBuilder::InsertionGuard guard(builder); + RankedTensorType zeroDType = operandType.clone(/*shape=*/std::nullopt); + SmallVector blockArgTypes(2, zeroDType); + Region &bodyRegion = scatter.getUpdateComputation(); + Block::BlockArgListType blockArgs = + builder + .createBlock(&bodyRegion, bodyRegion.begin(), blockArgTypes, + SmallVector(2, op.getLoc())) + ->getArguments(); + Value sum = + builder.create(op.getLoc(), blockArgs[0], blockArgs[1]); + builder.create(op.getLoc(), sum); + } + + gutils->setDiffe(op.getOperand(), scatter.getResult(0), builder); + + return success(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return {}; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + class AutoDiffScatterRev : public ReverseAutoDiffOpInterface::ExternalModel { @@ -508,11 +565,9 @@ class AutoDiffScatterRev skippedDims.insert(inputBatchingDims.begin(), inputBatchingDims.end()); unsigned windowIdx = 0; - for (int64_t i = 0; i < operandType.getRank(); ++i) { - if (!skippedDims.contains(i)) { + for (int64_t i = 0; i < operandType.getRank(); ++i) + if (!skippedDims.contains(i)) sliceSizes[i] = updatesType.getShape()[updateWindowDims[windowIdx++]]; - } - } Value indices = gutils->getNewFromOriginal(op.getScatterIndices()); auto gather = @@ -574,6 +629,7 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( SliceOp::attachInterface(*context); ReduceOp::attachInterface(*context); ConcatenateOp::attachInterface(*context); + GatherOp::attachInterface(*context); ScatterOp::attachInterface(*context); ConstantOp::attachInterface(*context); }); From 47bf82383009987937fcec9c8f7018a6d91a74f2 Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Thu, 5 Sep 2024 16:46:31 -0400 Subject: [PATCH 3/4] Add derivative rule for stablehlo.abs for real numbers --- src/enzyme_ad/jax/Implementations/HLODerivatives.td | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 08862b81..57384505 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -513,6 +513,10 @@ def FftMultiplier : GlobalExpr; // Derivative rules + +def : HLODerivative<"AbsOp", (Op $x), [(CheckedMul (DiffeRet), (Sign $x))] // TODO: support complex numbers + >; + def : HLODerivative<"AddOp", (Op $x, $y), [ (DiffeRet), From a1cc40d7b3977f39dc3fa1003f9525da3008954b Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Wed, 18 Sep 2024 15:07:42 -0400 Subject: [PATCH 4/4] Add explicit check that scatter body is supported --- .../Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index d9e90c88..5faf0bcc 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -539,6 +539,12 @@ class AutoDiffScatterRev return op.emitError() << "Expected exactly 1 update operand"; } + Operation &innerOp = op.getUpdateComputation().front().front(); + if (!isa(innerOp)) { + return op.emitError() + << "Unsupported operation in scatter rev autodiff: " << *orig; + } + Value updates = op.getUpdates().front(); ScatterDimensionNumbersAttr scatterDims = op.getScatterDimensionNumbers(); ArrayRef updateWindowDims = scatterDims.getUpdateWindowDims();