Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 254 additions & 1 deletion mlir/lib/Conversion/LinalgToRock/LinalgToRock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -139,8 +140,260 @@ LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite(
return success();
}

//===----------------------------------------------------------------------===//
// ConvLinalgConverter: linalg.generic (conv) -> rock.conv
//===----------------------------------------------------------------------===//

namespace {
struct ConvFields {
rock::LinalgConvType type;
int64_t spatialDim;
ArrayAttr padding, stride, dilation;
StringAttr perfConfig;
};
} // namespace

static int64_t getSpatialDim(rock::LinalgConvType type) {
switch (type) {
case rock::LinalgConvType::Conv1dNgchGkch:
return 1;
case rock::LinalgConvType::Conv2dNgchwGkchw:
return 2;
case rock::LinalgConvType::Conv3dNgchwdGkchwd:
return 3;
}
llvm_unreachable("unknown LinalgConvType");
}

/// Set filter_layout, input_layout, and output_layout on a rock.conv op.
/// Layouts match the linalg convention: GKC*, NGC*, NGK*.
static void setConvLayoutAttrs(OpBuilder &builder, rock::ConvOp cop,
int64_t spatialDim) {
auto *ctx = builder.getContext();
auto setLayout = [&](StringRef attrName, ArrayRef<StringRef> prefix,
StringRef suffix) {
SmallVector<Attribute> layout;
for (StringRef dim : prefix)
layout.push_back(StringAttr::get(ctx, dim));
for (int64_t i = 0; i < spatialDim; ++i)
layout.push_back(StringAttr::get(ctx, Twine(i) + suffix));
cop->setAttr(attrName, builder.getArrayAttr(layout));
};
setLayout("filter_layout", {"g", "k", "c"}, "");
setLayout("input_layout", {"ni", "gi", "ci"}, "i");
setLayout("output_layout", {"no", "go", "ko"}, "o");
}

/// Remove the tensor.pad + tensor.expand_shape pattern emitted by
/// migraphx-to-linalg, replacing it with just tensor.expand_shape on the
/// unpadded source. rock.conv handles padding internally.
///
/// Expected IR structure:
/// %padded = tensor.pad %original ...
/// %expanded = tensor.expand_shape %padded ...
/// Replaced with:
/// %expanded = tensor.expand_shape %original ...
static FailureOr<Value>
removePaddingFromInput(ConversionPatternRewriter &rewriter,
linalg::GenericOp op, Value in, ArrayAttr padding) {
bool hasPadding = llvm::any_of(padding.getValue(), [](Attribute attr) {
return cast<IntegerAttr>(attr).getInt() != 0;
});
if (!hasPadding)
return in;

auto expanded = in.getDefiningOp<tensor::ExpandShapeOp>();
if (!expanded) {
op.emitError("unexpected padding code structure");
return failure();
}
auto padded = expanded->getOperand(0).getDefiningOp<tensor::PadOp>();
if (!padded || !padded->hasOneUse()) {
op.emitError("unexpected padding code structure");
return failure();
}

SmallVector<int64_t, 6> resultShape(expanded.getResultType().getShape());
auto lowPad = padded.getStaticLow();
auto highPad = padded.getStaticHigh();
int64_t numPadDims = lowPad.size();
int64_t numExpandedDims = resultShape.size();

// Padding is defined in pre-expand space. The spatial dims are at the
// tail of both tensors (expand_shape only splits an earlier dim), so
// align from the end.
for (int64_t i = numPadDims - 1, j = numExpandedDims - 1; i >= 0 && j >= 0;
--i, --j) {
resultShape[j] -= (lowPad[i] + highPad[i]);
}

RankedTensorType newResultType = RankedTensorType::get(
resultShape, padded.getResultType().getElementType());
Value result = tensor::ExpandShapeOp::create(
rewriter, expanded.getLoc(), newResultType, padded.getOperand(0),
expanded.getReassociationIndices());
rewriter.replaceOp(expanded, result);
rewriter.eraseOp(padded);
Comment on lines +235 to +236
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably good to sanity check/assert that padded only has a single use before it gets deleted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to line 211.

return result;
}

