Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 143 additions & 8 deletions mlir/lib/Conversion/LinalgToRock/LinalgToRock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
#include "mlir/Dialect/Rock/Tuning/ConvContext.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -155,10 +156,13 @@ struct ConvFields {

static int64_t getSpatialDim(rock::LinalgConvType type) {
switch (type) {
case rock::LinalgConvType::Conv1dBWDNgchGckh:
case rock::LinalgConvType::Conv1dNgchGkch:
return 1;
case rock::LinalgConvType::Conv2dBWDNgchwGckhw:
case rock::LinalgConvType::Conv2dNgchwGkchw:
return 2;
case rock::LinalgConvType::Conv3dBWDNgchwdGckhwd:
case rock::LinalgConvType::Conv3dNgchwdGkchwd:
return 3;
}
Expand All @@ -167,7 +171,7 @@ static int64_t getSpatialDim(rock::LinalgConvType type) {

/// Set filter_layout, input_layout, and output_layout on a rock.conv op.
/// Layouts match the linalg convention: GKC*, NGC*, NGK*.
static void setConvLayoutAttrs(OpBuilder &builder, rock::ConvOp cop,
static void setConvLayoutAttrs(OpBuilder &builder, Operation *cop,
int64_t spatialDim) {
auto *ctx = builder.getContext();
auto setLayout = [&](StringRef attrName, ArrayRef<StringRef> prefix,
Expand Down Expand Up @@ -247,16 +251,22 @@ struct ConvLinalgConverter final
LogicalResult
matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct BwdConvLinalgConverter final
: public OpConversionPattern<linalg::GenericOp> {
using OpConversionPattern<linalg::GenericOp>::OpConversionPattern;
using OpConversionPattern<linalg::GenericOp>::getTypeConverter;
using OpAdaptor = typename OpConversionPattern<linalg::GenericOp>::OpAdaptor;

private:
FailureOr<ConvFields> isConv(ConversionPatternRewriter &rewriter,
linalg::GenericOp op) const;
LogicalResult
matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace

FailureOr<ConvFields>
ConvLinalgConverter::isConv(ConversionPatternRewriter &rewriter,
linalg::GenericOp op) const {
static FailureOr<ConvFields> isConv(ConversionPatternRewriter &rewriter,
linalg::GenericOp op) {
auto name = op->getAttrOfType<rock::LinalgConvTypeAttr>("conv_op");
if (!name)
return failure();
Expand Down Expand Up @@ -323,6 +333,120 @@ ConvLinalgConverter::isConv(ConversionPatternRewriter &rewriter,
stride, dilation, perfConfig};
}

LogicalResult BwdConvLinalgConverter::matchAndRewrite(
linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
FailureOr<ConvFields> maybeConv = isConv(rewriter, op);
if (failed(maybeConv))
return failure();

ConvFields conv = *maybeConv;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add lit tests ?

Location loc = op.getLoc();

// Making sure this is a backwards conv only
switch (conv.type) {
case rock::LinalgConvType::Conv1dBWDNgchGckh:
case rock::LinalgConvType::Conv2dBWDNgchwGckhw:
case rock::LinalgConvType::Conv3dBWDNgchwdGckhwd:
break;
default:
return failure();
}
bool hasPadding = llvm::any_of(conv.padding, [](Attribute attr) {
return cast<IntegerAttr>(attr).getInt() != 0;
});

RankedTensorType resultShape =
cast<RankedTensorType>(adaptor.getOutputs()[0].getType());
tensor::ExtractSliceOp extractSlicePadding = nullptr;
tensor::CollapseShapeOp collapseGroupPadding = nullptr;
if (hasPadding) {
// To handle padding, the migraphx to linalg pipeline
// and it should look something like the following:
// linalg.generic ins(...) outs(%output)
// %collapse_group = tensor.collapse_shape %output ....
// %output = tensor.extract_slice %collapse_shape ...
if (!op->hasOneUse())
return op.emitError("invalid padding code structure");
collapseGroupPadding = dyn_cast<tensor::CollapseShapeOp>(*op->user_begin());
if (!collapseGroupPadding || !collapseGroupPadding->hasOneUse())
return op.emitError("invalid padding code structure");

extractSlicePadding =
dyn_cast<tensor::ExtractSliceOp>(*collapseGroupPadding->user_begin());
if (!extractSlicePadding)
return op.emitError("invalid padding code structure");

// Take the padded output shape - HWD
auto lastFewShape = cast<RankedTensorType>(extractSlicePadding.getType())
.getShape()
.drop_front(2);
// Take the first NGK
SmallVector<int64_t, 4> newShape(resultShape.getShape().take_front(3));
newShape.insert(newShape.end(), lastFewShape.begin(), lastFewShape.end());
resultShape = RankedTensorType::get(newShape, resultShape.getElementType());
}

Value filter = adaptor.getOperands()[1];
Value input = adaptor.getOperands()[0];
auto output =
bufferization::AllocTensorOp::create(rewriter, loc, resultShape, {});
auto cop = rock::ConvBwdDataOp::create(
rewriter, loc, output.getType(), filter, output, input,
/*features=*/nullptr,
/*blockSize=*/nullptr,
/*gridSize=*/nullptr, conv.padding, conv.stride, conv.dilation,
/*params=*/nullptr, rewriter.getIndexAttr(0),
/*usesV4R1=*/rewriter.getBoolAttr(false));
if(conv.perfConfig)
cop->setAttr("perf_config", conv.perfConfig);
setConvLayoutAttrs(rewriter, cop, getSpatialDim(conv.type));

rock::ConvolutionContext ctx = rock::populateConvContext(cop);
auto strideDims = ctx.getStrideVal();
auto dilationDims = ctx.getDilationVal();
auto filterDims = ctx.getConvDims().fil;
// If there is no zeroinit kernel needed, then there is nothing more we need
// to do here.
if (!rock::isEveryElementWrittenBwdData(strideDims, dilationDims,
filterDims)) {
// FIXME: don't hard code this - see PR#1687
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this FIXME relevant ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, I have hard coded this to the first output with Type funcResType=func.getFunctionType().getResult(0);

My understanding of mlir is that there can be multiple output. In those cases, we have to climb a tree to figure which of the output we need to set rock.prefill attribute?

func::FuncOp func = op->getParentOfType<func::FuncOp>();
Attribute outputInitVal;
Type funcResType = func.getFunctionType().getResult(0);
auto shapedResType = cast<ShapedType>(funcResType);
Type elementType = shapedResType.getElementType();
if (isa<FloatType>(elementType)) {
outputInitVal = rewriter.getFloatAttr(elementType, 0.0);
} else if (isa<IntegerType>(elementType)) {
outputInitVal = rewriter.getIntegerAttr(elementType, 0);
} else {
// We only expect integer and float types for now
assert(false && "Unsupported element type for prefill attribute");
}

func.setResultAttr(0, rock::PrefillAttr::getMnemonic(), outputInitVal);
}

if (hasPadding) {
assert(extractSlicePadding && collapseGroupPadding &&
"these op should have be found from before");
SmallVector<ReassociationIndices, 4> reassocations{{0}, {1, 2}};
llvm::transform(llvm::seq<int64_t>(3, 3 + conv.spatialDim),
std::back_inserter(reassocations),
[](int64_t index) { return ReassociationIndices{index}; });
tensor::CollapseShapeOp collapseGroupDim =
tensor::CollapseShapeOp::create(rewriter, loc, output, reassocations);
rewriter.eraseOp(op);
rewriter.eraseOp(collapseGroupPadding);
rewriter.replaceOp(extractSlicePadding, collapseGroupDim);
return success();
}

rewriter.replaceOp(op, output);
return success();
Comment on lines +392 to +447
}

LogicalResult ConvLinalgConverter::matchAndRewrite(
linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -333,6 +457,16 @@ LogicalResult ConvLinalgConverter::matchAndRewrite(
ConvFields conv = *maybeConv;
Location loc = op.getLoc();

// Making sure this is a forward conv only
switch (conv.type) {
case rock::LinalgConvType::Conv1dNgchGkch:
case rock::LinalgConvType::Conv2dNgchwGkchw:
case rock::LinalgConvType::Conv3dNgchwdGkchwd:
break;
default:
return failure();
}

auto maybeInput =
removePaddingFromInput(rewriter, op, op.getOperand(0), conv.padding);
if (failed(maybeInput))
Expand Down Expand Up @@ -395,5 +529,6 @@ LogicalResult ConvLinalgConverter::matchAndRewrite(
void mlir::rock::populateLinalgToRockConversionPattern(
RewritePatternSet &pattern, MLIRContext *context) {
pattern.add<MatmulConverter<linalg::BatchMatmulOp>,
MatmulConverter<linalg::MatmulOp>, ConvLinalgConverter>(context);
MatmulConverter<linalg::MatmulOp>, ConvLinalgConverter,
BwdConvLinalgConverter>(context);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-opt --linalg-to-rock -verify-diagnostics --split-input-file | FileCheck %s

// Output: NGCHW = 1x1x1x3x3, Filter: GCKHW = 1x1x1x3x3
// stride=1, dilation=1, padding=[1,1,1,1], group=1

#map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d5, d4, d6, d7)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d4, d2 + d6, d3 + d7)>
// CHECK-LABEL: func.func @mlir_bwd_data_conv(
// CHECK-SAME: %[[arg0:.*]]: tensor{{.*}}, %[[arg1:.*]]: tensor
// CHECK-DAG: %[[expanded:.*]] = tensor.expand_shape %[[arg1]]
// CHECK-DAG: %[[expanded_0:.*]] = tensor.expand_shape %[[arg0]]
// CHECK-DAG: %[[alloc:.*]] = bufferization.alloc_tensor
// CHECK-DAG: %[[conv:.*]] = rock.conv_bwd_data(%[[expanded_0]], %[[alloc]], %[[expanded]])
// CHECK-SAME: dilations = [1 : index, 1 : index]
// CHECK-SAME: filter_layout = ["g", "k", "c", "0", "1"]
// CHECK-SAME: input_layout = ["ni", "gi", "ci", "0i", "1i"]
// CHECK-SAME: output_layout = ["no", "go", "ko", "0o", "1o"]
// CHECK-SAME: padding = [1 : index, 1 : index, 1 : index, 1 : index]
// CHECK-SAME: strides = [1 : index, 1 : index]
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[alloc]]
// CHECK-DAG: %[[collapsed_1:.*]] = tensor.collapse_shape %[[collapsed]]
// CHECK-DAG: return %[[collapsed_1]]
func.func @mlir_bwd_data_conv(%arg0: tensor<9xf32>, %arg1: tensor<9xf32>) -> tensor<9xf32> attributes {arch = "##TOKEN_ARCH##", kernel} {
%cst = arith.constant dense<0.000000e+00> : tensor<1x1x1x5x5xf32>
%expanded = tensor.expand_shape %arg1 [[0, 1, 2, 3, 4]] output_shape [1, 1, 1, 3, 3] : tensor<9xf32> into tensor<1x1x1x3x3xf32>
%expanded_0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [1, 1, 1, 3, 3] : tensor<9xf32> into tensor<1x1x1x3x3xf32>
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%expanded, %expanded_0 : tensor<1x1x1x3x3xf32>, tensor<1x1x1x3x3xf32>) outs(%cst : tensor<1x1x1x5x5xf32>) attrs = {conv_op = #rock<LinalgConvType convbwd2d_ngchw_gckhw>, dilation = [1, 1], group = 1 : i64, pad = [1, 1, 1, 1], stride = [1, 1]} {
^bb0(%in: f32, %in_2: f32, %out: f32):
%1 = arith.mulf %in, %in_2 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> tensor<1x1x1x5x5xf32>
%collapsed = tensor.collapse_shape %0 [[0], [1, 2], [3], [4]] : tensor<1x1x1x5x5xf32> into tensor<1x1x5x5xf32>
%extracted_slice = tensor.extract_slice %collapsed[0, 0, 1, 1] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x1x5x5xf32> to tensor<1x1x3x3xf32>
%collapsed_1 = tensor.collapse_shape %extracted_slice [[0, 1, 2, 3]] : tensor<1x1x3x3xf32> into tensor<9xf32>
return %collapsed_1 : tensor<9xf32>
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx-linalg,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s

module {
func.func @mlir_bwd_data_conv(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx-linalg,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s

module {
func.func @mlir_bwd_data_conv(%arg0: !migraphx.shaped<1x512x32x32xf32, 524288x1024x32x1>,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx-linalg,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s

module {
// CHECK: [1 1 1]
Expand Down
Loading