diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 08862b81..57384505 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -513,6 +513,10 @@ def FftMultiplier : GlobalExpr; // Derivative rules + +def : HLODerivative<"AbsOp", (Op $x), [(CheckedMul (DiffeRet), (Sign $x))] // TODO: support complex numbers + >; + def : HLODerivative<"AddOp", (Op $x, $y), [ (DiffeRet), diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 94ce0706..5faf0bcc 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -469,6 +469,129 @@ class AutoDiffConcatenateRev MGradientUtilsReverse *gutils) const {} }; +class AutoDiffGatherRev + : public ReverseAutoDiffOpInterface::ExternalModel { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto op = cast(orig); + auto inDiffe = gutils->diffe(op, builder); + GatherDimensionNumbersAttr dims = op.getDimensionNumbers(); + + auto scatterDims = ScatterDimensionNumbersAttr::get( + op.getContext(), + /*updateWindowDims=*/dims.getOffsetDims(), + /*insertedWindowDims=*/dims.getCollapsedSliceDims(), + /*inputBatchingDims=*/dims.getOperandBatchingDims(), + /*scatterIndicesBatchingDims=*/dims.getStartIndicesBatchingDims(), + /*scatterDimsToOperandDims=*/dims.getStartIndexMap(), + /*indexVectorDim=*/dims.getIndexVectorDim()); + Value indices = gutils->getNewFromOriginal(op.getStartIndices()); + auto operandType = cast(op.getOperand().getType()); + + auto scatter = builder.create( + op.getLoc(), operandType, gutils->diffe(op.getOperand(), builder), + indices, inDiffe, scatterDims, op.getIndicesAreSorted()); + + // Add the update computation. This includes the plus-equalling into the + // gradient. + { + OpBuilder::InsertionGuard guard(builder); + RankedTensorType zeroDType = operandType.clone(/*shape=*/std::nullopt); + SmallVector blockArgTypes(2, zeroDType); + Region &bodyRegion = scatter.getUpdateComputation(); + Block::BlockArgListType blockArgs = + builder + .createBlock(&bodyRegion, bodyRegion.begin(), blockArgTypes, + SmallVector(2, op.getLoc())) + ->getArguments(); + Value sum = + builder.create(op.getLoc(), blockArgs[0], blockArgs[1]); + builder.create(op.getLoc(), sum); + } + + gutils->setDiffe(op.getOperand(), scatter.getResult(0), builder); + + return success(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return {}; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + +class AutoDiffScatterRev + : public ReverseAutoDiffOpInterface::ExternalModel { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto op = cast(orig); + auto inDiffe = gutils->diffe(op->getResult(0), builder); + if (op.getUpdates().size() != 1) { + return op.emitError() << "Expected exactly 1 update operand"; + } + + Operation &innerOp = op.getUpdateComputation().front().front(); + if (!isa(innerOp)) { + return op.emitError() + << "Unsupported operation in scatter rev autodiff: " << *orig; + } + + Value updates = op.getUpdates().front(); + ScatterDimensionNumbersAttr scatterDims = op.getScatterDimensionNumbers(); + ArrayRef updateWindowDims = scatterDims.getUpdateWindowDims(); + ArrayRef insertedWindowDims = scatterDims.getInsertedWindowDims(); + ArrayRef inputBatchingDims = scatterDims.getInputBatchingDims(); + + auto gatherDims = GatherDimensionNumbersAttr::get( + op.getContext(), + /*offsetDims=*/updateWindowDims, + /*collapsedSliceDims=*/insertedWindowDims, + /*operandBatchingDims=*/inputBatchingDims, + /*startIndicesBatchingDims=*/ + scatterDims.getScatterIndicesBatchingDims(), + /*startIndexMap=*/scatterDims.getScatterDimsToOperandDims(), + /*indexVectorDim=*/scatterDims.getIndexVectorDim()); + + auto operandType = cast(inDiffe.getType()); + auto updatesType = cast(updates.getType()); + + // Compute slice sizes + SmallVector sliceSizes(operandType.getRank(), 1); + DenseSet skippedDims; + skippedDims.insert(insertedWindowDims.begin(), insertedWindowDims.end()); + skippedDims.insert(inputBatchingDims.begin(), inputBatchingDims.end()); + + unsigned windowIdx = 0; + for (int64_t i = 0; i < operandType.getRank(); ++i) + if (!skippedDims.contains(i)) + sliceSizes[i] = updatesType.getShape()[updateWindowDims[windowIdx++]]; + + Value indices = gutils->getNewFromOriginal(op.getScatterIndices()); + auto gather = + builder.create(op.getLoc(), inDiffe, indices, gatherDims, + sliceSizes, op.getIndicesAreSorted()); + gutils->addToDiffe(updates, gather, builder); + return success(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return {}; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + struct SHLOConstantOpBatchInterface : public BatchOpInterface::ExternalModel { @@ -512,6 +635,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( SliceOp::attachInterface(*context); ReduceOp::attachInterface(*context); ConcatenateOp::attachInterface(*context); + GatherOp::attachInterface(*context); + ScatterOp::attachInterface(*context); ConstantOp::attachInterface(*context); }); }