diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 472eccddd49..a30695e1f57 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -17,7 +17,9 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -297,6 +299,109 @@ struct InsertValueOpInterfaceReverse MGradientUtilsReverse *gutils) const {} }; +struct MemcpyOpInterfaceReverse + : public ReverseAutoDiffOpInterface::ExternalModel { + + static Type inferElemType(LLVM::MemcpyOp cp) { + if (auto t = cp->getAttrOfType("enzyme.elem_type")) + return t.getValue(); + auto walk = [](Value p) -> Type { + for (Operation *user : p.getUsers()) { + if (auto ld = dyn_cast(user)) + if (isa(ld.getType())) + return ld.getType(); + if (auto st = dyn_cast(user)) + if (isa(st.getValue().getType())) + return st.getValue().getType(); + } + return nullptr; + }; + if (Type t = walk(cp.getDst())) + return t; + if (Type t = walk(cp.getSrc())) + return t; + return Float64Type::get(cp.getContext()); + } + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + auto cp = cast(op); + if (gutils->isConstantValue(cp.getDst())) + return {}; + bool srcActive = !gutils->isConstantValue(cp.getSrc()); + OpBuilder cb(gutils->getNewFromOriginal(op)); + SmallVector caches; + caches.push_back( + gutils->initAndPushCache(gutils->invertPointerM(cp.getDst(), cb), cb)); + caches.push_back(gutils->initAndPushCache( + srcActive ? gutils->invertPointerM(cp.getSrc(), cb) + : gutils->getNewFromOriginal(cp.getSrc()), + cb)); + caches.push_back( + gutils->initAndPushCache(gutils->getNewFromOriginal(cp.getLen()), cb)); + return caches; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} + + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto cp = cast(op); + if (gutils->isConstantValue(cp.getDst())) + return success(); + bool srcActive = !gutils->isConstantValue(cp.getSrc()); + + Value dDst = gutils->popCache(caches[0], builder); + Value dSrc = gutils->popCache(caches[1], builder); + Value len = gutils->popCache(caches[2], builder); + + Type elemTy = inferElemType(cp); + auto adt = dyn_cast(elemTy); + if (!adt || !elemTy.isIntOrFloat()) + return op->emitError() + << "memcpy reverse: unsupported element type " << elemTy + << " (annotate enzyme.elem_type or lower to scalar stores)"; + + Location loc = op->getLoc(); + unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8; + + // n_elements = len / sizeof(elemTy) + Value byteSz = + LLVM::ConstantOp::create(builder, loc, len.getType(), + builder.getIntegerAttr(len.getType(), bytes)); + Value nInt = LLVM::SDivOp::create(builder, loc, len, byteSz); + Value n = + arith::IndexCastOp::create(builder, loc, builder.getIndexType(), nInt); + + Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); + Value c1 = arith::ConstantIndexOp::create(builder, loc, 1); + Value zeroElem = adt.createNullValue(builder, loc); + Type ptrTy = cp.getDst().getType(); + + auto forOp = scf::ForOp::create(builder, loc, c0, n, c1); + OpBuilder body(forOp.getBody()->getTerminator()); + Value ivIdx = forOp.getInductionVar(); + Value iv = arith::IndexCastOp::create(body, loc, len.getType(), ivIdx); + + Value gDst = LLVM::GEPOp::create(body, loc, ptrTy, elemTy, dDst, + ArrayRef{iv}); + Value vDst = LLVM::LoadOp::create(body, loc, elemTy, gDst); + if (srcActive) { + Value gSrc = LLVM::GEPOp::create(body, loc, ptrTy, elemTy, dSrc, + ArrayRef{iv}); + Value vSrc = LLVM::LoadOp::create(body, loc, elemTy, gSrc); + Value sum = adt.createAddOp(body, loc, vSrc, vDst); + LLVM::StoreOp::create(body, loc, sum, gSrc); + } + LLVM::StoreOp::create(body, loc, zeroElem, gDst); + + return success(); + } +}; + std::optional findPtrSize(Value ptr) { if (auto allocOp = ptr.getDefiningOp()) return allocOp.getSize(); @@ -467,6 +572,7 @@ void mlir::enzyme::registerLLVMDialectAutoDiffInterface( *context); LLVM::InsertValueOp::attachInterface( *context); + LLVM::MemcpyOp::attachInterface(*context); LLVM::UnreachableOp::template attachInterface< detail::NoopRevAutoDiffInterface>(*context); LLVM::LLVMFuncOp::attachInterface( diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 859b1717e29..106bc2a4caf 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/PassManager.h" @@ -53,7 +54,8 @@ struct DifferentiatePass registry.insert(); + mlir::scf::SCFDialect, mlir::linalg::LinalgDialect, + mlir::enzyme::EnzymeDialect>(); } static std::vector mode_from_fn(FunctionOpInterface fn, diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 4a3697cb74c..e52d2c6f9a1 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -21,6 +21,7 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #define DEBUG_TYPE "enzyme" diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 2cbc20453c5..9505541d344 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> { "complex::ComplexDialect", "cf::ControlFlowDialect", "tensor::TensorDialect", + "scf::SCFDialect", "enzyme::EnzymeDialect", ]; let options = [ @@ -85,6 +86,7 @@ def DifferentiateWrapperPass : Pass<"enzyme-wrap"> { "arith::ArithDialect", "complex::ComplexDialect", "cf::ControlFlowDialect", + "scf::SCFDialect", "enzyme::EnzymeDialect" ]; let options = [ diff --git a/enzyme/test/MLIR/ReverseMode/memcpy.mlir b/enzyme/test/MLIR/ReverseMode/memcpy.mlir new file mode 100644 index 00000000000..4b90efe66d8 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/memcpy.mlir @@ -0,0 +1,64 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +func.func @copy1(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) { + "llvm.intr.memcpy"(%dst, %src, %n) + <{arg_attrs = [{llvm.align = 8 : i64}], isVolatile = false}> + : (!llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @dcopy1(%dst: !llvm.ptr, %ddst: !llvm.ptr, + %src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) { + enzyme.autodiff @copy1(%dst, %ddst, %src, %dsrc, %n) { + activity = [#enzyme, + #enzyme, + #enzyme], + ret_activity = [] + } : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @copy2(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) { + "llvm.intr.memcpy"(%dst, %src, %n) + <{arg_attrs = [{llvm.align = 8 : i64}, {llvm.align = 8 : i64}, {}], + isVolatile = false}> + : (!llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @dcopy2(%dst: !llvm.ptr, %ddst: !llvm.ptr, + %src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) { + enzyme.autodiff @copy2(%dst, %ddst, %src, %dsrc, %n) { + activity = [#enzyme, + #enzyme, + #enzyme], + ret_activity = [] + } : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> () + return +} + +// CHECK-LABEL: func.func private @diffecopy1( +// Forward: the primal memcpy is preserved. +// CHECK: "llvm.intr.memcpy" +// Reverse: n / sizeof(f64) element-wise loop, d_src[i] += d_dst[i]; d_dst[i]=0. +// CHECK: %[[BYTES:.+]] = llvm.mlir.constant(8 : i64) : i64 +// CHECK: llvm.sdiv %{{.+}}, %[[BYTES]] : i64 +// CHECK: arith.index_cast +// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: scf.for +// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64 +// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64 +// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64 +// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64 +// CHECK: %[[SUM:.+]] = arith.addf +// CHECK: llvm.store %[[SUM]], %{{.+}} : f64, !llvm.ptr +// CHECK: llvm.store %[[ZERO]], %{{.+}} : f64, !llvm.ptr + +// CHECK-LABEL: func.func private @diffecopy2( +// CHECK: "llvm.intr.memcpy" +// CHECK: llvm.mlir.constant(8 : i64) : i64 +// CHECK: %[[ZERO2:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: scf.for +// CHECK: %[[SUM2:.+]] = arith.addf +// CHECK: llvm.store %[[SUM2]], %{{.+}} : f64, !llvm.ptr +// CHECK: llvm.store %[[ZERO2]], %{{.+}} : f64, !llvm.ptr