Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -297,6 +299,109 @@ struct InsertValueOpInterfaceReverse
MGradientUtilsReverse *gutils) const {}
};

struct MemcpyOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<MemcpyOpInterfaceReverse,
LLVM::MemcpyOp> {

static Type inferElemType(LLVM::MemcpyOp cp) {
if (auto t = cp->getAttrOfType<TypeAttr>("enzyme.elem_type"))
return t.getValue();
auto walk = [](Value p) -> Type {
for (Operation *user : p.getUsers()) {
if (auto ld = dyn_cast<LLVM::LoadOp>(user))
if (isa<AutoDiffTypeInterface>(ld.getType()))
return ld.getType();
if (auto st = dyn_cast<LLVM::StoreOp>(user))
if (isa<AutoDiffTypeInterface>(st.getValue().getType()))
return st.getValue().getType();
}
return nullptr;
};
if (Type t = walk(cp.getDst()))
return t;
if (Type t = walk(cp.getSrc()))
return t;
return Float64Type::get(cp.getContext());
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
auto cp = cast<LLVM::MemcpyOp>(op);
if (gutils->isConstantValue(cp.getDst()))
return {};
bool srcActive = !gutils->isConstantValue(cp.getSrc());
OpBuilder cb(gutils->getNewFromOriginal(op));
SmallVector<Value> caches;
caches.push_back(
gutils->initAndPushCache(gutils->invertPointerM(cp.getDst(), cb), cb));
caches.push_back(gutils->initAndPushCache(
srcActive ? gutils->invertPointerM(cp.getSrc(), cb)
: gutils->getNewFromOriginal(cp.getSrc()),
cb));
caches.push_back(
gutils->initAndPushCache(gutils->getNewFromOriginal(cp.getLen()), cb));
return caches;
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}

LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto cp = cast<LLVM::MemcpyOp>(op);
if (gutils->isConstantValue(cp.getDst()))
return success();
bool srcActive = !gutils->isConstantValue(cp.getSrc());

Value dDst = gutils->popCache(caches[0], builder);
Value dSrc = gutils->popCache(caches[1], builder);
Value len = gutils->popCache(caches[2], builder);

Type elemTy = inferElemType(cp);
auto adt = dyn_cast<AutoDiffTypeInterface>(elemTy);
if (!adt || !elemTy.isIntOrFloat())
return op->emitError()
<< "memcpy reverse: unsupported element type " << elemTy
<< " (annotate enzyme.elem_type or lower to scalar stores)";

Location loc = op->getLoc();
unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8;

// n_elements = len / sizeof(elemTy)
Value byteSz =
LLVM::ConstantOp::create(builder, loc, len.getType(),
builder.getIntegerAttr(len.getType(), bytes));
Value nInt = LLVM::SDivOp::create(builder, loc, len, byteSz);
Value n =
arith::IndexCastOp::create(builder, loc, builder.getIndexType(), nInt);

Value c0 = arith::ConstantIndexOp::create(builder, loc, 0);
Value c1 = arith::ConstantIndexOp::create(builder, loc, 1);
Value zeroElem = adt.createNullValue(builder, loc);
Type ptrTy = cp.getDst().getType();

auto forOp = scf::ForOp::create(builder, loc, c0, n, c1);
OpBuilder body(forOp.getBody()->getTerminator());
Value ivIdx = forOp.getInductionVar();
Value iv = arith::IndexCastOp::create(body, loc, len.getType(), ivIdx);

Value gDst = LLVM::GEPOp::create(body, loc, ptrTy, elemTy, dDst,
ArrayRef<LLVM::GEPArg>{iv});
Value vDst = LLVM::LoadOp::create(body, loc, elemTy, gDst);
if (srcActive) {
Value gSrc = LLVM::GEPOp::create(body, loc, ptrTy, elemTy, dSrc,
ArrayRef<LLVM::GEPArg>{iv});
Value vSrc = LLVM::LoadOp::create(body, loc, elemTy, gSrc);
Value sum = adt.createAddOp(body, loc, vSrc, vDst);
LLVM::StoreOp::create(body, loc, sum, gSrc);
}
LLVM::StoreOp::create(body, loc, zeroElem, gDst);

return success();
}
};

