Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,18 @@ def LinalgConv_2D
: I32EnumAttrCase<"Conv2dNgchwGkchw", 1, "conv2d_ngchw_gkchw">;
def LinalgConv_3D
: I32EnumAttrCase<"Conv3dNgchwdGkchwd", 2, "conv3d_ngchwd_gkchwd">;
def LinalgBwdConv1D
: I32EnumAttrCase<"Conv1dBWDNgchGckh", 3, "convbwd1d_ngch_gckh">;
def LinalgBwdConv2D
: I32EnumAttrCase<"Conv2dBWDNgchwGckhw", 4, "convbwd2d_ngchw_gckhw">;
def LinalgBwdConv3D
: I32EnumAttrCase<"Conv3dBWDNgchwdGckhwd", 5, "convbwd3d_ngchwd_gckhwd">;

def LinalgConvType
: Rock_I32Enum<"LinalgConvType",
"Hints for the linalg.generic convolution ops used by linalg-to-rock lowering",
[LinalgConv_1D, LinalgConv_2D, LinalgConv_3D]>;
[LinalgConv_1D, LinalgConv_2D, LinalgConv_3D,
LinalgBwdConv1D, LinalgBwdConv2D, LinalgBwdConv3D]>;

def LinalgConvTypeAttr : EnumAttr<Rock_Dialect, LinalgConvType, "LinalgConvType">;

Expand Down
252 changes: 243 additions & 9 deletions mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ LogicalResult AsUnderlyingShapeConverter::matchAndRewrite(
"input shape is non standard or broadcast; cannot convert this shape");
}

//===----------------------------------------------------------------------===//
// Forward and Backward convolution converter
//===----------------------------------------------------------------------===//
namespace {
struct ConvConverter final
: public OpConversionPattern<migraphx::ConvolutionOp> {
Expand All @@ -124,6 +127,21 @@ struct ConvConverter final
migraphx::ConvolutionOp op, Value input,
Value filter) const;
};

struct BackwardConvConverter final
: public OpConversionPattern<migraphx::ConvolutionBwdDataOp> {
using OpConversionPattern<
migraphx::ConvolutionBwdDataOp>::OpConversionPattern;
using OpConversionPattern<migraphx::ConvolutionBwdDataOp>::getTypeConverter;
using OpAdaptor =
typename OpConversionPattern<migraphx::ConvolutionBwdDataOp>::OpAdaptor;

LogicalResult
matchAndRewrite(migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
private:
LogicalResult emitBackwardConv(ConversionPatternRewriter& rewriter, migraphx::ConvolutionBwdDataOp op, Value input, Value filter) const;
};
} // namespace

// Nice helper function for the linalg.generic op region
Expand All @@ -137,18 +155,19 @@ static void convBodyBuilder(OpBuilder &b, Location loc, ValueRange blockArgs) {
}