namespace {
struct ConvLinalgConverter final
: public OpConversionPattern<linalg::GenericOp> {
using OpConversionPattern<linalg::GenericOp>::OpConversionPattern;
using OpConversionPattern<linalg::GenericOp>::getTypeConverter;
using OpAdaptor = typename OpConversionPattern<linalg::GenericOp>::OpAdaptor;

LogicalResult
matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

private:
FailureOr<ConvFields> isConv(ConversionPatternRewriter &rewriter,
linalg::GenericOp op) const;
};
} // namespace

FailureOr<ConvFields>
ConvLinalgConverter::isConv(ConversionPatternRewriter &rewriter,
linalg::GenericOp op) const {
auto name = op->getAttrOfType<rock::LinalgConvTypeAttr>("conv_op");
if (!name)
return failure();
rock::LinalgConvType convType = name.getValue();
int64_t spatialDim = getSpatialDim(convType);
// Conv1D is broadcasted into Conv2D. To check for error, we
// use effectiveDim instead because it one more stride/dilation
// in the expanded dimension
int64_t effectiveDim = (spatialDim == 1) ? spatialDim + 1 : spatialDim;

auto convertToArrayAttr =
[&](Attribute arr, ArrayRef<int64_t> dimOneDefaults = {}) -> ArrayAttr {
if(!arr || !isa<ArrayAttr>(arr)){
return ArrayAttr {};
}

SmallVector<int64_t, 4> values;
llvm::transform(
cast<ArrayAttr>(arr).getValue(), std::back_inserter(values),
[](Attribute val) { return cast<IntegerAttr>(val).getInt(); });
// Conv1D is expanded into Conv2D: append identity defaults for the
// extra spatial dimension (stride=1, dilation=1, pad=0).
if (spatialDim == 1)
values.insert(values.end(), dimOneDefaults.begin(), dimOneDefaults.end());
return rewriter.getIndexArrayAttr(values);
};

auto dilation =
convertToArrayAttr(op->getAttr("dilation"), /*dimOneDefaults=*/{1});
auto stride =
convertToArrayAttr(op->getAttr("stride"), /*dimOneDefaults=*/{1});
if (!dilation || !stride || (int64_t)dilation.size() != effectiveDim || (int64_t)stride.size() != effectiveDim){
op.emitError("invalid dilation or stride");
return failure();
}

// Input format: [dim0_low, dim1_low, ..., dim0_high, dim1_high, ...]
// Rock format: [dim0_low, dim0_high, dim1_low, dim1_high, ...]
auto originalPadding = convertToArrayAttr(op->getAttr("pad"));
if(!originalPadding){
op.emitError("no padding found");
return failure();
}
int64_t numSpatial = originalPadding.size() / 2;
SmallVector<Attribute, 8> interleavedPad;
for (int64_t i = 0; i < numSpatial; ++i) {
interleavedPad.push_back(originalPadding[i]);
interleavedPad.push_back(originalPadding[numSpatial + i]);
}
// Conv1D is expanded into Conv2D
if (spatialDim == 1) {
interleavedPad.push_back(rewriter.getIndexAttr(0));
interleavedPad.push_back(rewriter.getIndexAttr(0));
}
auto padding = rewriter.getArrayAttr(interleavedPad);
// note that Conv1D is expanded into Conv2D
if(effectiveDim*2 != (int64_t)padding.size()){
op.emitError("invalid number of padding");
return failure();
}

StringAttr perfConfig = op->getAttrOfType<StringAttr>("perf_config");
return ConvFields{convType, spatialDim, padding,
stride, dilation, perfConfig};
}