std::optional<Value> findPtrSize(Value ptr) {
if (auto allocOp = ptr.getDefiningOp<llvm_ext::AllocOp>())
return allocOp.getSize();
Expand Down Expand Up @@ -467,6 +572,7 @@ void mlir::enzyme::registerLLVMDialectAutoDiffInterface(
*context);
LLVM::InsertValueOp::attachInterface<InsertValueOpInterfaceReverse>(
*context);
LLVM::MemcpyOp::attachInterface<MemcpyOpInterfaceReverse>(*context);
LLVM::UnreachableOp::template attachInterface<
detail::NoopRevAutoDiffInterface<LLVM::UnreachableOp>>(*context);
LLVM::LLVMFuncOp::attachInterface<AutoDiffLLVMFuncOpFunctionInterface>(
Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -53,7 +54,8 @@ struct DifferentiatePass

registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
mlir::linalg::LinalgDialect, mlir::enzyme::EnzymeDialect>();
mlir::scf::SCFDialect, mlir::linalg::LinalgDialect,
mlir::enzyme::EnzymeDialect>();
}

static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Interfaces/FunctionInterfaces.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#define DEBUG_TYPE "enzyme"

Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> {
"complex::ComplexDialect",
"cf::ControlFlowDialect",
"tensor::TensorDialect",
"scf::SCFDialect",
"enzyme::EnzymeDialect",
];
let options = [
Expand Down Expand Up @@ -85,6 +86,7 @@ def DifferentiateWrapperPass : Pass<"enzyme-wrap"> {
"arith::ArithDialect",
"complex::ComplexDialect",
"cf::ControlFlowDialect",
"scf::SCFDialect",
"enzyme::EnzymeDialect"
];
let options = [
Expand Down
64 changes: 64 additions & 0 deletions enzyme/test/MLIR/ReverseMode/memcpy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: %eopt --enzyme %s | FileCheck %s

func.func @copy1(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) {
"llvm.intr.memcpy"(%dst, %src, %n)
<{arg_attrs = [{llvm.align = 8 : i64}], isVolatile = false}>
: (!llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @dcopy1(%dst: !llvm.ptr, %ddst: !llvm.ptr,
%src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) {
enzyme.autodiff @copy1(%dst, %ddst, %src, %dsrc, %n) {
activity = [#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_const>],
ret_activity = []
} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @copy2(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) {
"llvm.intr.memcpy"(%dst, %src, %n)
<{arg_attrs = [{llvm.align = 8 : i64}, {llvm.align = 8 : i64}, {}],
isVolatile = false}>
: (!llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @dcopy2(%dst: !llvm.ptr, %ddst: !llvm.ptr,
%src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) {
enzyme.autodiff @copy2(%dst, %ddst, %src, %dsrc, %n) {
activity = [#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_const>],
ret_activity = []
} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> ()
return
}

// CHECK-LABEL: func.func private @diffecopy1(
// Forward: the primal memcpy is preserved.
// CHECK: "llvm.intr.memcpy"
// Reverse: n / sizeof(f64) element-wise loop, d_src[i] += d_dst[i]; d_dst[i]=0.
// CHECK: %[[BYTES:.+]] = llvm.mlir.constant(8 : i64) : i64
// CHECK: llvm.sdiv %{{.+}}, %[[BYTES]] : i64
// CHECK: arith.index_cast
// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: scf.for
// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64
// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64
// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64
// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64
// CHECK: %[[SUM:.+]] = arith.addf
// CHECK: llvm.store %[[SUM]], %{{.+}} : f64, !llvm.ptr
// CHECK: llvm.store %[[ZERO]], %{{.+}} : f64, !llvm.ptr

// CHECK-LABEL: func.func private @diffecopy2(
// CHECK: "llvm.intr.memcpy"
// CHECK: llvm.mlir.constant(8 : i64) : i64
// CHECK: %[[ZERO2:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: scf.for
// CHECK: %[[SUM2:.+]] = arith.addf
// CHECK: llvm.store %[[SUM2]], %{{.+}} : f64, !llvm.ptr
// CHECK: llvm.store %[[ZERO2]], %{{.+}} : f64, !llvm.ptr
Loading