-
Notifications
You must be signed in to change notification settings - Fork 55
[AIROCMLIR-445] Lower linalg.generic convolution into rock
#2252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pr-template-migraphx-to-linalg-conv
Are you sure you want to change the base?
Changes from all commits
a8fefdd
a6c0222
6c7dabc
7abde9b
f58acdd
7de89e5
ebb7a41
6225bdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||||||
| #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||||
| #include "mlir/Dialect/Rock/IR/Rock.h" | ||||||
| #include "mlir/Dialect/Rock/IR/TransformMapBuilder.h" | ||||||
| #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||||||
| #include "mlir/IR/AffineExpr.h" | ||||||
| #include "mlir/IR/PatternMatch.h" | ||||||
|
|
@@ -139,8 +140,260 @@ LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite( | |||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| //===----------------------------------------------------------------------===// | ||||||
| // ConvLinalgConverter: linalg.generic (conv) -> rock.conv | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| namespace { | ||||||
| struct ConvFields { | ||||||
| rock::LinalgConvType type; | ||||||
| int64_t spatialDim; | ||||||
| ArrayAttr padding, stride, dilation; | ||||||
| StringAttr perfConfig; | ||||||
| }; | ||||||
| } // namespace | ||||||
|
|
||||||
| static int64_t getSpatialDim(rock::LinalgConvType type) { | ||||||
| switch (type) { | ||||||
| case rock::LinalgConvType::Conv1dNgchGkch: | ||||||
| return 1; | ||||||
| case rock::LinalgConvType::Conv2dNgchwGkchw: | ||||||
| return 2; | ||||||
| case rock::LinalgConvType::Conv3dNgchwdGkchwd: | ||||||
| return 3; | ||||||
| } | ||||||
| llvm_unreachable("unknown LinalgConvType"); | ||||||
| } | ||||||
|
|
||||||
| /// Set filter_layout, input_layout, and output_layout on a rock.conv op. | ||||||
| /// Layouts match the linalg convention: GKC*, NGC*, NGK*. | ||||||
| static void setConvLayoutAttrs(OpBuilder &builder, rock::ConvOp cop, | ||||||
| int64_t spatialDim) { | ||||||
| auto *ctx = builder.getContext(); | ||||||
| auto setLayout = [&](StringRef attrName, ArrayRef<StringRef> prefix, | ||||||
| StringRef suffix) { | ||||||
| SmallVector<Attribute> layout; | ||||||
| for (StringRef dim : prefix) | ||||||
| layout.push_back(StringAttr::get(ctx, dim)); | ||||||
| for (int64_t i = 0; i < spatialDim; ++i) | ||||||
| layout.push_back(StringAttr::get(ctx, Twine(i) + suffix)); | ||||||
| cop->setAttr(attrName, builder.getArrayAttr(layout)); | ||||||
| }; | ||||||
| setLayout("filter_layout", {"g", "k", "c"}, ""); | ||||||
| setLayout("input_layout", {"ni", "gi", "ci"}, "i"); | ||||||
| setLayout("output_layout", {"no", "go", "ko"}, "o"); | ||||||
| } | ||||||
|
|
||||||
| /// Remove the tensor.pad + tensor.expand_shape pattern emitted by | ||||||
| /// migraphx-to-linalg, replacing it with just tensor.expand_shape on the | ||||||
| /// unpadded source. rock.conv handles padding internally. | ||||||
| /// | ||||||
| /// Expected IR structure: | ||||||
| /// %padded = tensor.pad %original ... | ||||||
| /// %expanded = tensor.expand_shape %padded ... | ||||||
| /// Replaced with: | ||||||
| /// %expanded = tensor.expand_shape %original ... | ||||||
| static FailureOr<Value> | ||||||
| removePaddingFromInput(ConversionPatternRewriter &rewriter, | ||||||
| linalg::GenericOp op, Value in, ArrayAttr padding) { | ||||||
| bool hasPadding = llvm::any_of(padding.getValue(), [](Attribute attr) { | ||||||
| return cast<IntegerAttr>(attr).getInt() != 0; | ||||||
| }); | ||||||
| if (!hasPadding) | ||||||
| return in; | ||||||
|
|
||||||
| auto expanded = in.getDefiningOp<tensor::ExpandShapeOp>(); | ||||||
| if (!expanded) { | ||||||
| op.emitError("unexpected padding code structure"); | ||||||
| return failure(); | ||||||
| } | ||||||
| auto padded = expanded->getOperand(0).getDefiningOp<tensor::PadOp>(); | ||||||
| if (!padded || !padded->hasOneUse()) { | ||||||
| op.emitError("unexpected padding code structure"); | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| SmallVector<int64_t, 6> resultShape(expanded.getResultType().getShape()); | ||||||
| auto lowPad = padded.getStaticLow(); | ||||||
| auto highPad = padded.getStaticHigh(); | ||||||
| int64_t numPadDims = lowPad.size(); | ||||||
| int64_t numExpandedDims = resultShape.size(); | ||||||
|
|
||||||
| // Padding is defined in pre-expand space. The spatial dims are at the | ||||||
| // tail of both tensors (expand_shape only splits an earlier dim), so | ||||||
| // align from the end. | ||||||
| for (int64_t i = numPadDims - 1, j = numExpandedDims - 1; i >= 0 && j >= 0; | ||||||
| --i, --j) { | ||||||
| resultShape[j] -= (lowPad[i] + highPad[i]); | ||||||
| } | ||||||
|
|
||||||
| RankedTensorType newResultType = RankedTensorType::get( | ||||||
| resultShape, padded.getResultType().getElementType()); | ||||||
| Value result = tensor::ExpandShapeOp::create( | ||||||
| rewriter, expanded.getLoc(), newResultType, padded.getOperand(0), | ||||||
| expanded.getReassociationIndices()); | ||||||
| rewriter.replaceOp(expanded, result); | ||||||
| rewriter.eraseOp(padded); | ||||||
| return result; | ||||||
| } | ||||||
|
|
||||||
| namespace { | ||||||
| struct ConvLinalgConverter final | ||||||
| : public OpConversionPattern<linalg::GenericOp> { | ||||||
| using OpConversionPattern<linalg::GenericOp>::OpConversionPattern; | ||||||
| using OpConversionPattern<linalg::GenericOp>::getTypeConverter; | ||||||
| using OpAdaptor = typename OpConversionPattern<linalg::GenericOp>::OpAdaptor; | ||||||
|
|
||||||
| LogicalResult | ||||||
| matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, | ||||||
| ConversionPatternRewriter &rewriter) const override; | ||||||
|
|
||||||
| private: | ||||||
| FailureOr<ConvFields> isConv(ConversionPatternRewriter &rewriter, | ||||||
| linalg::GenericOp op) const; | ||||||
| }; | ||||||
| } // namespace | ||||||
|
|
||||||
| FailureOr<ConvFields> | ||||||
| ConvLinalgConverter::isConv(ConversionPatternRewriter &rewriter, | ||||||
| linalg::GenericOp op) const { | ||||||
| auto name = op->getAttrOfType<rock::LinalgConvTypeAttr>("conv_op"); | ||||||
| if (!name) | ||||||
| return failure(); | ||||||
| rock::LinalgConvType convType = name.getValue(); | ||||||
| int64_t spatialDim = getSpatialDim(convType); | ||||||
| // Conv1D is broadcasted into Conv2D. To check for error, we | ||||||
| // use effectiveDim instead because it one more stride/dilation | ||||||
| // in the expanded dimension | ||||||
| int64_t effectiveDim = (spatialDim == 1) ? spatialDim + 1 : spatialDim; | ||||||
|
|
||||||
| auto convertToArrayAttr = | ||||||
| [&](Attribute arr, ArrayRef<int64_t> dimOneDefaults = {}) -> ArrayAttr { | ||||||
| if(!arr || !isa<ArrayAttr>(arr)){ | ||||||
| return ArrayAttr {}; | ||||||
| } | ||||||
|
|
||||||
| SmallVector<int64_t, 4> values; | ||||||
| llvm::transform( | ||||||
| cast<ArrayAttr>(arr).getValue(), std::back_inserter(values), | ||||||
| [](Attribute val) { return cast<IntegerAttr>(val).getInt(); }); | ||||||
Mr-Anyone marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| // Conv1D is expanded into Conv2D: append identity defaults for the | ||||||
| // extra spatial dimension (stride=1, dilation=1, pad=0). | ||||||
| if (spatialDim == 1) | ||||||
| values.insert(values.end(), dimOneDefaults.begin(), dimOneDefaults.end()); | ||||||
| return rewriter.getIndexArrayAttr(values); | ||||||
| }; | ||||||
|
|
||||||
| auto dilation = | ||||||
| convertToArrayAttr(op->getAttr("dilation"), /*dimOneDefaults=*/{1}); | ||||||
| auto stride = | ||||||
| convertToArrayAttr(op->getAttr("stride"), /*dimOneDefaults=*/{1}); | ||||||
| if (!dilation || !stride || (int64_t)dilation.size() != effectiveDim || (int64_t)stride.size() != effectiveDim){ | ||||||
| op.emitError("invalid dilation or stride"); | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| // Input format: [dim0_low, dim1_low, ..., dim0_high, dim1_high, ...] | ||||||
| // Rock format: [dim0_low, dim0_high, dim1_low, dim1_high, ...] | ||||||
| auto originalPadding = convertToArrayAttr(op->getAttr("pad")); | ||||||
| if(!originalPadding){ | ||||||
| op.emitError("no padding found"); | ||||||
| return failure(); | ||||||
| } | ||||||
| int64_t numSpatial = originalPadding.size() / 2; | ||||||
| SmallVector<Attribute, 8> interleavedPad; | ||||||
| for (int64_t i = 0; i < numSpatial; ++i) { | ||||||
| interleavedPad.push_back(originalPadding[i]); | ||||||
| interleavedPad.push_back(originalPadding[numSpatial + i]); | ||||||
| } | ||||||
| // Conv1D is expanded into Conv2D | ||||||
| if (spatialDim == 1) { | ||||||
| interleavedPad.push_back(rewriter.getIndexAttr(0)); | ||||||
| interleavedPad.push_back(rewriter.getIndexAttr(0)); | ||||||
| } | ||||||
| auto padding = rewriter.getArrayAttr(interleavedPad); | ||||||
| // note that Conv1D is expanded into Conv2D | ||||||
| if(effectiveDim*2 != (int64_t)padding.size()){ | ||||||
| op.emitError("invalid number of padding"); | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| StringAttr perfConfig = op->getAttrOfType<StringAttr>("perf_config"); | ||||||
| return ConvFields{convType, spatialDim, padding, | ||||||
| stride, dilation, perfConfig}; | ||||||
| } | ||||||
|
|
||||||
| LogicalResult ConvLinalgConverter::matchAndRewrite( | ||||||
| linalg::GenericOp op, OpAdaptor adaptor, | ||||||
| ConversionPatternRewriter &rewriter) const { | ||||||
| FailureOr<ConvFields> maybeConv = isConv(rewriter, op); | ||||||
| if (failed(maybeConv)) | ||||||
| return failure(); | ||||||
|
|
||||||
| ConvFields conv = *maybeConv; | ||||||
| Location loc = op.getLoc(); | ||||||
|
|
||||||
| auto maybeInput = | ||||||
| removePaddingFromInput(rewriter, op, op.getOperand(0), conv.padding); | ||||||
| if (failed(maybeInput)) | ||||||
| return failure(); | ||||||
|
|
||||||
| Value input = *maybeInput; | ||||||
| Value filter = op.getOperand(1); | ||||||
|
|
||||||
| // Conv1D is expanded into Conv2D: unmerge the single spatial dim | ||||||
| // into (spatial, W=1) for filter and input. | ||||||
| int64_t effectiveSpatialDim = conv.spatialDim; | ||||||
| if (conv.spatialDim == 1) { | ||||||
| effectiveSpatialDim = 2; | ||||||
| auto filterShape = cast<RankedTensorType>(filter.getType()).getShape(); | ||||||
| rock::BottomUpTMBuilder builder(rewriter, {"g", "k", "c", "0"}, filterShape, | ||||||
| loc); | ||||||
| builder.passThrough({"gf", "kf", "cf"}, {0, 1, 2}, {"g", "k", "c"}); | ||||||
| builder.unmerge({"0f", "1f"}, {3, 4}, "0", {filterShape[3], 1}); | ||||||
| filter = rock::TransformOp::create(rewriter, loc, filter, builder.get()); | ||||||
|
|
||||||
| auto inputShape = cast<RankedTensorType>(input.getType()).getShape(); | ||||||
| rock::BottomUpTMBuilder b(rewriter, {"n", "g", "c", "0"}, inputShape, loc); | ||||||
| b.passThrough({"nu", "gu", "cu"}, {0, 1, 2}, {"n", "g", "c"}); | ||||||
| b.unmerge({"0u", "1u"}, {3, 4}, "0", {inputShape[3], 1}); | ||||||
| input = rock::TransformOp::create(rewriter, loc, input, b.get()); | ||||||
| } | ||||||
|
|
||||||
| RankedTensorType linalgResultType = | ||||||
| cast<RankedTensorType>(op.getResult(0).getType()); | ||||||
| SmallVector<int64_t> rockShape(linalgResultType.getShape()); | ||||||
| if (conv.spatialDim == 1) | ||||||
| rockShape.push_back(1); | ||||||
| RankedTensorType rockResultType = | ||||||
| RankedTensorType::get(rockShape, linalgResultType.getElementType()); | ||||||
| Value output = | ||||||
| bufferization::AllocTensorOp::create(rewriter, loc, rockResultType, {}); | ||||||
| auto cop = rock::ConvOp::create(rewriter, loc, rockResultType, filter, input, | ||||||
| output, /*features=*/nullptr, | ||||||
| /*blockSize=*/nullptr, /*gridSize=*/nullptr, | ||||||
| conv.padding, conv.stride, conv.dilation, | ||||||
| /*params=*/nullptr); | ||||||
| // TODO: add splitk | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: why is this a todo ? I am not sure if splitk requires any special treatment during conversion
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that we are doing a tree combing to set the mhal.read_access and original_func: see here: rocMLIR/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp Lines 637 to 638 in d3a76d7
|
||||||
| if (conv.perfConfig) | ||||||
| cop->setAttr("perf_config", conv.perfConfig); | ||||||
| setConvLayoutAttrs(rewriter, cop, effectiveSpatialDim); | ||||||
|
|
||||||
| Value result = cop.getResult(); | ||||||
| if (conv.spatialDim == 1) { | ||||||
| auto shape = cast<RankedTensorType>(result.getType()).getShape(); | ||||||
| rock::BottomUpTMBuilder b(rewriter, {"n", "g", "k", "0", "1"}, shape, loc); | ||||||
| b.passThrough({"no", "go", "ko"}, {0, 1, 2}, {"n", "g", "k"}); | ||||||
| b.merge("0o", 3, {"0", "1"}); | ||||||
| result = rock::TransformOp::create(rewriter, loc, result, b.get()); | ||||||
| } | ||||||
|
|
||||||
| rewriter.replaceOp(op, result); | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| void mlir::rock::populateLinalgToRockConversionPattern( | ||||||
| RewritePatternSet &pattern, MLIRContext *context) { | ||||||
| pattern.add<MatmulConverter<linalg::BatchMatmulOp>, | ||||||
| MatmulConverter<linalg::MatmulOp>>(context); | ||||||
| MatmulConverter<linalg::MatmulOp>, ConvLinalgConverter>(context); | ||||||
| } | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -47,8 +47,15 @@ static void populateLinalgToRockDialectConversion(ConversionTarget &target) { | |||||||||
| if (!linalgOp) { | ||||||||||
| return std::nullopt; | ||||||||||
| } | ||||||||||
| return linalg::isElementwise(linalgOp) || isa<linalg::GenericOp>(op) || | ||||||||||
| isa<linalg::YieldOp>(op); | ||||||||||
|
|
||||||||||
| // Convolution has attributes. | ||||||||||
| linalg::GenericOp castedOp = dyn_cast<linalg::GenericOp>(op); | ||||||||||
| if (castedOp && castedOp->hasAttr("conv_op")) { | ||||||||||
| return false; | ||||||||||
| } | ||||||||||
|
Comment on lines
+52
to
+55
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In theory this could lead to non-conv
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That could work for now. Originally, my intention was to catch all linalg.generic that has a reduction iterator since rocMLIR/mlir/lib/Dialect/Rock/Transforms/Regularize.cpp Lines 138 to 141 in 51df5f4
|
||||||||||
|
|
||||||||||
| return linalg::isElementwise(linalgOp) || isa<linalg::YieldOp>(op) || | ||||||||||
| castedOp; | ||||||||||
| }); | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably good to sanity check/assert that padded only has a single use before it gets deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved it to line 211.