LogicalResult ConvLinalgConverter::matchAndRewrite(
linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
FailureOr<ConvFields> maybeConv = isConv(rewriter, op);
if (failed(maybeConv))
return failure();

ConvFields conv = *maybeConv;
Location loc = op.getLoc();

auto maybeInput =
removePaddingFromInput(rewriter, op, op.getOperand(0), conv.padding);
if (failed(maybeInput))
return failure();

Value input = *maybeInput;
Value filter = op.getOperand(1);

// Conv1D is expanded into Conv2D: unmerge the single spatial dim
// into (spatial, W=1) for filter and input.
int64_t effectiveSpatialDim = conv.spatialDim;
if (conv.spatialDim == 1) {
effectiveSpatialDim = 2;
auto filterShape = cast<RankedTensorType>(filter.getType()).getShape();
rock::BottomUpTMBuilder builder(rewriter, {"g", "k", "c", "0"}, filterShape,
loc);
builder.passThrough({"gf", "kf", "cf"}, {0, 1, 2}, {"g", "k", "c"});
builder.unmerge({"0f", "1f"}, {3, 4}, "0", {filterShape[3], 1});
filter = rock::TransformOp::create(rewriter, loc, filter, builder.get());

auto inputShape = cast<RankedTensorType>(input.getType()).getShape();
rock::BottomUpTMBuilder b(rewriter, {"n", "g", "c", "0"}, inputShape, loc);
b.passThrough({"nu", "gu", "cu"}, {0, 1, 2}, {"n", "g", "c"});
b.unmerge({"0u", "1u"}, {3, 4}, "0", {inputShape[3], 1});
input = rock::TransformOp::create(rewriter, loc, input, b.get());
}

RankedTensorType linalgResultType =
cast<RankedTensorType>(op.getResult(0).getType());
SmallVector<int64_t> rockShape(linalgResultType.getShape());
if (conv.spatialDim == 1)
rockShape.push_back(1);
RankedTensorType rockResultType =
RankedTensorType::get(rockShape, linalgResultType.getElementType());
Value output =
bufferization::AllocTensorOp::create(rewriter, loc, rockResultType, {});
auto cop = rock::ConvOp::create(rewriter, loc, rockResultType, filter, input,
output, /*features=*/nullptr,
/*blockSize=*/nullptr, /*gridSize=*/nullptr,
conv.padding, conv.stride, conv.dilation,
/*params=*/nullptr);
// TODO: add splitk
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: why is this a todo ? I am not sure if splitk requires any special treatment during conversion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that we are doing a tree combing to set the mhal.read_access and original_func:

see here:

if (failed(setSplitKAttrs(op, features, rw)))
return failure();

if (conv.perfConfig)
cop->setAttr("perf_config", conv.perfConfig);
setConvLayoutAttrs(rewriter, cop, effectiveSpatialDim);

Value result = cop.getResult();
if (conv.spatialDim == 1) {
auto shape = cast<RankedTensorType>(result.getType()).getShape();
rock::BottomUpTMBuilder b(rewriter, {"n", "g", "k", "0", "1"}, shape, loc);
b.passThrough({"no", "go", "ko"}, {0, 1, 2}, {"n", "g", "k"});
b.merge("0o", 3, {"0", "1"});
result = rock::TransformOp::create(rewriter, loc, result, b.get());
}

rewriter.replaceOp(op, result);
return success();
}

void mlir::rock::populateLinalgToRockConversionPattern(
RewritePatternSet &pattern, MLIRContext *context) {
pattern.add<MatmulConverter<linalg::BatchMatmulOp>,
MatmulConverter<linalg::MatmulOp>>(context);
MatmulConverter<linalg::MatmulOp>, ConvLinalgConverter>(context);
}
11 changes: 9 additions & 2 deletions mlir/lib/Conversion/LinalgToRock/LinalgToRockPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,15 @@ static void populateLinalgToRockDialectConversion(ConversionTarget &target) {
if (!linalgOp) {
return std::nullopt;
}
return linalg::isElementwise(linalgOp) || isa<linalg::GenericOp>(op) ||
isa<linalg::YieldOp>(op);

// Convolution has attributes.
linalg::GenericOp castedOp = dyn_cast<linalg::GenericOp>(op);
if (castedOp && castedOp->hasAttr("conv_op")) {
return false;
}
Comment on lines +52 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory this could lead to non-conv linalg.generics to be accidentally picked up here. Could we look for generics that were marked with the conv_op attribute instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could work for now.

Originally, my intention was to catch all linalg.generic that has a reduction iterator since rock-regularize doesn't support it.

for (utils::IteratorType iterType : lgop.getIteratorTypesArray()) {
if (!linalg::isParallelIterator(iterType))
return lgop.emitError("Only fully parallel supported");
}


return linalg::isElementwise(linalgOp) || isa<linalg::YieldOp>(op) ||
castedOp;
});
}

Expand Down
Loading