-
Notifications
You must be signed in to change notification settings - Fork 55
[AIROCMLIR-445] Lower migraphx.backwards_data_convolution
#2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pr-template-migraphx-to-linalg-conv-2
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -107,6 +107,9 @@ LogicalResult AsUnderlyingShapeConverter::matchAndRewrite( | |
| "input shape is non standard or broadcast; cannot convert this shape"); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Forward and Backward convolution converter | ||
| //===----------------------------------------------------------------------===// | ||
| namespace { | ||
| struct ConvConverter final | ||
| : public OpConversionPattern<migraphx::ConvolutionOp> { | ||
|
|
@@ -124,6 +127,21 @@ struct ConvConverter final | |
| migraphx::ConvolutionOp op, Value input, | ||
| Value filter) const; | ||
| }; | ||
|
|
||
| struct BackwardConvConverter final | ||
| : public OpConversionPattern<migraphx::ConvolutionBwdDataOp> { | ||
| using OpConversionPattern< | ||
| migraphx::ConvolutionBwdDataOp>::OpConversionPattern; | ||
| using OpConversionPattern<migraphx::ConvolutionBwdDataOp>::getTypeConverter; | ||
| using OpAdaptor = | ||
| typename OpConversionPattern<migraphx::ConvolutionBwdDataOp>::OpAdaptor; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override; | ||
| private: | ||
| LogicalResult emitBackwardConv(ConversionPatternRewriter& rewriter, migraphx::ConvolutionBwdDataOp op, Value input, Value filter) const; | ||
| }; | ||
| } // namespace | ||
|
|
||
| // Nice helper function for the linalg.generic op region | ||
|
|
@@ -137,18 +155,19 @@ static void convBodyBuilder(OpBuilder &b, Location loc, ValueRange blockArgs) { | |
| } | ||
|
|
||
| /// Emit convolution attributes on the newly created operation. | ||
| static void emitConvAttributes(migraphx::ConvolutionOp op, Value convOp, | ||
| Attribute strides, Attribute dilation, | ||
| Attribute pad, Attribute convOpName) { | ||
| static void emitConvAttributes(Operation* migraphxOp, Value convOp, Attribute strides, | ||
| Attribute dilation, Attribute pad, | ||
| Attribute perfConfig, Attribute groupAttr, | ||
| Attribute convOpName) { | ||
| Operation *newOp = convOp.getDefiningOp(); | ||
| newOp->setAttr("pad", pad); | ||
| newOp->setAttr("group", op.getGroupAttr()); | ||
| newOp->setAttr("group", migraphxOp->getAttr("group")); | ||
| newOp->setAttr("stride", strides); | ||
| newOp->setAttr("dilation", dilation); | ||
|
|
||
| // Convert optional attributes | ||
| if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config")) | ||
| newOp->setAttr("perf_config", attr); | ||
| if (migraphxOp->hasAttr("perf_config")) | ||
| newOp->setAttr("perf_config", migraphxOp->getAttr("perf_config")); | ||
| newOp->setAttr("conv_op", convOpName); | ||
| } | ||
|
|
||
|
|
@@ -239,6 +258,118 @@ static Value emitGroupedConv(ConversionPatternRewriter &rewriter, Location loc, | |
| .getResult(0); | ||
| } | ||
|
|
||
| /// Emit a grouped backward (transposed) convolution of any spatial rank. | ||
| /// Input shape: (batch, group, channel, spatial...), | ||
| /// filter shape: (group, filter, channel, kernel_spatial...) | ||
| /// | ||
| /// The loop structure mirrors the forward convolution, but with the | ||
| /// stride/dilation affine expression on the *output* indexing map: | ||
| /// | ||
| /// clang-format off | ||
| /// for n in batch: | ||
| /// for g in group: | ||
| /// for ih_0 in input_spatial_0: | ||
| /// for ih_1 in input_spatial_1: | ||
| /// // ... | ||
| /// for ih_{dim-1} in input_spatial_{dim-1}: | ||
| /// for f in filters: | ||
| /// reduction starts here | ||
| /// for c in channels: // reduction | ||
| /// for kh_0 in kernel_spatial_0: // reduction | ||
| /// for kh_1 in kernel_spatial_1: // reduction | ||
| /// // ... | ||
| /// result[n,g,f, ih_i*stride_i + kh_i*dilation_i, ...] += | ||
| /// input[n,g,c,ih_0,...] * filter[g,c,f,kh_0,...] | ||
| /// clang-format on | ||
| static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter, | ||
| Location loc, | ||
| RankedTensorType resultType, Value input, | ||
| Value filter, Value zero, | ||
| ArrayAttr strides, | ||
| ArrayAttr dilation) { | ||
| MLIRContext *ctx = rewriter.getContext(); | ||
| int64_t spatialDim = cast<RankedTensorType>(input.getType()).getRank() - 3; | ||
| SmallVector<int64_t, 4> strideVals; | ||
| SmallVector<int64_t, 4> dilationVals; | ||
| llvm::transform(strides.getValue(), std::back_inserter(strideVals), | ||
| [](Attribute attr) { | ||
| return cast<IntegerAttr>(attr).getInt(); | ||
| }); | ||
| llvm::transform(dilation.getValue(), std::back_inserter(dilationVals), | ||
| [](Attribute attr) { | ||
| return cast<IntegerAttr>(attr).getInt(); | ||
| }); | ||
|
|
||
| // Iteration domain layout (mirrors emitGroupedConv): | ||
| // parallel: batch, group, ih_0 .. ih_{dim-1}, filter | ||
| // reduction: channel, kh_0 .. kh_{dim-1} | ||
| // See the loop structure from above to see where these constants come fron | ||
| const int64_t ihStart = 2; | ||
| const int64_t filterIdx = ihStart + spatialDim; | ||
| const int64_t channelIdx = filterIdx + 1; | ||
| const int64_t khStart = channelIdx + 1; | ||
| const int64_t totalDims = khStart + spatialDim; | ||
| const int64_t numParallel = channelIdx; | ||
|
|
||
| SmallVector<AffineExpr> d; | ||
| for (int64_t i = 0; i < totalDims; ++i) | ||
| d.push_back(getAffineDimExpr(i, ctx)); | ||
|
|
||
| AffineExpr batch = d[0], group = d[1]; | ||
| AffineExpr outChannel = d[filterIdx]; | ||
| AffineExpr inChannel = d[channelIdx]; | ||
|
|
||
| SmallVector<AffineExpr> inputExprs = {batch, group, inChannel}; | ||
| for (int64_t i = 0; i < spatialDim; ++i) | ||
| inputExprs.push_back(d[ihStart + i]); | ||
|
|
||
| SmallVector<AffineExpr> filterExprs = {group, inChannel, outChannel}; | ||
| for (int64_t i = 0; i < spatialDim; ++i) | ||
| filterExprs.push_back(d[khStart + i]); | ||
|
|
||
| SmallVector<AffineExpr> outputExprs = {batch, group, outChannel}; | ||
| for (int64_t i = 0; i < spatialDim; ++i) { | ||
| AffineExpr ih_i = d[ihStart + i]; | ||
| AffineExpr kh_i = d[khStart + i]; | ||
| outputExprs.push_back(ih_i * strideVals[i] + kh_i * dilationVals[i]); | ||
| } | ||
|
|
||
| SmallVector<AffineMap> indexingMaps = { | ||
| AffineMap::get(totalDims, /*symbolCount=*/0, inputExprs, ctx), | ||
| AffineMap::get(totalDims, /*symbolCount=*/0, filterExprs, ctx), | ||
| AffineMap::get(totalDims, /*symbolCount=*/0, outputExprs, ctx)}; | ||
|
|
||
| SmallVector<utils::IteratorType> iteratorTypes(numParallel, | ||
| utils::IteratorType::parallel); | ||
| iteratorTypes.append(totalDims - numParallel, | ||
| utils::IteratorType::reduction); | ||
|
|
||
| auto result = linalg::GenericOp::create(rewriter, loc, resultType, | ||
| ValueRange{input, filter}, zero, | ||
| indexingMaps, iteratorTypes, convBodyBuilder) | ||
| .getResult(0); | ||
| return result; | ||
| } | ||
|
|
||
| /// Given the collapsed NF* result type and the group count, return the | ||
| /// expanded NGK* result type for the grouped linalg convolution. | ||
| static RankedTensorType | ||
| expandResultForGroupedConv(RankedTensorType resultType, int64_t group) { | ||
| ArrayRef<int64_t> resultShape = resultType.getShape(); | ||
| int64_t n = resultType.getDimSize(0); | ||
| int64_t newF = resultType.getDimSize(1) / group; | ||
| assert(resultType.getDimSize(1) % group == 0 && | ||
| "output channel must be divisible by group"); | ||
|
|
||
| SmallVector<int64_t, 4> newShape; | ||
| newShape.push_back(n); | ||
| newShape.push_back(group); | ||
| newShape.push_back(newF); | ||
| newShape.insert(newShape.end(), std::next(resultShape.begin(), 2), | ||
| resultShape.end()); | ||
| return RankedTensorType::get(newShape, resultType.getElementType()); | ||
| } | ||
|
|
||
| LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter, | ||
| migraphx::ConvolutionOp op, Value input, | ||
| Value filter) const { | ||
|
|
@@ -280,8 +411,8 @@ LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter, | |
| Value result = emitGroupedConv(rewriter, loc, newResultType, input, filter, | ||
| zero, strides, dilation); | ||
|
|
||
| emitConvAttributes(op, result, strides, dilation, | ||
| op.getPaddingAttr(), | ||
| emitConvAttributes(op,result, strides, dilation, op.getPaddingAttr(), | ||
| op->getAttr("perf_config"), op.getGroupAttr(), | ||
| resultConvOpName); | ||
|
|
||
| // we must reshape the operand to what the type converter expects | ||
|
|
@@ -444,6 +575,108 @@ ConvConverter::matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor, | |
| return emitConv(rewriter, op, input, filter); | ||
| } | ||
|
|
||
| LogicalResult | ||
| BackwardConvConverter::emitBackwardConv(ConversionPatternRewriter &rewriter, | ||
| migraphx::ConvolutionBwdDataOp op, | ||
| Value input, Value filter) const { | ||
| Location loc = op.getLoc(); | ||
| int64_t group = op.getGroupAttr().getInt(); | ||
| int64_t spatialDim = cast<RankedTensorType>(input.getType()).getRank() - | ||
| 3; // exclude batch (N), group (G), channel (C) | ||
| assert(spatialDim >= 1 && spatialDim <= 3 && | ||
| "this should be checked at matchAndRewrite"); | ||
| // To get the result shape, we must first add the padding | ||
| ArrayRef<Attribute> padding = op.getPaddingAttr().getValue(); | ||
| RankedTensorType originalResult = cast<RankedTensorType>(getTypeConverter()->convertType(op.getResult())); | ||
| SmallVector<int64_t, 4> resultShape (originalResult.getShape()); | ||
| SmallVector<int64_t, 4> lowPads; | ||
| SmallVector<int64_t, 4> highPads; | ||
| for(int64_t i = 0; i<spatialDim; ++i){ | ||
| int64_t lowPad = cast<IntegerAttr>(padding[i]).getInt(); | ||
| int64_t highPad = cast<IntegerAttr>(padding[i+spatialDim]).getInt(); | ||
| resultShape[2+i] += lowPad + highPad; | ||
| lowPads.push_back(lowPad); | ||
| highPads.push_back(highPad); | ||
| } | ||
| RankedTensorType resultType = RankedTensorType::get(resultShape,originalResult.getElementType()); | ||
| auto newResultType = expandResultForGroupedConv(resultType, group); | ||
| Value zero = arith::ConstantOp::create(rewriter, loc, newResultType, | ||
| rewriter.getZeroAttr(newResultType)); | ||
|
|
||
| ArrayAttr strides = op.getStride(); | ||
| ArrayAttr dilation = op.getDilation(); | ||
|
|
||
| Value result = emitGroupedBackwardConv(rewriter, loc, newResultType, input, | ||
| filter, zero, strides, dilation); | ||
| rock::LinalgConvType convType = | ||
| (spatialDim == 3) ? rock::LinalgConvType::Conv3dBWDNgchwdGckhwd | ||
| : (spatialDim == 2) ? rock::LinalgConvType::Conv2dBWDNgchwGckhw | ||
| : rock::LinalgConvType::Conv1dBWDNgchGckh; | ||
| emitConvAttributes( | ||
| op,result, strides, dilation, op.getPaddingAttr(), | ||
| op->getAttr("perf_config"), op.getGroupAttr(), | ||
| rock::LinalgConvTypeAttr::get(rewriter.getContext(), convType)); | ||
|
|
||
| // Collapse result from NGK* back to NK* | ||
| SmallVector<ReassociationIndices, 4> reassociation{{0}, {1, 2}}; | ||
| llvm::for_each(llvm::seq<int64_t>(3, spatialDim + 3), | ||
| [&](int64_t index) { reassociation.push_back({index}); }); | ||
| auto finalResult = | ||
| tensor::CollapseShapeOp::create(rewriter, loc, result, reassociation).getResult(); | ||
|
|
||
| bool hasPadding = llvm::any_of(lowPads, [](int64_t p) { return p != 0; }) || | ||
| llvm::any_of(highPads, [](int64_t p) { return p != 0; }); | ||
| if (hasPadding) { | ||
| int64_t rank = originalResult.getRank(); | ||
| SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); | ||
| SmallVector<OpFoldResult> sizes; | ||
| SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); | ||
| for (int64_t i = 0; i < rank; ++i) | ||
| sizes.push_back(rewriter.getIndexAttr(originalResult.getDimSize(i))); | ||
| for (int64_t i = 0; i < spatialDim; ++i) | ||
| offsets[2 + i] = rewriter.getIndexAttr(lowPads[i]); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember that conv_bwd can probably have negative padding values. Can you check and would this work for that case ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this would work if we are having a negative padding. This is because What is the semantics for negative padding? Is it just the same as the as we apply a output padding with with the magnitude being those negative values? |
||
| finalResult = tensor::ExtractSliceOp::create( | ||
| rewriter, loc, originalResult, finalResult, offsets, sizes, | ||
| strides) | ||
| .getResult(); | ||
| } | ||
|
|
||
| rewriter.replaceOp(op, finalResult); | ||
| return success(); | ||
| } | ||
|
|
||
| LogicalResult BackwardConvConverter::matchAndRewrite( | ||
| migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const { | ||
| // Backward convolution lowering is similar to forward convolution and is lowered in three steps: | ||
| // 1. Expand the channel dimension into (group, channel_per_group), | ||
| // introducing | ||
| // a group dimension G. Input becomes NGC* (e.g. NGCL, NGCHW, NGCDHW) and | ||
| // filter becomes GFC* (e.g. GFCL, GFCHW, GFCDHW), matching the group attr. | ||
| // 2. Emit the grouped linalg convolution (1D/2D/3D), then collapse the | ||
| // result back to the original NFHW/NFDHW shape for the type converter. | ||
| Location loc = op.getLoc(); | ||
| Value input = adaptor.getInput(); | ||
| Value filter = adaptor.getFilter(); | ||
| RankedTensorType inputType = cast<RankedTensorType>(input.getType()); | ||
| int64_t dim = inputType.getRank() - 2; | ||
| int64_t group = op.getGroupAttr().getInt(); | ||
|
|
||
| if (dim > 3 || dim < 1) { | ||
| return op.emitError(Twine(dim) + "D conv is not supported for now"); | ||
| } | ||
|
|
||
| if (inputType.getElementType() != op.getFilter().getType().getElementType() || | ||
| inputType.getElementType() != op.getResult().getType().getElementType()) { | ||
| return op.emitError( | ||
| "type casting between operands and result is unsupported for now"); | ||
| } | ||
|
|
||
| input = expandGroupDim(rewriter, loc, input, /*isFilter=*/false, group, dim); | ||
| filter = expandGroupDim(rewriter, loc, filter, /*isFilter=*/true, group, dim); | ||
|
|
||
| return emitBackwardConv(rewriter, op, input, filter); | ||
| } | ||
| // TODO: add support for scaled gemms, and migraphx::DeQuantizeLinearConverter | ||
| //===----------------------------------------------------------------------===// | ||
| // Base kernels (gemm) | ||
|
|
@@ -1134,7 +1367,8 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns( | |
| MultiBroadcastConverter, LiteralConverter, ReshapeConverter, | ||
| BooleanElementwiseConverter<migraphx::Greater>, | ||
| BooleanElementwiseConverter<migraphx::Equal>, ClipConverter, | ||
| TransposeConverter, ConvConverter>(converter, patterns.getContext()); | ||
| TransposeConverter, ConvConverter, | ||
| BackwardConvConverter>(converter, patterns.getContext()); | ||
| } | ||
|
|
||
| void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.