/// Emit convolution attributes on the newly created operation.
static void emitConvAttributes(migraphx::ConvolutionOp op, Value convOp,
Attribute strides, Attribute dilation,
Attribute pad, Attribute convOpName) {
static void emitConvAttributes(Operation* migraphxOp, Value convOp, Attribute strides,
Attribute dilation, Attribute pad,
Attribute perfConfig, Attribute groupAttr,
Attribute convOpName) {
Operation *newOp = convOp.getDefiningOp();
newOp->setAttr("pad", pad);
newOp->setAttr("group", op.getGroupAttr());
newOp->setAttr("group", migraphxOp->getAttr("group"));
newOp->setAttr("stride", strides);
newOp->setAttr("dilation", dilation);

// Convert optional attributes
if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config"))
newOp->setAttr("perf_config", attr);
if (migraphxOp->hasAttr("perf_config"))
newOp->setAttr("perf_config", migraphxOp->getAttr("perf_config"));
newOp->setAttr("conv_op", convOpName);
}

Expand Down Expand Up @@ -239,6 +258,118 @@ static Value emitGroupedConv(ConversionPatternRewriter &rewriter, Location loc,
.getResult(0);
}

/// Emit a grouped backward (transposed) convolution of any spatial rank.
/// Input shape: (batch, group, channel, spatial...),
/// filter shape: (group, filter, channel, kernel_spatial...)
///
/// The loop structure mirrors the forward convolution, but with the
/// stride/dilation affine expression on the *output* indexing map:
///
/// clang-format off
/// for n in batch:
/// for g in group:
/// for ih_0 in input_spatial_0:
/// for ih_1 in input_spatial_1:
/// // ...
/// for ih_{dim-1} in input_spatial_{dim-1}:
/// for f in filters:
/// reduction starts here
/// for c in channels: // reduction
/// for kh_0 in kernel_spatial_0: // reduction
/// for kh_1 in kernel_spatial_1: // reduction
/// // ...
/// result[n,g,f, ih_i*stride_i + kh_i*dilation_i, ...] +=
/// input[n,g,c,ih_0,...] * filter[g,c,f,kh_0,...]
/// clang-format on
static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter,
Location loc,
RankedTensorType resultType, Value input,
Value filter, Value zero,
ArrayAttr strides,
ArrayAttr dilation) {
MLIRContext *ctx = rewriter.getContext();
int64_t spatialDim = cast<RankedTensorType>(input.getType()).getRank() - 3;
SmallVector<int64_t, 4> strideVals;
SmallVector<int64_t, 4> dilationVals;
llvm::transform(strides.getValue(), std::back_inserter(strideVals),
[](Attribute attr) {
return cast<IntegerAttr>(attr).getInt();
});
llvm::transform(dilation.getValue(), std::back_inserter(dilationVals),
[](Attribute attr) {
return cast<IntegerAttr>(attr).getInt();
});

// Iteration domain layout (mirrors emitGroupedConv):
// parallel: batch, group, ih_0 .. ih_{dim-1}, filter
// reduction: channel, kh_0 .. kh_{dim-1}
// See the loop structure from above to see where these constants come fron
const int64_t ihStart = 2;
const int64_t filterIdx = ihStart + spatialDim;
const int64_t channelIdx = filterIdx + 1;
const int64_t khStart = channelIdx + 1;
const int64_t totalDims = khStart + spatialDim;
const int64_t numParallel = channelIdx;

SmallVector<AffineExpr> d;
for (int64_t i = 0; i < totalDims; ++i)
d.push_back(getAffineDimExpr(i, ctx));

AffineExpr batch = d[0], group = d[1];
AffineExpr outChannel = d[filterIdx];
AffineExpr inChannel = d[channelIdx];

SmallVector<AffineExpr> inputExprs = {batch, group, inChannel};
for (int64_t i = 0; i < spatialDim; ++i)
inputExprs.push_back(d[ihStart + i]);

SmallVector<AffineExpr> filterExprs = {group, inChannel, outChannel};
for (int64_t i = 0; i < spatialDim; ++i)
filterExprs.push_back(d[khStart + i]);

SmallVector<AffineExpr> outputExprs = {batch, group, outChannel};
for (int64_t i = 0; i < spatialDim; ++i) {
AffineExpr ih_i = d[ihStart + i];
AffineExpr kh_i = d[khStart + i];
outputExprs.push_back(ih_i * strideVals[i] + kh_i * dilationVals[i]);
}

SmallVector<AffineMap> indexingMaps = {
AffineMap::get(totalDims, /*symbolCount=*/0, inputExprs, ctx),
AffineMap::get(totalDims, /*symbolCount=*/0, filterExprs, ctx),
AffineMap::get(totalDims, /*symbolCount=*/0, outputExprs, ctx)};

SmallVector<utils::IteratorType> iteratorTypes(numParallel,
utils::IteratorType::parallel);
iteratorTypes.append(totalDims - numParallel,
utils::IteratorType::reduction);

auto result = linalg::GenericOp::create(rewriter, loc, resultType,
ValueRange{input, filter}, zero,
indexingMaps, iteratorTypes, convBodyBuilder)
.getResult(0);
return result;
}

