From 4939adf1fcf5e9c28e22a624ddca14bd366b324a Mon Sep 17 00:00:00 2001 From: Vincent Date: Wed, 25 Feb 2026 21:39:12 +0000 Subject: [PATCH 1/2] [AIROCMLIR-445] Lower 'migraphx.backwards_data_convolution` --- .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 9 +- .../MIGraphXToLinalg/MIGraphXToLinalg.cpp | 241 +++++++++++++++++- .../migraphx-to-linalg-not-implemented.mlir | 6 - .../mixr-bwd-data-conv-asymmetric-stride.mlir | 1 + .../mixr-bwd-data-conv-dilation1-stride2.mlir | 1 + .../mixr-bwd-data-conv-dilation2-stride1.mlir | 1 + .../pr-e2e/mixr-bwd-data-conv-padding1.mlir | 1 + .../mixr-bwd-data-conv-stride2-dilation2.mlir | 1 + .../pr-e2e/mixr-bwd-data-conv-stride32.mlir | 1 + .../fusion/pr-e2e/mixr-bwd-data-conv.mlir | 1 + .../fusion/pr-e2e/mixr-bwd-data-conv3d.mlir | 2 + 11 files changed, 249 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index a469deb67067..fa4a48b60152 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -65,11 +65,18 @@ def LinalgConv_2D : I32EnumAttrCase<"Conv2dNgchwGkchw", 1, "conv2d_ngchw_gkchw">; def LinalgConv_3D : I32EnumAttrCase<"Conv3dNgchwdGkchwd", 2, "conv3d_ngchwd_gkchwd">; +def LinalgBwdConv1D + : I32EnumAttrCase<"Conv1dBWDNgchGckh", 3, "convbwd1d_ngch_gckh">; +def LinalgBwdConv2D + : I32EnumAttrCase<"Conv2dBWDNgchwGckhw", 4, "convbwd2d_ngchw_gckhw">; +def LinalgBwdConv3D + : I32EnumAttrCase<"Conv3dBWDNgchwdGckhwd", 5, "convbwd3d_ngchd_gckhw">; def LinalgConvType : Rock_I32Enum<"LinalgConvType", "Hints for the linalg.generic convolution ops used by linalg-to-rock lowering", - [LinalgConv_1D, LinalgConv_2D, LinalgConv_3D]>; + [LinalgConv_1D, LinalgConv_2D, LinalgConv_3D, + LinalgBwdConv1D, LinalgBwdConv2D, LinalgBwdConv3D]>; def LinalgConvTypeAttr : EnumAttr; diff --git a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp index cb972bd39b83..55408230940d 100644 --- a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp +++ b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp @@ -107,6 +107,9 @@ LogicalResult AsUnderlyingShapeConverter::matchAndRewrite( "input shape is non standard or broadcast; cannot convert this shape"); } +//===----------------------------------------------------------------------===// +// Forward and Backward convolution converter +//===----------------------------------------------------------------------===// namespace { struct ConvConverter final : public OpConversionPattern { @@ -124,6 +127,21 @@ struct ConvConverter final migraphx::ConvolutionOp op, Value input, Value filter) const; }; + +struct BackwardConvConverter final + : public OpConversionPattern { + using OpConversionPattern< + migraphx::ConvolutionBwdDataOp>::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using OpAdaptor = + typename OpConversionPattern::OpAdaptor; + + LogicalResult + matchAndRewrite(migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +private: + LogicalResult emitBackwardConv(ConversionPatternRewriter& rewriter, migraphx::ConvolutionBwdDataOp op, Value input, Value filter) const; +}; } // namespace // Nice helper function for the linalg.generic op region @@ -137,18 +155,19 @@ static void convBodyBuilder(OpBuilder &b, Location loc, ValueRange blockArgs) { } /// Emit convolution attributes on the newly created operation. -static void emitConvAttributes(migraphx::ConvolutionOp op, Value convOp, - Attribute strides, Attribute dilation, - Attribute pad, Attribute convOpName) { +static void emitConvAttributes(Value convOp, Attribute strides, + Attribute dilation, Attribute pad, + Attribute perfConfig, Attribute groupAttr, + Attribute convOpName) { Operation *newOp = convOp.getDefiningOp(); newOp->setAttr("pad", pad); - newOp->setAttr("group", op.getGroupAttr()); + newOp->setAttr("group", groupAttr); newOp->setAttr("stride", strides); newOp->setAttr("dilation", dilation); // Convert optional attributes - if (auto attr = (*op).template getAttrOfType("perf_config")) - newOp->setAttr("perf_config", attr); + if (perfConfig) + newOp->setAttr("perf_config", perfConfig); newOp->setAttr("conv_op", convOpName); } @@ -239,6 +258,107 @@ static Value emitGroupedConv(ConversionPatternRewriter &rewriter, Location loc, .getResult(0); } +/// Emit a grouped backward (transposed) convolution of any spatial rank. +/// Input shape: (batch, group, channel, spatial...), +/// filter shape: (group, filter, channel, kernel_spatial...) +/// +/// The loop structure mirrors the forward convolution, but with the +/// stride/dilation affine expression on the *output* indexing map: +/// +/// clang-format off +/// for n in batch: +/// for g in group: +/// for ih_0 in input_spatial_0: +/// for ih_1 in input_spatial_1: +/// // ... +/// for ih_{dim-1} in input_spatial_{dim-1}: +/// for f in filters: +/// reduction starts here +/// for c in channels: // reduction +/// for kh_0 in kernel_spatial_0: // reduction +/// for kh_1 in kernel_spatial_1: // reduction +/// // ... +/// result[n,g,f, ih_i*stride_i + kh_i*dilation_i, ...] += +/// input[n,g,c,ih_0,...] * filter[c,g,f,kh_0,...] +/// clang-format on +static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter, + Location loc, + RankedTensorType resultType, Value input, + Value filter, Value zero, + ArrayAttr strides, + ArrayAttr dilation) { + MLIRContext *ctx = rewriter.getContext(); + int64_t dim = cast(input.getType()).getRank() - 3; + SmallVector strideVals; + SmallVector dilationVals; + llvm::transform(strides.getValue(), std::back_inserter(strideVals), + [](Attribute attr) { + return cast(attr).getInt(); + }); + llvm::transform(dilation.getValue(), std::back_inserter(dilationVals), + [](Attribute attr) { + return cast(attr).getInt(); + }); + + // Iteration domain layout (mirrors emitGroupedConv): + // parallel: batch, group, filter, ih_0 .. ih_{dim-1} + // reduction: channel, kh_0 .. kh_{dim-1} + int64_t totalDims = 4 + 2 * dim; + SmallVector d; + for (int64_t i = 0; i < totalDims; ++i) + d.push_back(getAffineDimExpr(i, ctx)); + + AffineExpr batch = d[0], group = d[1], filterExpr = d[dim+2]; + AffineExpr channel = d[3 + dim]; + + SmallVector inputExprs = {batch, group, channel}; + for (int64_t i = 0; i < dim; ++i) + inputExprs.push_back(d[2 + i]); + + SmallVector filterExprs = {group, channel, filterExpr}; + for (int64_t i = 0; i < dim; ++i) + filterExprs.push_back(d[4 + dim + i]); + + SmallVector outputExprs = {batch, group, filterExpr}; + for (int64_t i = 0; i < dim; ++i) + outputExprs.push_back(d[2 + i] * strideVals[i] + + d[4 + dim + i] * dilationVals[i]); + + SmallVector indexingMaps = { + AffineMap::get(totalDims, /*symbolCount=*/0, inputExprs, ctx), + AffineMap::get(totalDims, /*symbolCount=*/0, filterExprs, ctx), + AffineMap::get(totalDims, /*symbolCount=*/0, outputExprs, ctx)}; + + SmallVector iteratorTypes(3 + dim, + utils::IteratorType::parallel); + iteratorTypes.append(1 + dim, utils::IteratorType::reduction); + + auto result = linalg::GenericOp::create(rewriter, loc, resultType, + ValueRange{input, filter}, zero, + indexingMaps, iteratorTypes, convBodyBuilder) + .getResult(0); + return result; +} + +/// Given the collapsed NF* result type and the group count, return the +/// expanded NGF* result type for the grouped linalg convolution. +static RankedTensorType +expandResultForGroupedConv(RankedTensorType resultType, int64_t group) { + ArrayRef resultShape = resultType.getShape(); + int64_t n = resultType.getDimSize(0); + int64_t newF = resultType.getDimSize(1) / group; + assert(resultType.getDimSize(1) % group == 0 && + "output channel must be divisible by group"); + + SmallVector newShape; + newShape.push_back(n); + newShape.push_back(group); + newShape.push_back(newF); + newShape.insert(newShape.end(), std::next(resultShape.begin(), 2), + resultShape.end()); + return RankedTensorType::get(newShape, resultType.getElementType()); +} + LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter, migraphx::ConvolutionOp op, Value input, Value filter) const { @@ -280,8 +400,8 @@ LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter, Value result = emitGroupedConv(rewriter, loc, newResultType, input, filter, zero, strides, dilation); - emitConvAttributes(op, result, strides, dilation, - op.getPaddingAttr(), + emitConvAttributes(result, strides, dilation, op.getPaddingAttr(), + op->getAttr("perf_config"), op.getGroupAttr(), resultConvOpName); // we must reshape the operand to what the type converter expects @@ -444,6 +564,108 @@ ConvConverter::matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor, return emitConv(rewriter, op, input, filter); } +LogicalResult +BackwardConvConverter::emitBackwardConv(ConversionPatternRewriter &rewriter, + migraphx::ConvolutionBwdDataOp op, + Value input, Value filter) const { + Location loc = op.getLoc(); + int64_t group = op.getGroupAttr().getInt(); + int64_t spatialDim = cast(input.getType()).getRank() - + 3; // exclude batch (N), group (G), channel (C) + assert(spatialDim >= 1 && spatialDim <= 3 && + "this should be checked at matchAndRewrite"); + // To get the result shape, we must first add the padding + ArrayRef padding = op.getPaddingAttr().getValue(); + RankedTensorType originalResult = cast(getTypeConverter()->convertType(op.getResult())); + SmallVector resultShape (originalResult.getShape()); + SmallVector lowPads; + SmallVector highPads; + for(int64_t i = 0; i(padding[i]).getInt(); + int64_t highPad = cast(padding[i+spatialDim]).getInt(); + resultShape[2+i] += lowPad + highPad; + lowPads.push_back(lowPad); + highPads.push_back(highPad); + } + RankedTensorType resultType = RankedTensorType::get(resultShape,originalResult.getElementType()); + auto newResultType = expandResultForGroupedConv(resultType, group); + Value zero = arith::ConstantOp::create(rewriter, loc, newResultType, + rewriter.getZeroAttr(newResultType)); + + ArrayAttr strides = op.getStride(); + ArrayAttr dilation = op.getDilation(); + + Value result = emitGroupedBackwardConv(rewriter, loc, newResultType, input, + filter, zero, strides, dilation); + rock::LinalgConvType convType = + (spatialDim == 3) ? rock::LinalgConvType::Conv3dBWDNgchwdGckhwd + : (spatialDim == 2) ? rock::LinalgConvType::Conv2dBWDNgchwGckhw + : rock::LinalgConvType::Conv1dBWDNgchGckh; + emitConvAttributes( + result, strides, dilation, op.getPaddingAttr(), + op->getAttr("perf_config"), op.getGroupAttr(), + rock::LinalgConvTypeAttr::get(rewriter.getContext(), convType)); + + // Collapse result from NGF* back to NF* + SmallVector reassociation{{0}, {1, 2}}; + llvm::for_each(llvm::seq(3, spatialDim + 3), + [&](int64_t index) { reassociation.push_back({index}); }); + auto finalResult = + tensor::CollapseShapeOp::create(rewriter, loc, result, reassociation).getResult(); + + bool hasPadding = llvm::any_of(lowPads, [](int64_t p) { return p != 0; }) || + llvm::any_of(highPads, [](int64_t p) { return p != 0; }); + if (hasPadding) { + int64_t rank = originalResult.getRank(); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(rank, rewriter.getIndexAttr(1)); + for (int64_t i = 0; i < rank; ++i) + sizes.push_back(rewriter.getIndexAttr(originalResult.getDimSize(i))); + for (int64_t i = 0; i < spatialDim; ++i) + offsets[2 + i] = rewriter.getIndexAttr(lowPads[i]); + finalResult = tensor::ExtractSliceOp::create( + rewriter, loc, originalResult, finalResult, offsets, sizes, + strides) + .getResult(); + } + + rewriter.replaceOp(op, finalResult); + return success(); +} + +LogicalResult BackwardConvConverter::matchAndRewrite( + migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Backward convolution lowering is similar to foward convolution and is lowered in three steps: + // 1. Expand the channel dimension into (group, channel_per_group), + // introducing + // a group dimension G. Input becomes NGC* (e.g. NGCL, NGCHW, NGCDHW) and + // filter becomes GFC* (e.g. GFCL, GFCHW, GFCDHW), matching the group attr. + // 2.. Emit the grouped linalg convolution (1D/2D/3D), then collapse the + // result back to the original NFHW/NFDHW shape for the type converter. + Location loc = op.getLoc(); + Value input = adaptor.getInput(); + Value filter = adaptor.getFilter(); + RankedTensorType inputType = cast(input.getType()); + int64_t dim = inputType.getRank() - 2; + int64_t group = op.getGroupAttr().getInt(); + + if (dim > 3 || dim < 1) { + return op.emitError(Twine(dim) + "D conv is not supported for now"); + } + + if (inputType.getElementType() != op.getFilter().getType().getElementType() || + inputType.getElementType() != op.getResult().getType().getElementType()) { + return op.emitError( + "type casting between operands and result is unsupported for now"); + } + + input = expandGroupDim(rewriter, loc, input, /*isFilter=*/false, group, dim); + filter = expandGroupDim(rewriter, loc, filter, /*isFilter=*/true, group, dim); + + return emitBackwardConv(rewriter, op, input, filter); +} // TODO: add support for scaled gemms, and migraphx::DeQuantizeLinearConverter //===----------------------------------------------------------------------===// // Base kernels (gemm) @@ -1134,7 +1356,8 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns( MultiBroadcastConverter, LiteralConverter, ReshapeConverter, BooleanElementwiseConverter, BooleanElementwiseConverter, ClipConverter, - TransposeConverter, ConvConverter>(converter, patterns.getContext()); + TransposeConverter, ConvConverter, + BackwardConvConverter>(converter, patterns.getContext()); } void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns( diff --git a/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-not-implemented.mlir b/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-not-implemented.mlir index 711a3ad08232..7dfef1b77118 100644 --- a/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-not-implemented.mlir +++ b/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-not-implemented.mlir @@ -48,12 +48,6 @@ func.func @func_quant_convolution(%arg0: !migraphx.shaped<1x1xi8, 1x1>, %arg1: ! func.return } -func.func @func_backwards_data_convolution(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.shaped<1x1xf32, 1x1>) { - // expected-error @+1{{failed to legalize operation 'migraphx.backwards_data_convolution'}} - migraphx.backwards_data_convolution %arg0, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0], stride = [1, 1]}: <1x1xf32, 1x1>, <1x1xf32, 1x1> -> <1x1xf32, 1x1> - func.return -} - func.func @func_batch_norm_inference(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.shaped<1x1xf32, 1x1>) { // expected-error @+1{{failed to legalize operation 'migraphx.batch_norm_inference'}} migraphx.batch_norm_inference %arg0, %arg1, %arg1, %arg1, %arg1 {bn_mode = 0 : i64, epsilon = 1.0e-5 : f32, momentum = 0.9 : f32}: diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-asymmetric-stride.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-asymmetric-stride.mlir index d70e47cb738e..4563ff8d33cd 100644 --- a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-asymmetric-stride.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-asymmetric-stride.mlir @@ -1,4 +1,5 @@ // RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s module { func.func @mlir_bwd_data_conv( diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation1-stride2.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation1-stride2.mlir index dd2a15958aae..c2baa84d2195 100644 --- a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation1-stride2.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation1-stride2.mlir @@ -1,4 +1,5 @@ // RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s module { func.func @mlir_bwd_data_conv(%arg0: !migraphx.shaped<1x512x32x32xf32, 524288x1024x32x1>, diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation2-stride1.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation2-stride1.mlir index a93c9f66ea11..ffcf23ccb574 100644 --- a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation2-stride1.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-dilation2-stride1.mlir @@ -1,4 +1,5 @@ // RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s module { // CHECK: [1 1 1] diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-padding1.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-padding1.mlir index c7fb226f4e2e..a5bf5797f9e4 100644 --- a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-padding1.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-padding1.mlir @@ -1,4 +1,5 @@ // RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s module { // CHECK: [1 1 1] diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride2-dilation2.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride2-dilation2.mlir index ca9e9cfeb2ed..6cb94fa05c44 100644 --- a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride2-dilation2.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride2-dilation2.mlir @@ -1,4 +1,5 @@ // RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s module { // CHECK: [1 1 1] diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride32.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride32.mlir index 17c2af1d8607..e8c7ae7e585a 100644 --- a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride32.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-stride32.mlir @@ -1,4 +1,5 @@ // RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s module { // CHECK: [1 1 1] diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv.mlir index f436724c95a0..c28108b9a6f2 100644 --- a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv.mlir @@ -1,4 +1,5 @@ // RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s module { // CHECK: [1 1 1] diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv3d.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv3d.mlir index cae3bad86ea6..2c802cc7d5aa 100644 --- a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv3d.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv3d.mlir @@ -1,4 +1,5 @@ // RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-driver -kernel-pipeline migraphx,highlevel | rocmlir-gen -ph -print-results -rand none - | rocmlir-driver -arch %arch -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=MIXR +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx-linalg,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s --check-prefix=LINALG // The CPU lowering pipeline is currently broken for backwards data convolution // ops (lowering tosa.transpose_conv2d). As such, we do not currently have a way @@ -12,6 +13,7 @@ // TODO: We are actually generating a 2D conv with rocmlir-gen here since it does not support generating 3D // This should be a 3D conv once rocmlir-gen supports it. module { + // LINALG: [1 1 1] // MIXR: [1, 2, 3, 2, 1, 2, 4, 6, 4, 2, 3, 6, 9, 6, 3, 2, 4, 6, 4, 2, 1, 2, 3, 2, 1] // GEN: [1 1 1] func.func @mlir_bwd_data_conv( From 2d94bfc046cdeb78590d78f5b844cf449e2772bb Mon Sep 17 00:00:00 2001 From: Vincent Date: Mon, 16 Mar 2026 18:31:49 +0000 Subject: [PATCH 2/2] Address some comments --- .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 2 +- .../MIGraphXToLinalg/MIGraphXToLinalg.cpp | 67 +++++++++++-------- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index fa4a48b60152..6964ca6316f8 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -70,7 +70,7 @@ def LinalgBwdConv1D def LinalgBwdConv2D : I32EnumAttrCase<"Conv2dBWDNgchwGckhw", 4, "convbwd2d_ngchw_gckhw">; def LinalgBwdConv3D - : I32EnumAttrCase<"Conv3dBWDNgchwdGckhwd", 5, "convbwd3d_ngchd_gckhw">; + : I32EnumAttrCase<"Conv3dBWDNgchwdGckhwd", 5, "convbwd3d_ngchwd_gckhwd">; def LinalgConvType : Rock_I32Enum<"LinalgConvType", diff --git a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp index 55408230940d..4dcbabf50a93 100644 --- a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp +++ b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp @@ -155,19 +155,19 @@ static void convBodyBuilder(OpBuilder &b, Location loc, ValueRange blockArgs) { } /// Emit convolution attributes on the newly created operation. -static void emitConvAttributes(Value convOp, Attribute strides, +static void emitConvAttributes(Operation* migraphxOp, Value convOp, Attribute strides, Attribute dilation, Attribute pad, Attribute perfConfig, Attribute groupAttr, Attribute convOpName) { Operation *newOp = convOp.getDefiningOp(); newOp->setAttr("pad", pad); - newOp->setAttr("group", groupAttr); + newOp->setAttr("group", migraphxOp->getAttr("group")); newOp->setAttr("stride", strides); newOp->setAttr("dilation", dilation); // Convert optional attributes - if (perfConfig) - newOp->setAttr("perf_config", perfConfig); + if (migraphxOp->hasAttr("perf_config")) + newOp->setAttr("perf_config", migraphxOp->getAttr("perf_config")); newOp->setAttr("conv_op", convOpName); } @@ -279,7 +279,7 @@ static Value emitGroupedConv(ConversionPatternRewriter &rewriter, Location loc, /// for kh_1 in kernel_spatial_1: // reduction /// // ... /// result[n,g,f, ih_i*stride_i + kh_i*dilation_i, ...] += -/// input[n,g,c,ih_0,...] * filter[c,g,f,kh_0,...] +/// input[n,g,c,ih_0,...] * filter[g,c,f,kh_0,...] /// clang-format on static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter, Location loc, @@ -288,7 +288,7 @@ static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter, ArrayAttr strides, ArrayAttr dilation) { MLIRContext *ctx = rewriter.getContext(); - int64_t dim = cast(input.getType()).getRank() - 3; + int64_t spatialDim = cast(input.getType()).getRank() - 3; SmallVector strideVals; SmallVector dilationVals; llvm::transform(strides.getValue(), std::back_inserter(strideVals), @@ -301,37 +301,48 @@ static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter, }); // Iteration domain layout (mirrors emitGroupedConv): - // parallel: batch, group, filter, ih_0 .. ih_{dim-1} + // parallel: batch, group, ih_0 .. ih_{dim-1}, filter // reduction: channel, kh_0 .. kh_{dim-1} - int64_t totalDims = 4 + 2 * dim; + // See the loop structure from above to see where these constants come fron + const int64_t ihStart = 2; + const int64_t filterIdx = ihStart + spatialDim; + const int64_t channelIdx = filterIdx + 1; + const int64_t khStart = channelIdx + 1; + const int64_t totalDims = khStart + spatialDim; + const int64_t numParallel = channelIdx; + SmallVector d; for (int64_t i = 0; i < totalDims; ++i) d.push_back(getAffineDimExpr(i, ctx)); - AffineExpr batch = d[0], group = d[1], filterExpr = d[dim+2]; - AffineExpr channel = d[3 + dim]; + AffineExpr batch = d[0], group = d[1]; + AffineExpr outChannel = d[filterIdx]; + AffineExpr inChannel = d[channelIdx]; - SmallVector inputExprs = {batch, group, channel}; - for (int64_t i = 0; i < dim; ++i) - inputExprs.push_back(d[2 + i]); + SmallVector inputExprs = {batch, group, inChannel}; + for (int64_t i = 0; i < spatialDim; ++i) + inputExprs.push_back(d[ihStart + i]); - SmallVector filterExprs = {group, channel, filterExpr}; - for (int64_t i = 0; i < dim; ++i) - filterExprs.push_back(d[4 + dim + i]); + SmallVector filterExprs = {group, inChannel, outChannel}; + for (int64_t i = 0; i < spatialDim; ++i) + filterExprs.push_back(d[khStart + i]); - SmallVector outputExprs = {batch, group, filterExpr}; - for (int64_t i = 0; i < dim; ++i) - outputExprs.push_back(d[2 + i] * strideVals[i] + - d[4 + dim + i] * dilationVals[i]); + SmallVector outputExprs = {batch, group, outChannel}; + for (int64_t i = 0; i < spatialDim; ++i) { + AffineExpr ih_i = d[ihStart + i]; + AffineExpr kh_i = d[khStart + i]; + outputExprs.push_back(ih_i * strideVals[i] + kh_i * dilationVals[i]); + } SmallVector indexingMaps = { AffineMap::get(totalDims, /*symbolCount=*/0, inputExprs, ctx), AffineMap::get(totalDims, /*symbolCount=*/0, filterExprs, ctx), AffineMap::get(totalDims, /*symbolCount=*/0, outputExprs, ctx)}; - SmallVector iteratorTypes(3 + dim, + SmallVector iteratorTypes(numParallel, utils::IteratorType::parallel); - iteratorTypes.append(1 + dim, utils::IteratorType::reduction); + iteratorTypes.append(totalDims - numParallel, + utils::IteratorType::reduction); auto result = linalg::GenericOp::create(rewriter, loc, resultType, ValueRange{input, filter}, zero, @@ -341,7 +352,7 @@ static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter, } /// Given the collapsed NF* result type and the group count, return the -/// expanded NGF* result type for the grouped linalg convolution. +/// expanded NGK* result type for the grouped linalg convolution. static RankedTensorType expandResultForGroupedConv(RankedTensorType resultType, int64_t group) { ArrayRef resultShape = resultType.getShape(); @@ -400,7 +411,7 @@ LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter, Value result = emitGroupedConv(rewriter, loc, newResultType, input, filter, zero, strides, dilation); - emitConvAttributes(result, strides, dilation, op.getPaddingAttr(), + emitConvAttributes(op,result, strides, dilation, op.getPaddingAttr(), op->getAttr("perf_config"), op.getGroupAttr(), resultConvOpName); @@ -602,11 +613,11 @@ BackwardConvConverter::emitBackwardConv(ConversionPatternRewriter &rewriter, : (spatialDim == 2) ? rock::LinalgConvType::Conv2dBWDNgchwGckhw : rock::LinalgConvType::Conv1dBWDNgchGckh; emitConvAttributes( - result, strides, dilation, op.getPaddingAttr(), + op,result, strides, dilation, op.getPaddingAttr(), op->getAttr("perf_config"), op.getGroupAttr(), rock::LinalgConvTypeAttr::get(rewriter.getContext(), convType)); - // Collapse result from NGF* back to NF* + // Collapse result from NGK* back to NK* SmallVector reassociation{{0}, {1, 2}}; llvm::for_each(llvm::seq(3, spatialDim + 3), [&](int64_t index) { reassociation.push_back({index}); }); @@ -637,12 +648,12 @@ BackwardConvConverter::emitBackwardConv(ConversionPatternRewriter &rewriter, LogicalResult BackwardConvConverter::matchAndRewrite( migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - // Backward convolution lowering is similar to foward convolution and is lowered in three steps: + // Backward convolution lowering is similar to forward convolution and is lowered in three steps: // 1. Expand the channel dimension into (group, channel_per_group), // introducing // a group dimension G. Input becomes NGC* (e.g. NGCL, NGCHW, NGCDHW) and // filter becomes GFC* (e.g. GFCL, GFCHW, GFCDHW), matching the group attr. - // 2.. Emit the grouped linalg convolution (1D/2D/3D), then collapse the + // 2. Emit the grouped linalg convolution (1D/2D/3D), then collapse the // result back to the original NFHW/NFDHW shape for the type converter. Location loc = op.getLoc(); Value input = adaptor.getInput();