-
Notifications
You must be signed in to change notification settings - Fork 54
[AIROCMLIR-445] Lower linalg.generic Backwards Conv into rock.bwd_data_conv
#2274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pr-template-migraphx-to-linalg-conv-backward
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
| #include "mlir/Dialect/Rock/IR/Rock.h" | ||
| #include "mlir/Dialect/Rock/IR/TransformMapBuilder.h" | ||
| #include "mlir/Dialect/Rock/Tuning/ConvContext.h" | ||
| #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
| #include "mlir/IR/AffineExpr.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
|
|
@@ -155,10 +156,13 @@ struct ConvFields { | |
|
|
||
| static int64_t getSpatialDim(rock::LinalgConvType type) { | ||
| switch (type) { | ||
| case rock::LinalgConvType::Conv1dBWDNgchGckh: | ||
| case rock::LinalgConvType::Conv1dNgchGkch: | ||
| return 1; | ||
| case rock::LinalgConvType::Conv2dBWDNgchwGckhw: | ||
| case rock::LinalgConvType::Conv2dNgchwGkchw: | ||
| return 2; | ||
| case rock::LinalgConvType::Conv3dBWDNgchwdGckhwd: | ||
| case rock::LinalgConvType::Conv3dNgchwdGkchwd: | ||
| return 3; | ||
| } | ||
|
|
@@ -167,7 +171,7 @@ static int64_t getSpatialDim(rock::LinalgConvType type) { | |
|
|
||
| /// Set filter_layout, input_layout, and output_layout on a rock.conv op. | ||
| /// Layouts match the linalg convention: GKC*, NGC*, NGK*. | ||
| static void setConvLayoutAttrs(OpBuilder &builder, rock::ConvOp cop, | ||
| static void setConvLayoutAttrs(OpBuilder &builder, Operation *cop, | ||
| int64_t spatialDim) { | ||
| auto *ctx = builder.getContext(); | ||
| auto setLayout = [&](StringRef attrName, ArrayRef<StringRef> prefix, | ||
|
|
@@ -247,16 +251,22 @@ struct ConvLinalgConverter final | |
| LogicalResult | ||
| matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override; | ||
| }; | ||
|
|
||
| struct BwdConvLinalgConverter final | ||
| : public OpConversionPattern<linalg::GenericOp> { | ||
| using OpConversionPattern<linalg::GenericOp>::OpConversionPattern; | ||
| using OpConversionPattern<linalg::GenericOp>::getTypeConverter; | ||
| using OpAdaptor = typename OpConversionPattern<linalg::GenericOp>::OpAdaptor; | ||
|
|
||
| private: | ||
| FailureOr<ConvFields> isConv(ConversionPatternRewriter &rewriter, | ||
| linalg::GenericOp op) const; | ||
| LogicalResult | ||
| matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override; | ||
| }; | ||
| } // namespace | ||
|
|
||
| FailureOr<ConvFields> | ||
| ConvLinalgConverter::isConv(ConversionPatternRewriter &rewriter, | ||
| linalg::GenericOp op) const { | ||
| static FailureOr<ConvFields> isConv(ConversionPatternRewriter &rewriter, | ||
| linalg::GenericOp op) { | ||
| auto name = op->getAttrOfType<rock::LinalgConvTypeAttr>("conv_op"); | ||
| if (!name) | ||
| return failure(); | ||
|
|
@@ -323,6 +333,120 @@ ConvLinalgConverter::isConv(ConversionPatternRewriter &rewriter, | |
| stride, dilation, perfConfig}; | ||
| } | ||
|
|
||
| LogicalResult BwdConvLinalgConverter::matchAndRewrite( | ||
| linalg::GenericOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const { | ||
| FailureOr<ConvFields> maybeConv = isConv(rewriter, op); | ||
| if (failed(maybeConv)) | ||
| return failure(); | ||
|
|
||
| ConvFields conv = *maybeConv; | ||
| Location loc = op.getLoc(); | ||
|
|
||
| // Making sure this is a backwards conv only | ||
| switch (conv.type) { | ||
| case rock::LinalgConvType::Conv1dBWDNgchGckh: | ||
| case rock::LinalgConvType::Conv2dBWDNgchwGckhw: | ||
| case rock::LinalgConvType::Conv3dBWDNgchwdGckhwd: | ||
| break; | ||
| default: | ||
| return failure(); | ||
| } | ||
| bool hasPadding = llvm::any_of(conv.padding, [](Attribute attr) { | ||
| return cast<IntegerAttr>(attr).getInt() != 0; | ||
| }); | ||
|
|
||
| RankedTensorType resultShape = | ||
| cast<RankedTensorType>(adaptor.getOutputs()[0].getType()); | ||
| tensor::ExtractSliceOp extractSlicePadding = nullptr; | ||
| tensor::CollapseShapeOp collapseGroupPadding = nullptr; | ||
| if (hasPadding) { | ||
| // To handle padding, the migraphx to linalg pipeline | ||
| // and it should look something like the following: | ||
| // linalg.generic ins(...) outs(%output) | ||
| // %collapse_group = tensor.collapse_shape %output .... | ||
| // %output = tensor.extract_slice %collapse_shape ... | ||
| if (!op->hasOneUse()) | ||
| return op.emitError("invalid padding code structure"); | ||
| collapseGroupPadding = dyn_cast<tensor::CollapseShapeOp>(*op->user_begin()); | ||
| if (!collapseGroupPadding || !collapseGroupPadding->hasOneUse()) | ||
| return op.emitError("invalid padding code structure"); | ||
|
|
||
| extractSlicePadding = | ||
| dyn_cast<tensor::ExtractSliceOp>(*collapseGroupPadding->user_begin()); | ||
| if (!extractSlicePadding) | ||
| return op.emitError("invalid padding code structure"); | ||
|
|
||
| // Take the padded output shape - HWD | ||
| auto lastFewShape = cast<RankedTensorType>(extractSlicePadding.getType()) | ||
| .getShape() | ||
| .drop_front(2); | ||
| // Take the first NGK | ||
| SmallVector<int64_t, 4> newShape(resultShape.getShape().take_front(3)); | ||
| newShape.insert(newShape.end(), lastFewShape.begin(), lastFewShape.end()); | ||
| resultShape = RankedTensorType::get(newShape, resultShape.getElementType()); | ||
| } | ||
|
|
||
| Value filter = adaptor.getOperands()[1]; | ||
| Value input = adaptor.getOperands()[0]; | ||
| auto output = | ||
| bufferization::AllocTensorOp::create(rewriter, loc, resultShape, {}); | ||
| auto cop = rock::ConvBwdDataOp::create( | ||
| rewriter, loc, output.getType(), filter, output, input, | ||
| /*features=*/nullptr, | ||
| /*blockSize=*/nullptr, | ||
| /*gridSize=*/nullptr, conv.padding, conv.stride, conv.dilation, | ||
| /*params=*/nullptr, rewriter.getIndexAttr(0), | ||
| /*usesV4R1=*/rewriter.getBoolAttr(false)); | ||
Mr-Anyone marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if(conv.perfConfig) | ||
| cop->setAttr("perf_config", conv.perfConfig); | ||
| setConvLayoutAttrs(rewriter, cop, getSpatialDim(conv.type)); | ||
|
|
||
|
|
||
| rock::ConvolutionContext ctx = rock::populateConvContext(cop); | ||
| auto strideDims = ctx.getStrideVal(); | ||
| auto dilationDims = ctx.getDilationVal(); | ||
| auto filterDims = ctx.getConvDims().fil; | ||
| // If there is no zeroinit kernel needed, then there is nothing more we need | ||
| // to do here. | ||
| if (!rock::isEveryElementWrittenBwdData(strideDims, dilationDims, | ||
| filterDims)) { | ||
| // FIXME: don't hard code this - see PR#1687 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this FIXME relevant ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, I have hard coded this to the first output with My understanding of mlir is that there can be multiple output. In those cases, we have to climb a tree to figure which of the output we need to set rock.prefill attribute? |
||
| func::FuncOp func = op->getParentOfType<func::FuncOp>(); | ||
| Attribute outputInitVal; | ||
| Type funcResType = func.getFunctionType().getResult(0); | ||
| auto shapedResType = cast<ShapedType>(funcResType); | ||
| Type elementType = shapedResType.getElementType(); | ||
| if (isa<FloatType>(elementType)) { | ||
| outputInitVal = rewriter.getFloatAttr(elementType, 0.0); | ||
| } else if (isa<IntegerType>(elementType)) { | ||
| outputInitVal = rewriter.getIntegerAttr(elementType, 0); | ||
| } else { | ||
| // We only expect integer and float types for now | ||
| assert(false && "Unsupported element type for prefill attribute"); | ||
Mr-Anyone marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| func.setResultAttr(0, rock::PrefillAttr::getMnemonic(), outputInitVal); | ||
| } | ||
Mr-Anyone marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if (hasPadding) { | ||
| assert(extractSlicePadding && collapseGroupPadding && | ||
| "these op should have be found from before"); | ||
| SmallVector<ReassociationIndices, 4> reassocations{{0}, {1, 2}}; | ||
| llvm::transform(llvm::seq<int64_t>(3, 3 + conv.spatialDim), | ||
| std::back_inserter(reassocations), | ||
| [](int64_t index) { return ReassociationIndices{index}; }); | ||
| tensor::CollapseShapeOp collapseGroupDim = | ||
| tensor::CollapseShapeOp::create(rewriter, loc, output, reassocations); | ||
| rewriter.eraseOp(op); | ||
| rewriter.eraseOp(collapseGroupPadding); | ||
| rewriter.replaceOp(extractSlicePadding, collapseGroupDim); | ||
| return success(); | ||
| } | ||
|
|
||
| rewriter.replaceOp(op, output); | ||
| return success(); | ||
|
Comment on lines
+392
to
+447
|
||
| } | ||
|
|
||
| LogicalResult ConvLinalgConverter::matchAndRewrite( | ||
| linalg::GenericOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const { | ||
|
|
@@ -333,6 +457,16 @@ LogicalResult ConvLinalgConverter::matchAndRewrite( | |
| ConvFields conv = *maybeConv; | ||
| Location loc = op.getLoc(); | ||
|
|
||
| // Making sure this is a forward conv only | ||
| switch (conv.type) { | ||
| case rock::LinalgConvType::Conv1dNgchGkch: | ||
| case rock::LinalgConvType::Conv2dNgchwGkchw: | ||
| case rock::LinalgConvType::Conv3dNgchwdGkchwd: | ||
| break; | ||
| default: | ||
| return failure(); | ||
| } | ||
|
|
||
| auto maybeInput = | ||
| removePaddingFromInput(rewriter, op, op.getOperand(0), conv.padding); | ||
| if (failed(maybeInput)) | ||
|
|
@@ -395,5 +529,6 @@ LogicalResult ConvLinalgConverter::matchAndRewrite( | |
| void mlir::rock::populateLinalgToRockConversionPattern( | ||
| RewritePatternSet &pattern, MLIRContext *context) { | ||
| pattern.add<MatmulConverter<linalg::BatchMatmulOp>, | ||
| MatmulConverter<linalg::MatmulOp>, ConvLinalgConverter>(context); | ||
| MatmulConverter<linalg::MatmulOp>, ConvLinalgConverter, | ||
| BwdConvLinalgConverter>(context); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| // RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-opt --linalg-to-rock -verify-diagnostics --split-input-file | FileCheck %s | ||
|
|
||
| // Output: NGCHW = 1x1x1x3x3, Filter: GCKHW = 1x1x1x3x3 | ||
| // stride=1, dilation=1, padding=[1,1,1,1], group=1 | ||
|
|
||
| #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d2, d3)> | ||
| #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d5, d4, d6, d7)> | ||
| #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d4, d2 + d6, d3 + d7)> | ||
| // CHECK-LABEL: func.func @mlir_bwd_data_conv( | ||
| // CHECK-SAME: %[[arg0:.*]]: tensor{{.*}}, %[[arg1:.*]]: tensor | ||
| // CHECK-DAG: %[[expanded:.*]] = tensor.expand_shape %[[arg1]] | ||
| // CHECK-DAG: %[[expanded_0:.*]] = tensor.expand_shape %[[arg0]] | ||
| // CHECK-DAG: %[[alloc:.*]] = bufferization.alloc_tensor | ||
| // CHECK-DAG: %[[conv:.*]] = rock.conv_bwd_data(%[[expanded_0]], %[[alloc]], %[[expanded]]) | ||
| // CHECK-SAME: dilations = [1 : index, 1 : index] | ||
| // CHECK-SAME: filter_layout = ["g", "k", "c", "0", "1"] | ||
| // CHECK-SAME: input_layout = ["ni", "gi", "ci", "0i", "1i"] | ||
| // CHECK-SAME: output_layout = ["no", "go", "ko", "0o", "1o"] | ||
| // CHECK-SAME: padding = [1 : index, 1 : index, 1 : index, 1 : index] | ||
| // CHECK-SAME: strides = [1 : index, 1 : index] | ||
| // CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[alloc]] | ||
| // CHECK-DAG: %[[collapsed_1:.*]] = tensor.collapse_shape %[[collapsed]] | ||
| // CHECK-DAG: return %[[collapsed_1]] | ||
| func.func @mlir_bwd_data_conv(%arg0: tensor<9xf32>, %arg1: tensor<9xf32>) -> tensor<9xf32> attributes {arch = "##TOKEN_ARCH##", kernel} { | ||
| %cst = arith.constant dense<0.000000e+00> : tensor<1x1x1x5x5xf32> | ||
| %expanded = tensor.expand_shape %arg1 [[0, 1, 2, 3, 4]] output_shape [1, 1, 1, 3, 3] : tensor<9xf32> into tensor<1x1x1x3x3xf32> | ||
| %expanded_0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [1, 1, 1, 3, 3] : tensor<9xf32> into tensor<1x1x1x3x3xf32> | ||
| %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%expanded, %expanded_0 : tensor<1x1x1x3x3xf32>, tensor<1x1x1x3x3xf32>) outs(%cst : tensor<1x1x1x5x5xf32>) attrs = {conv_op = #rock<LinalgConvType convbwd2d_ngchw_gckhw>, dilation = [1, 1], group = 1 : i64, pad = [1, 1, 1, 1], stride = [1, 1]} { | ||
| ^bb0(%in: f32, %in_2: f32, %out: f32): | ||
| %1 = arith.mulf %in, %in_2 : f32 | ||
| %2 = arith.addf %out, %1 : f32 | ||
| linalg.yield %2 : f32 | ||
| } -> tensor<1x1x1x5x5xf32> | ||
| %collapsed = tensor.collapse_shape %0 [[0], [1, 2], [3], [4]] : tensor<1x1x1x5x5xf32> into tensor<1x1x5x5xf32> | ||
| %extracted_slice = tensor.extract_slice %collapsed[0, 0, 1, 1] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x1x5x5xf32> to tensor<1x1x3x3xf32> | ||
| %collapsed_1 = tensor.collapse_shape %extracted_slice [[0, 1, 2, 3]] : tensor<1x1x3x3xf32> into tensor<9xf32> | ||
| return %collapsed_1 : tensor<9xf32> | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add lit tests ?