/// Given the collapsed NF* result type and the group count, return the
/// expanded NGK* result type for the grouped linalg convolution.
static RankedTensorType
expandResultForGroupedConv(RankedTensorType resultType, int64_t group) {
ArrayRef<int64_t> resultShape = resultType.getShape();
int64_t n = resultType.getDimSize(0);
int64_t newF = resultType.getDimSize(1) / group;
assert(resultType.getDimSize(1) % group == 0 &&
"output channel must be divisible by group");

SmallVector<int64_t, 4> newShape;
newShape.push_back(n);
newShape.push_back(group);
newShape.push_back(newF);
newShape.insert(newShape.end(), std::next(resultShape.begin(), 2),
resultShape.end());
return RankedTensorType::get(newShape, resultType.getElementType());
}

LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter,
migraphx::ConvolutionOp op, Value input,
Value filter) const {
Expand Down Expand Up @@ -280,8 +411,8 @@ LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter,
Value result = emitGroupedConv(rewriter, loc, newResultType, input, filter,
zero, strides, dilation);

emitConvAttributes(op, result, strides, dilation,
op.getPaddingAttr(),
emitConvAttributes(op,result, strides, dilation, op.getPaddingAttr(),
op->getAttr("perf_config"), op.getGroupAttr(),
resultConvOpName);

// we must reshape the operand to what the type converter expects
Expand Down Expand Up @@ -444,6 +575,108 @@ ConvConverter::matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor,
return emitConv(rewriter, op, input, filter);
}

LogicalResult
BackwardConvConverter::emitBackwardConv(ConversionPatternRewriter &rewriter,
migraphx::ConvolutionBwdDataOp op,
Value input, Value filter) const {
Location loc = op.getLoc();
int64_t group = op.getGroupAttr().getInt();
int64_t spatialDim = cast<RankedTensorType>(input.getType()).getRank() -
3; // exclude batch (N), group (G), channel (C)
assert(spatialDim >= 1 && spatialDim <= 3 &&
"this should be checked at matchAndRewrite");
// To get the result shape, we must first add the padding
ArrayRef<Attribute> padding = op.getPaddingAttr().getValue();
RankedTensorType originalResult = cast<RankedTensorType>(getTypeConverter()->convertType(op.getResult()));
SmallVector<int64_t, 4> resultShape (originalResult.getShape());
SmallVector<int64_t, 4> lowPads;
SmallVector<int64_t, 4> highPads;
for(int64_t i = 0; i<spatialDim; ++i){
int64_t lowPad = cast<IntegerAttr>(padding[i]).getInt();
int64_t highPad = cast<IntegerAttr>(padding[i+spatialDim]).getInt();
resultShape[2+i] += lowPad + highPad;
lowPads.push_back(lowPad);
highPads.push_back(highPad);
}
RankedTensorType resultType = RankedTensorType::get(resultShape,originalResult.getElementType());
auto newResultType = expandResultForGroupedConv(resultType, group);
Value zero = arith::ConstantOp::create(rewriter, loc, newResultType,
rewriter.getZeroAttr(newResultType));

ArrayAttr strides = op.getStride();
ArrayAttr dilation = op.getDilation();

Value result = emitGroupedBackwardConv(rewriter, loc, newResultType, input,
filter, zero, strides, dilation);
rock::LinalgConvType convType =
(spatialDim == 3) ? rock::LinalgConvType::Conv3dBWDNgchwdGckhwd
: (spatialDim == 2) ? rock::LinalgConvType::Conv2dBWDNgchwGckhw
: rock::LinalgConvType::Conv1dBWDNgchGckh;
emitConvAttributes(
op,result, strides, dilation, op.getPaddingAttr(),
op->getAttr("perf_config"), op.getGroupAttr(),
rock::LinalgConvTypeAttr::get(rewriter.getContext(), convType));

// Collapse result from NGK* back to NK*
SmallVector<ReassociationIndices, 4> reassociation{{0}, {1, 2}};
llvm::for_each(llvm::seq<int64_t>(3, spatialDim + 3),
[&](int64_t index) { reassociation.push_back({index}); });
auto finalResult =
tensor::CollapseShapeOp::create(rewriter, loc, result, reassociation).getResult();

bool hasPadding = llvm::any_of(lowPads, [](int64_t p) { return p != 0; }) ||
llvm::any_of(highPads, [](int64_t p) { return p != 0; });
if (hasPadding) {
int64_t rank = originalResult.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
for (int64_t i = 0; i < rank; ++i)
sizes.push_back(rewriter.getIndexAttr(originalResult.getDimSize(i)));
for (int64_t i = 0; i < spatialDim; ++i)
offsets[2 + i] = rewriter.getIndexAttr(lowPads[i]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that conv_bwd can probably have negative padding values. Can you check and would this work for that case ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this would work if we are having a negative padding. This is because tensor.extract_slice doesn't expect negative values.

What is the semantics for negative padding? Is it just the same as the as we apply a output padding with with the magnitude being those negative values?

finalResult = tensor::ExtractSliceOp::create(
rewriter, loc, originalResult, finalResult, offsets, sizes,
strides)
.getResult();
}

rewriter.replaceOp(op, finalResult);
return success();
}

LogicalResult BackwardConvConverter::matchAndRewrite(
migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Backward convolution lowering is similar to forward convolution and is lowered in three steps:
// 1. Expand the channel dimension into (group, channel_per_group),
// introducing
// a group dimension G. Input becomes NGC* (e.g. NGCL, NGCHW, NGCDHW) and
// filter becomes GFC* (e.g. GFCL, GFCHW, GFCDHW), matching the group attr.
// 2. Emit the grouped linalg convolution (1D/2D/3D), then collapse the
// result back to the original NFHW/NFDHW shape for the type converter.
Location loc = op.getLoc();
Value input = adaptor.getInput();
Value filter = adaptor.getFilter();
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
int64_t dim = inputType.getRank() - 2;
int64_t group = op.getGroupAttr().getInt();

if (dim > 3 || dim < 1) {
return op.emitError(Twine(dim) + "D conv is not supported for now");
}

if (inputType.getElementType() != op.getFilter().getType().getElementType() ||
inputType.getElementType() != op.getResult().getType().getElementType()) {
return op.emitError(
"type casting between operands and result is unsupported for now");
}

input = expandGroupDim(rewriter, loc, input, /*isFilter=*/false, group, dim);
filter = expandGroupDim(rewriter, loc, filter, /*isFilter=*/true, group, dim);

return emitBackwardConv(rewriter, op, input, filter);
}
// TODO: add support for scaled gemms, and migraphx::DeQuantizeLinearConverter
//===----------------------------------------------------------------------===//
// Base kernels (gemm)
Expand Down Expand Up @@ -1134,7 +1367,8 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns(
MultiBroadcastConverter, LiteralConverter, ReshapeConverter,
BooleanElementwiseConverter<migraphx::Greater>,
BooleanElementwiseConverter<migraphx::Equal>, ClipConverter,
TransposeConverter, ConvConverter>(converter, patterns.getContext());
TransposeConverter, ConvConverter,
BackwardConvConverter>(converter, patterns.getContext());
}

void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ func.func @func_quant_convolution(%arg0: !migraphx.shaped<1x1xi8, 1x1>, %arg1: !
func.return
}

func.func @func_backwards_data_convolution(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.shaped<1x1xf32, 1x1>) {
// expected-error @+1{{failed to legalize operation 'migraphx.backwards_data_convolution'}}
migraphx.backwards_data_convolution %arg0, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0], stride = [1, 1]}: <1x1xf32, 1x1>, <1x1xf32, 1x1> -> <1x1xf32, 1x1>
func.return
}

func.func @func_batch_norm_inference(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.shaped<1x1xf32, 1x1>) {
// expected-error @+1{{failed to legalize operation 'migraphx.batch_norm_inference'}}
migraphx.batch_norm_inference %arg0, %arg1, %arg1, %arg1, %arg1 {bn_mode = 0 : i64, epsilon = 1.0e-5 : f32, momentum = 0.9 : f32}:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s

module {
func.func @mlir_bwd_data_conv(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s

module {
func.func @mlir_bwd_data_conv(%arg0: !migraphx.shaped<1x512x32x32xf32, 524288x1024x32x1>,
Expand Down
Loading