diff --git a/CMakeLists.txt b/CMakeLists.txt index 693c96bb..68733a52 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ if (TRITON_BUILD_PYTHON_MODULE) LinalgExtTransforms TritonExtTransforms + LinalgExtAnalysis LinalgToLinked LinkedToHIVM diff --git a/backend/compiler.py b/backend/compiler.py index 804bdfdf..2a429b0d 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -131,6 +131,7 @@ def __init__(self, target: str) -> None: self.binary_ext = "mcfatbin" elif self.driver.target == "ascend": self.binary_ext = "npubin" + self.capability = target.arch else: raise RuntimeError(f"Target '{self.target_type}' is not supported.") @@ -249,7 +250,7 @@ def add_stages(self, stages, options, language=None): ) stages["npubin"] = ( lambda src, metadata: linalg_to_bin_enable_npu_compile( - src, metadata, options + src, metadata, options, self.capability ) ) else: @@ -264,7 +265,7 @@ def add_stages(self, stages, options, language=None): ) stages["npubin"] = ( lambda src, metadata: linalg_to_bin_enable_npu_compile( - src, metadata, options + src, metadata, options, self.capability ) ) else: diff --git a/backend/npu.py b/backend/npu.py index fc39baac..d93be3d1 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -491,6 +491,7 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) dicp_triton.passes.linked_npu.add_linked_to_hivm(pm) + dicp_triton.passes.linked_npu.add_npu_unroll_pipeline(pm) pm.run(mod) # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。 @@ -683,7 +684,7 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): return linalg, metadata -def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): +def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt, capability): linalg, metadata = _parse_linalg_metadata(linalg, metadata) with tempfile.TemporaryDirectory() as tmpdir: ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") @@ -706,6 +707,8 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): _compile_option_list += [ f"--enable-auto-multi-buffer={multibuffer}", ] + if capability: + _compile_option_list += [f"--target={capability}"] if _is_ascend_sanitizer_enabled(): _compile_option_list += ["--enable-sanitizer=true"] diff --git a/compiler/include/dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h b/compiler/include/dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h new file mode 100644 index 00000000..cb16f301 --- /dev/null +++ b/compiler/include/dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h @@ -0,0 +1,166 @@ +#ifndef DICP_DIALECT_LINALGEXT_TRANSFORMS_DIMANALYZER_H +#define DICP_DIALECT_LINALGEXT_TRANSFORMS_DIMANALYZER_H + +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" + +#include +#include +#include + +namespace mlir { +namespace dicp { + +/// Classification of a dimension's role in the computation graph. +/// This helps determine if a dimension is safe to tile or parallelize. +enum class DimKind { + Unknown, // No specific property inferred yet. + Parallel, // Dimension implies independent iterations (safe to tile). + Reduction, // Dimension is collapsed/reduced (requires accumulation). + Broadcast, // Dimension is replicated (data invariant along this axis). + Complex // Dimension undergoes complex transformation (e.g., non-affine + // reshape). +}; + +std::string toString(DimKind k); + +/// Disjoint Set Union (DSU) for tracking dimension equivalence and properties. +/// +/// This class implements a Disjoint Set data structure (Union-Find) +/// specifically designed for Tensor/MemRef dimensions. It serves two main +/// purposes: +/// 1. **Equivalence Tracking**: Determines which dimensions across different +/// values +/// represent the same logical axis (e.g., the 'N' dimension in a Matmul +/// propagating through element-wise adds). +/// 2. **Property Propagation**: Merges semantic properties (DimKind) when +/// dimensions +/// are unified. For example, if a dimension is used as a Reduction iterator +/// in one operation, that property propagates to all equivalent dimensions +/// in the set. +class DimensionDisjointSet { +public: + explicit DimensionDisjointSet(size_t size = 0) { resize(size); } + + /// Allocates `n` new dimension IDs in the set. + /// \return The ID of the first allocated dimension. + int64_t allocate(size_t n = 1); + + /// Finds the representative (root) ID for the set containing dimension `i`. + /// Implements path compression for amortized constant time lookups. + int64_t find(int64_t i); + + /// Merges the sets containing dimensions `i` and `j`. + /// This also merges the `DimKind` properties of both roots using + /// `mergeKinds`. + void unionSets(int64_t i, int64_t j); + + /// Updates the DimKind property for the set containing dimension `i`. + /// The new kind is merged with the existing kind to ensure safety (e.g., + /// Reduction is sticky). + void setKind(int64_t i, DimKind k); + + /// Retrieves the DimKind property of the set containing dimension `i`. + DimKind getKind(int64_t i); + +private: + /// Resizes the internal storage to accommodate `n` dimensions. + void resize(size_t n); + + /// Defines the logic for combining two dimension kinds. + /// Hierarchy of "stickiness": Complex > Reduction > Broadcast/Parallel. + DimKind mergeKinds(DimKind a, DimKind b); + + std::vector parent; // Parent pointers for DSU. + std::vector kind; // Properties associated with each root. +}; + +/// DimAnalyzer: +/// Analyzes a specific execution stage (StageInfo) to determine tiling +/// strategies. +/// +/// The analyzer constructs a constraint graph where nodes are tensor dimensions +/// and edges represent data flow relationships. It uses a Breadth-First Search +/// (BFS) approach to traverse operations and propagate dimension IDs. +/// +/// Algorithm Overview: +/// 1. **Initialization**: Seeds the analysis with stage inputs (operands +/// defined outside the stage). +/// 2. **BFS Propagation**: Traverses the def-use chains. For each operation, it +/// uses specific handlers (e.g., processMatmulOp) to bind input dimensions to +/// output dimensions. +/// 3. **Anchor Heuristic**: Identifies the "Anchor" operation (typically the +/// final LinalgOp) to interpret the resulting loops. +/// 4. **Tiling Selection**: Checks the properties of the Anchor's loops in the +/// DSU to recommend outermost parallel loops for tiling. +class DimAnalyzer { +public: + explicit DimAnalyzer(const StageInfo &stage); + + /// Analyzes the stage operations and returns indices of loops recommended for + /// tiling. The indices correspond to the loop nest of the "Anchor" operation. + SmallVector analyzeAndGetTilingDims(); + +private: + const StageInfo &stage_; + // Quick lookup for ops belonging to this stage. + DenseSet stageOps_; + DimensionDisjointSet dsu_; + // Maps SSA Value -> [Dim IDs] + DenseMap> valueDims_; + + // BFS State passed to handlers to allow them to enqueue new values. + using BFSQueue = std::queue; + using VisitedSet = DenseSet; + + /// Drives the traversal of the data flow graph. + void processBFS(); + + /// Dispatches the operation to the appropriate handler. + /// \return true if the operation was handled, false otherwise. + bool processOperation(Operation *op, Value current, BFSQueue &q, + VisitedSet &v); + + /// Lazily retrieves or allocates unique IDs for the dimensions of a Value. + std::vector getOrAllocateDims(Value v); + + /// Helper to strictly bind all dimensions of v1 to v2 (1-to-1 mapping). + /// Used for Elementwise, Copy, etc. + void bindDimensions(Value v1, Value v2); + + // --- Op Handlers --- + // Each handler interprets the semantics of the op to union input/output + // dimensions correctly. + + void processElementwise(Operation *op, Value current); + void processMatmulOp(linalg::MatmulOp op); + void processReduceOp(linalg::ReduceOp op); + void processTransposeOp(linalg::TransposeOp op); + void processBroadcastOp(linalg::BroadcastOp op); + void processLinalgOpGeneric(linalg::LinalgOp op); + void processReshapeOp(Operation *op); + void processConcatOp(tensor::ConcatOp op); + void processPadOp(tensor::PadOp op); + void processExtractSliceOp(tensor::ExtractSliceOp op); + void processInsertSliceOp(tensor::InsertSliceOp op); + + // Handlers that may need to continue BFS propagation explicitly + void processMemrefCopyOp(memref::CopyOp op, Value current, BFSQueue &q, + VisitedSet &v); + void processMemrefCastOp(Operation *op); + void processBufferizationToTensor(bufferization::ToTensorOp op); + void processMaterializeOp(bufferization::MaterializeInDestinationOp op, + Value current, BFSQueue &q, VisitedSet &v); +}; + +} // namespace dicp +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h b/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h new file mode 100644 index 00000000..16b508ff --- /dev/null +++ b/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h @@ -0,0 +1,97 @@ +#ifndef DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H +#define DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Operation.h" + +#include "llvm/ADT/SetVector.h" + +#include +#include + +namespace mlir { +namespace dicp { + +/// Represents a single pipeline stage. +/// A stage is a sequence of operations that execute together. +/// Synchronization operations (SyncBlockWaitOp) typically delimit stage +/// boundaries. +struct StageInfo { + int id = -1; + std::vector ops; + // IDs of stages that this stage depends on + std::set preds; + // IDs of stages that depend on this stage + std::set succs; + bool hasSync = false; +}; + +// StageDependencyAnalyzer: +// 1. Partitioning a loop body into "stages" based on synchronization primitives +// (hivm::SyncBlockWaitOp). +// 2. Building a dependency graph between these stages considering both: +// - SSA Data Flow (Producer-Consumer relationships). +// - Memory Dependencies (Read-After-Write via +// AliasAnalysis). +// 3. Computing a topological ordering (levels) to detect cycles and determine +// a valid execution schedule. +// 4. Physically reordering the IR operations to match the valid schedule. +// +class StageDependencyAnalyzer { +public: + StageDependencyAnalyzer(scf::ForOp forOp, AliasAnalysis &aliasAnalysis) + : forOp(forOp), aliasAnalysis(aliasAnalysis) {} + + /// Runs the analysis, computes the topological sort, and physically reorders + /// the operations in the loop body. + /// Returns the ordered list of StageInfo on success, or failure if a cycle is + /// detected. + FailureOr> runAndReorder(RewriterBase &rewriter); + +private: + /// Internal node structure for the dependency graph. + struct StageNode { + int id; + StageInfo *stageInfo; + int level = 0; // Topological level (depth) + + // Memory dependencies + llvm::SetVector readValues; + llvm::SetVector writeValues; + + // SSA Value dependencies + llvm::SetVector producedValues; // Values defined in this stage + llvm::SetVector consumedValues; // Values used in this stage + }; + + scf::ForOp forOp; + AliasAnalysis &aliasAnalysis; + std::vector stages; + std::vector nodes; + + /// Scans the loop body to populate the `stages` vector. + void collectStages(); + + /// Collects SSA definitions/uses and Memory Read/Write effects for each + /// stage. + void collectEffects(); + + /// Builds the directed graph edges based on SSA and Memory conflicts. + void buildDependencyGraph(); + + /// Computes the topological level of each node using DFS. + /// Returns failure if a cycle is detected. + LogicalResult computeStageLevels(); + + /// Sorts the `stages` vector based on the computed topological levels. + void reorderStagesLogical(); + + /// Moves the operations in the IR to match the logical order of `stages`. + void materializeScheduleToIR(); +}; + +} // namespace dicp +} // namespace mlir + +#endif // DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H \ No newline at end of file diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h index 7ae43b6c..9ce9267e 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h @@ -23,6 +23,9 @@ std::unique_ptr> createScalarTo1DTensorPass(); std::unique_ptr> createNormalizeSliceOpsPass(); +std::unique_ptr> +createNPUUnroolPipelinePass(); + #define GEN_PASS_REGISTRATION #include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td index c486210a..bb4404bd 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td @@ -68,4 +68,16 @@ def NormalizeSliceOps : Pass<"normalize-slice-ops", "func::FuncOp"> { let dependentDialects = ["mlir::tensor::TensorDialect"]; } +def NPUUnroolPipeline : Pass<"npu-unrool-pipeline", "func::FuncOp"> { + let summary = "DLC Pipelines."; + let constructor = "mlir::dicp::LinalgExt::createNPUUnroolPipelinePass()"; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::tensor::TensorDialect", + "mlir::bufferization::BufferizationDialect", + "mlir::func::FuncDialect" + ]; +} + #endif diff --git a/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp b/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp index 3f69b378..36957c91 100644 --- a/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp +++ b/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp @@ -1,6 +1,8 @@ #include "dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h" #include "dicp/Utils/Utils.h" + #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "bishengir/Dialect/Annotation/IR/Annotation.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -680,7 +682,7 @@ void replacePtrLoopArguments(Operation *rootOp, op.getLoc(), rewriter.getI32Type(), ValueRange({})) ->getResult(0); if (auto forOp = dyn_cast(op.getOperation())) { - newOp = rewriter.create( + auto createdFor = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), constructOperands(forOp.getInitArgs(), tempVar, mapping), @@ -701,6 +703,13 @@ void replacePtrLoopArguments(Operation *rootOp, yieldOp.getLoc(), constructOperands(yieldOp.getOperands(), tempVar, mapping)); }); + + // propagate Triton-specific loop attribute if present on the old for + if (forOp->hasAttr(triton::kNumStagesAttrName)) + createdFor->setAttr(triton::kNumStagesAttrName, + forOp->getAttr(triton::kNumStagesAttrName)); + + newOp = createdFor; } else if (auto whileOp = dyn_cast(op.getOperation())) { newOp = rewriter.create( whileOp.getLoc(), constructTypes(whileOp->getResultTypes()), diff --git a/compiler/lib/Dialect/LinalgExt/Analysis/CMakeLists.txt b/compiler/lib/Dialect/LinalgExt/Analysis/CMakeLists.txt new file mode 100644 index 00000000..def33eb9 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Analysis/CMakeLists.txt @@ -0,0 +1,18 @@ +add_triton_library(LinalgExtAnalysis + DimAnalyzer.cpp + StageDependencyAnalyzer.cpp + + LINK_LIBS PUBLIC + + MLIRAffineDialect + MLIRArithDialect + MLIRDialectUtils + MLIRFuncDialect + MLIRLinalgDialect + MLIRLinalgUtils + MLIRMemRefDialect + MLIRPass + MLIRShapeDialect + MLIRTensorDialect + MLIRTensorUtils +) \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Analysis/DimAnalyzer.cpp b/compiler/lib/Dialect/LinalgExt/Analysis/DimAnalyzer.cpp new file mode 100644 index 00000000..618ed105 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Analysis/DimAnalyzer.cpp @@ -0,0 +1,646 @@ +#include "dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "npu-stage-dim-analyzer" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +std::string mlir::dicp::toString(DimKind k) { + switch (k) { + case DimKind::Unknown: + return "Unknown"; + case DimKind::Parallel: + return "Parallel"; + case DimKind::Reduction: + return "Reduction"; + case DimKind::Broadcast: + return "Broadcast"; + case DimKind::Complex: + return "Complex"; + } + return "INVALID"; +} + +//===----------------------------------------------------------------------===// +// DimensionDisjointSet Implementation +//===----------------------------------------------------------------------===// + +int64_t DimensionDisjointSet::allocate(size_t n) { + size_t start = parent.size(); + resize(start + n); + LDBG(" [DSU] Allocated " << n << " new dims. Range: [" << start << ", " + << start + n - 1 << "]"); + return static_cast(start); +} + +int64_t DimensionDisjointSet::find(int64_t i) { + if (i < 0 || i >= (int64_t)parent.size()) + return -1; + // Path compression: Point directly to the root to speed up future lookups. + if (parent[i] == i) + return i; + return parent[i] = find(parent[i]); +} + +void DimensionDisjointSet::unionSets(int64_t i, int64_t j) { + int64_t rootI = find(i); + int64_t rootJ = find(j); + if (rootI != -1 && rootJ != -1 && rootI != rootJ) { + DimKind kI = kind[rootI]; + DimKind kJ = kind[rootJ]; + + // Merge properties based on priority logic (e.g., Reduction takes + // precedence). + DimKind mergedKind = mergeKinds(kI, kJ); + + // Union by attaching I to J (could be optimized with rank/size). + parent[rootI] = rootJ; + kind[rootJ] = mergedKind; + + LDBG(" [DSU] Union(ID:" << i << " [" << toString(kI) << "] -> ID:" << j + << " [" << toString(kJ) + << "]) => Merged Kind: " << toString(mergedKind)); + } +} + +void DimensionDisjointSet::setKind(int64_t i, DimKind k) { + int64_t root = find(i); + if (root != -1) { + DimKind oldK = kind[root]; + // Update the kind, ensuring we don't downgrade a strong property (like + // Reduction). + kind[root] = mergeKinds(kind[root], k); + if (oldK != kind[root]) { + LDBG(" [DSU] SetKind ID:" << i << " (Root:" << root << ") changed from " + << toString(oldK) << " to " + << toString(kind[root])); + } + } +} + +DimKind DimensionDisjointSet::getKind(int64_t i) { + int64_t root = find(i); + return (root != -1) ? kind[root] : DimKind::Unknown; +} + +void DimensionDisjointSet::resize(size_t n) { + size_t oldSize = parent.size(); + if (n > oldSize) { + parent.resize(n); + // Initialize new elements to point to themselves (roots) with Unknown kind. + std::iota(parent.begin() + oldSize, parent.end(), oldSize); + kind.resize(n, DimKind::Unknown); + } +} + +DimKind DimensionDisjointSet::mergeKinds(DimKind a, DimKind b) { + if (a == b) + return a; + // Complex is the strongest property: if a dimension is complex anywhere, it's + // complex everywhere. + if (a == DimKind::Complex || b == DimKind::Complex) + return DimKind::Complex; + // Reduction is stronger than Parallel/Broadcast: forces serialization/atomic + // handling. + if (a == DimKind::Reduction || b == DimKind::Reduction) + return DimKind::Reduction; + // Broadcast + Parallel is treated as Parallel for tiling purposes. + // (Tiling a broadcasted loop is valid and often efficient). + if ((a == DimKind::Broadcast && b == DimKind::Parallel) || + (a == DimKind::Parallel && b == DimKind::Broadcast)) + return DimKind::Parallel; + // If one is Unknown, take the known one. + return (a != DimKind::Unknown) ? a : b; +} + +//===----------------------------------------------------------------------===// +// DimAnalyzer Implementation +//===----------------------------------------------------------------------===// + +DimAnalyzer::DimAnalyzer(const StageInfo &stage) : stage_(stage) { + // Populate the set for fast O(1) membership checks during traversal. + for (auto *op : stage_.ops) { + stageOps_.insert(op); + } +} + +std::vector DimAnalyzer::getOrAllocateDims(Value v) { + if (valueDims_.count(v)) + return valueDims_[v]; + + auto type = dyn_cast(v.getType()); + if (!type || !type.hasRank()) { + LDBG(" [Warn] Skipping unranked/non-shaped value: " << v); + return {}; + } + + int64_t rank = type.getRank(); + int64_t startId = dsu_.allocate(rank); + std::vector dims(rank); + std::iota(dims.begin(), dims.end(), startId); + + // Default assumption: Dimensions are Parallel unless proven otherwise. + // This helps when operations (like elementwise) don't impose constraints. + for (auto id : dims) + dsu_.setKind(id, DimKind::Parallel); + + valueDims_[v] = dims; + return dims; +} + +void DimAnalyzer::bindDimensions(Value v1, Value v2) { + auto d1 = getOrAllocateDims(v1); + auto d2 = getOrAllocateDims(v2); + if (d1.empty() || d2.empty()) + return; + + if (d1.size() != d2.size()) { + LDBG(" [Warn] Rank mismatch binding " << v1 << " <-> " << v2); + return; + } + // 1-to-1 binding of dimensions (e.g., for Copy, Cast, or Elementwise). + for (size_t i = 0; i < d1.size(); ++i) { + dsu_.unionSets(d1[i], d2[i]); + } +} + +SmallVector DimAnalyzer::analyzeAndGetTilingDims() { + LDBG("\n>>> [Analysis] Starting Analysis for Stage ID: " << stage_.id); + // 1. Build the constraint graph via BFS traversal. + processBFS(); + + // 2. Identify Anchor Op. + // Heuristic: The last LinalgOp in the stage is usually the "Compute" or + // "Write" op. Tiling decisions should be based on this op's loop structure. + linalg::LinalgOp anchorOp; + for (auto it = stage_.ops.rbegin(); it != stage_.ops.rend(); ++it) { + if (auto op = dyn_cast(*it)) { + anchorOp = op; + break; + } + } + + if (!anchorOp) { + LDBG(">>> [Analysis] No LinalgOp anchor found. Tiling unknown."); + return {}; + } + + LDBG(">>> [Analysis] Anchor Op: " << anchorOp->getName()); + + // 3. Map Anchor Loops to Global Dimension IDs. + SmallVector chosenLoops; + auto iterTypes = anchorOp.getIteratorTypesArray(); + auto maps = anchorOp.getIndexingMapsArray(); + std::vector loopToDSU(iterTypes.size(), -1); + + // Iterate over operands to find which Value Dimension corresponds to which + // Loop. + auto operands = anchorOp->getOperands(); + int mapIdx = 0; + for (auto val : operands) { + if (mapIdx >= (int)maps.size()) + break; + if (!isa(val.getType())) { + mapIdx++; + continue; + } + + auto valDims = getOrAllocateDims(val); + AffineMap map = maps[mapIdx++]; + + // Analyze the AffineMap: (d0, d1) -> (d0, d1) + // If result[i] is a simple DimExpr(d_k), then Loop k corresponds to Value + // Dim i. + for (unsigned dimIdx = 0; dimIdx < map.getNumResults(); ++dimIdx) { + if (dimIdx >= valDims.size()) + continue; + if (auto dimExpr = dyn_cast(map.getResult(dimIdx))) { + unsigned loopPos = dimExpr.getPosition(); + if (loopPos < loopToDSU.size()) { + // Link the loop to the global DSU ID of the operand dimension. + loopToDSU[loopPos] = valDims[dimIdx]; + } + } + } + } + + // 4. Evaluate Loops for Tiling. + LDBG(">>> [Analysis] Loop Classification:"); + for (size_t i = 0; i < loopToDSU.size(); ++i) { + DimKind k = DimKind::Unknown; + if (loopToDSU[i] != -1) { + // Get the global property from DSU (propagated from all ops in the + // stage). + k = dsu_.getKind(loopToDSU[i]); + } else { + // Fallback: If loop isn't linked to any data dimension (rare), rely on + // local iterator type. + if (linalg::isReductionIterator(iterTypes[i])) + k = DimKind::Reduction; + else if (linalg::isParallelIterator(iterTypes[i])) + k = DimKind::Parallel; + } + + LDBG(" Loop " << i << ": " << toString(k)); + + // Policy: We only auto-tile global Parallel loops. + // (Future work: support Tiling Reduction if atomic updates are supported). + if (k == DimKind::Parallel) { + chosenLoops.push_back(i); + } + } + return chosenLoops; +} + +void DimAnalyzer::processBFS() { + BFSQueue bfsQueue; + VisitedSet visited; + DenseSet definedInStage; + + // Identify all values defined within the stage to find boundary inputs. + for (auto *op : stage_.ops) + for (auto res : op->getResults()) + definedInStage.insert(res); + + // 1. Seeds: Operands used in stage but defined externally (Inputs). + for (auto *op : stage_.ops) { + for (auto operand : op->getOperands()) { + if (!definedInStage.contains(operand)) { + if (visited.insert(operand).second) { + bfsQueue.push(operand); + getOrAllocateDims(operand); // Pre-allocate IDs for inputs. + } + } + } + } + + // 2. Seeds: Internal roots (Fallback). + // If the graph is fully internal or disconnected, start from the first op. + if (bfsQueue.empty() && !stage_.ops.empty()) { + for (auto res : stage_.ops[0]->getResults()) { + bfsQueue.push(res); + visited.insert(res); + } + } + + // Standard BFS Traversal + while (!bfsQueue.empty()) { + Value current = bfsQueue.front(); + bfsQueue.pop(); + + for (Operation *user : current.getUsers()) { + // Only process users that are part of the current stage. + if (!stageOps_.contains(user)) + continue; + + // Dispatch processing to specific Op handler. + // This establishes constraints between 'current' and 'user's results. + processOperation(user, current, bfsQueue, visited); + + // Enqueue results for downstream propagation. + for (Value result : user->getResults()) { + if (visited.insert(result).second) { + bfsQueue.push(result); + getOrAllocateDims(result); + } + } + } + } +} + +bool DimAnalyzer::processOperation(Operation *op, Value current, BFSQueue &q, + VisitedSet &v) { + // Dispatcher: Directs operation to the specific semantic handler. + if (auto matmulOp = dyn_cast(op)) + processMatmulOp(matmulOp); + else if (auto reduceOp = dyn_cast(op)) + processReduceOp(reduceOp); + else if (auto transOp = dyn_cast(op)) + processTransposeOp(transOp); + else if (auto bcastOp = dyn_cast(op)) + processBroadcastOp(bcastOp); + else if (auto linalgOp = dyn_cast(op)) + processLinalgOpGeneric(linalgOp); + + // Tensor manipulation ops + else if (auto castOp = dyn_cast(op)) + bindDimensions(castOp.getSource(), castOp.getDest()); + else if (isa(op)) + processReshapeOp(op); + else if (auto concatOp = dyn_cast(op)) + processConcatOp(concatOp); + else if (auto padOp = dyn_cast(op)) + processPadOp(padOp); + else if (auto extSlice = dyn_cast(op)) + processExtractSliceOp(extSlice); + else if (auto insSlice = dyn_cast(op)) + processInsertSliceOp(insSlice); + + // Bufferization & MemRef ops + else if (auto copyOp = dyn_cast(op)) + processMemrefCopyOp(copyOp, current, q, v); + else if (isa(op)) + processMemrefCastOp(op); + else if (auto toTensor = dyn_cast(op)) + processBufferizationToTensor(toTensor); + else if (auto matOp = dyn_cast(op)) + processMaterializeOp(matOp, current, q, v); + + // Elementwise ops (Arith, Math) + else if (isa(op->getDialect())) + processElementwise(op, current); + else { + // Default fallback: assume 1-to-1 preservation if results exist. + if (op->getNumResults() > 0) + bindDimensions(current, op->getResult(0)); + } + return true; +} + +//===----------------------------------------------------------------------===// +// Specific Handlers +//===----------------------------------------------------------------------===// + +void DimAnalyzer::processMemrefCopyOp(memref::CopyOp op, Value current, + BFSQueue &q, VisitedSet &v) { + LDBG(" [Op] Processing MemRef Copy"); + Value src = op.getSource(); + Value dst = op.getTarget(); + bindDimensions(src, dst); + + // Special Case: Copy sends data to 'dst', which is an operand (outs), not a + // result. We must explicitly enqueue 'dst' to continue BFS. + if (current == src) { + if (v.insert(dst).second) { + q.push(dst); + getOrAllocateDims(dst); + LDBG(" -> Enqueued Copy Destination: " << dst); + } + } +} + +void DimAnalyzer::processMaterializeOp( + bufferization::MaterializeInDestinationOp op, Value current, BFSQueue &q, + VisitedSet &v) { + LDBG(" [Op] Processing MaterializeInDestination"); + Value src = op.getSource(); + Value dst = op.getDest(); + bindDimensions(src, dst); + + // Similar to Copy: Propagate to destination buffer. + if (current == src) { + if (v.insert(dst).second) { + q.push(dst); + getOrAllocateDims(dst); + LDBG(" -> Enqueued Materialize Destination: " << dst); + } + } +} + +void DimAnalyzer::processMemrefCastOp(Operation *op) { + LDBG(" [Op] Processing MemRef Cast/Reinterpret"); + Value src = op->getOperand(0); + Value dst = op->getResult(0); + + auto srcType = dyn_cast(src.getType()); + auto dstType = dyn_cast(dst.getType()); + + if (srcType && srcType.hasRank() && dstType && dstType.hasRank()) { + if (srcType.getRank() == dstType.getRank()) { + bindDimensions(src, dst); + } else { + // Rank changing casts (e.g. collapse/expand via reinterpret) break strict + // 1-to-1 binding. We treat dst dims as new/separate. + LDBG(" Rank change detected, breaking strict binding."); + getOrAllocateDims(dst); + } + } else { + getOrAllocateDims(dst); + } +} + +void DimAnalyzer::processBufferizationToTensor(bufferization::ToTensorOp op) { + LDBG(" [Op] Processing ToTensor"); + // Converts MemRef to Tensor. Dimensions are strictly preserved. + Value memrefValue = op.getOperand(); + Value tensorResult = op.getResult(); + bindDimensions(memrefValue, tensorResult); +} + +void DimAnalyzer::processTransposeOp(linalg::TransposeOp op) { + LDBG(" [Op] Processing TransposeOp"); + Value input = op.getInput(); + Value result = op.getResult()[0]; + auto perm = op.getPermutation(); + + auto inputDims = getOrAllocateDims(input); + auto resDims = getOrAllocateDims(result); + + if (inputDims.empty() || resDims.empty()) + return; + + // Bind Input[Perm[i]] <-> Result[i] + for (size_t i = 0; i < perm.size(); ++i) { + int64_t srcIdx = perm[i]; + if (srcIdx < (int)inputDims.size() && i < resDims.size()) { + dsu_.unionSets(inputDims[srcIdx], resDims[i]); + } + } +} + +void DimAnalyzer::processMatmulOp(linalg::MatmulOp op) { + LDBG(" [Op] Processing MatmulOp"); + // Standard Matmul: [M, K] * [K, N] -> [M, N] + Value lhs = op.getInputs()[0]; + Value rhs = op.getInputs()[1]; + Value out = op.getResults()[0]; + + auto lhsDims = getOrAllocateDims(lhs); + auto rhsDims = getOrAllocateDims(rhs); + auto outDims = getOrAllocateDims(out); + + // Allocate implicit loops for M, N, K and set their properties. + int64_t loopM = dsu_.allocate(1); + int64_t loopN = dsu_.allocate(1); + int64_t loopK = dsu_.allocate(1); + dsu_.setKind(loopM, DimKind::Parallel); + dsu_.setKind(loopN, DimKind::Parallel); + dsu_.setKind(loopK, DimKind::Reduction); + + // Bind operand dimensions to these loops. + // Assumes standard layout: LHS=[..., M, K], RHS=[..., K, N], Out=[..., M, N] + if (lhsDims.size() >= 2 && rhsDims.size() >= 2 && outDims.size() >= 2) { + dsu_.unionSets(lhsDims[lhsDims.size() - 2], loopM); + dsu_.unionSets(lhsDims[lhsDims.size() - 1], loopK); + dsu_.unionSets(rhsDims[rhsDims.size() - 2], loopK); + dsu_.unionSets(rhsDims[rhsDims.size() - 1], loopN); + dsu_.unionSets(outDims[outDims.size() - 2], loopM); + dsu_.unionSets(outDims[outDims.size() - 1], loopN); + } +} + +void DimAnalyzer::processReduceOp(linalg::ReduceOp op) { + LDBG(" [Op] Processing ReduceOp"); + Value input = op.getInputs()[0]; + Value output = op.getResults()[0]; + auto inputDims = getOrAllocateDims(input); + auto outputDims = getOrAllocateDims(output); + auto reduceIndices = op.getDimensions(); + std::set reduceSet(reduceIndices.begin(), reduceIndices.end()); + + int outIdx = 0; + for (size_t i = 0; i < inputDims.size(); ++i) { + if (reduceSet.count(i)) { + // Input dimension is being reduced -> Mark as Reduction. + dsu_.setKind(inputDims[i], DimKind::Reduction); + } else if (outIdx < (int)outputDims.size()) { + // Input dimension is preserved -> Bind to Output dimension. + dsu_.unionSets(inputDims[i], outputDims[outIdx++]); + } + } +} + +void DimAnalyzer::processBroadcastOp(linalg::BroadcastOp op) { + LDBG(" [Op] Processing BroadcastOp"); + auto inDims = getOrAllocateDims(op.getInput()); + auto resDims = getOrAllocateDims(op.getResult()[0]); + auto broadcastIndices = op.getDimensions(); + std::set bcastSet(broadcastIndices.begin(), broadcastIndices.end()); + + int inIdx = 0; + for (size_t i = 0; i < resDims.size(); ++i) { + if (bcastSet.count(i)) { + // New dimension added by broadcast -> Mark as Broadcast. + dsu_.setKind(resDims[i], DimKind::Broadcast); + } else if (inIdx < (int)inDims.size()) { + // Existing dimension -> Bind to input. + dsu_.unionSets(resDims[i], inDims[inIdx++]); + } + } +} + +void DimAnalyzer::processReshapeOp(Operation *op) { + LDBG(" [Op] Processing Reshape"); + bool isExpand = isa(op); + auto srcDims = getOrAllocateDims(op->getOperand(0)); + auto dstDims = getOrAllocateDims(op->getResult(0)); + + SmallVector indices; + if (isExpand) { + indices = cast(op).getReassociationIndices(); + } else { + indices = cast(op).getReassociationIndices(); + } + + // Map between Collapsed (1 dim) and Expanded (N dims). + auto &collapsed = isExpand ? srcDims : dstDims; + auto &expanded = isExpand ? dstDims : srcDims; + + if (indices.size() != collapsed.size()) + return; + + // Bind the single collapsed dimension to ALL corresponding expanded + // dimensions. This is a conservative approach: it effectively groups them all + // into one equivalence class. + for (size_t i = 0; i < indices.size(); ++i) { + int64_t colID = collapsed[i]; + for (int64_t expIdx : indices[i]) { + if (expIdx < (int64_t)expanded.size()) + dsu_.unionSets(colID, expanded[expIdx]); + } + } +} + +void DimAnalyzer::processElementwise(Operation *op, Value current) { + LDBG(" [Op] Processing Elementwise"); + // Elementwise ops (Add, Sub, etc.) strictly preserve shape. + // Bind input dimensions to result dimensions 1-to-1. + if (op->getNumResults() > 0) + bindDimensions(current, op->getResult(0)); +} + +void DimAnalyzer::processConcatOp(tensor::ConcatOp op) { + // Concat preserves all dimensions except the concatenation axis. + // Even on the concat axis, the logical meaning of the dimension usually + // matches (e.g., stacking Batches). We bind inputs to output 1-to-1. + Value result = op.getResult(); + for (Value operand : op.getOperands()) + bindDimensions(operand, result); +} + +void DimAnalyzer::processPadOp(tensor::PadOp op) { + // Padding extends the size but preserves the logical axis. + bindDimensions(op.getSource(), op.getResult()); +} + +void DimAnalyzer::processExtractSliceOp(tensor::ExtractSliceOp op) { + auto srcDims = getOrAllocateDims(op.getSource()); + auto dstDims = getOrAllocateDims(op.getResult()); + auto dropped = op.getDroppedDims(); + int dstIdx = 0; + for (size_t i = 0; i < srcDims.size(); ++i) { + // If dimension is NOT dropped (rank-reduced), bind it to the next output + // dimension. + if (!dropped.test(i) && dstIdx < (int)dstDims.size()) { + dsu_.unionSets(srcDims[i], dstDims[dstIdx++]); + } + // Dropped dimensions are effectively ignored for tiling propagation of the + // result. + } +} + +void DimAnalyzer::processInsertSliceOp(tensor::InsertSliceOp op) { + // InsertSlice modifies 'Dest'. The Result shape matches 'Dest'. + bindDimensions(op.getDest(), op.getResult()); +} + +void DimAnalyzer::processLinalgOpGeneric(linalg::LinalgOp op) { + LDBG(" [Op] Processing Generic: " << op->getName()); + auto maps = op.getIndexingMapsArray(); + auto iterTypes = op.getIteratorTypesArray(); + + // Allocate IDs for the op's loop iterators. + int64_t loopStart = dsu_.allocate(op.getNumLoops()); + + // Set properties based on iterator types (Parallel vs Reduction). + for (int i = 0; i < (int)iterTypes.size(); ++i) { + DimKind k = linalg::isReductionIterator(iterTypes[i]) ? DimKind::Reduction + : DimKind::Parallel; + dsu_.setKind(loopStart + i, k); + } + + // Bind Operands to Loops using AffineMaps. + auto operands = op->getOperands(); + int mapIdx = 0; + for (auto val : operands) { + if (mapIdx >= (int)maps.size()) + break; + if (!isa(val.getType())) { + mapIdx++; + continue; + } + + AffineMap map = maps[mapIdx++]; + auto valDims = getOrAllocateDims(val); + + // If map is (d0, d1) -> (d0, d1), bind ValDim[0] to Loop[0], etc. + for (unsigned d = 0; d < map.getNumResults(); ++d) { + if (d >= valDims.size()) + continue; + if (auto dimExpr = dyn_cast(map.getResult(d))) { + dsu_.unionSets(valDims[d], loopStart + dimExpr.getPosition()); + } + } + } +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.cpp b/compiler/lib/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.cpp new file mode 100644 index 00000000..c4d5cf2c --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.cpp @@ -0,0 +1,248 @@ + +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "npu-stage-dep-analyzer" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; + +void StageDependencyAnalyzer::collectStages() { + LDBG(">>> [Analysis] Collecting Stages..."); + stages.clear(); + StageInfo currentStage; + currentStage.id = 0; + + Block *body = forOp.getBody(); + for (Operation &op : body->without_terminator()) { + // If the current operation is a SyncBlockWaitOp, it marks the start of a + // new stage. We complete the current stage (if it's not empty) and start a + // new one. The SyncBlockWaitOp will become the first operation of the new + // stage. + if (isa(op)) { + if (!currentStage.ops.empty()) { + stages.push_back(currentStage); + currentStage = StageInfo(); + currentStage.id = stages.size(); + } + } + + currentStage.ops.push_back(&op); + + // Mark the stage if it contains a sync wait operation + if (isa(op)) { + currentStage.hasSync = true; + } + } + + if (!currentStage.ops.empty()) { + stages.push_back(currentStage); + } + + nodes.resize(stages.size()); + for (size_t i = 0; i < stages.size(); ++i) { + nodes[i].id = i; + nodes[i].stageInfo = &stages[i]; + } + + LDBG("Collected " << stages.size() << " stages."); + + // Debug: dump the ops contained in each stage (print full op IR). + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "] Detailed stage contents:\n"; + for (const auto &stage : stages) { + llvm::dbgs() << "[" DEBUG_TYPE << "] Stage " << stage.id + << (stage.hasSync ? " (hasSync)" : "") + << " - ops: " << stage.ops.size() << "\n"; + for (Operation *op : stage.ops) { + llvm::dbgs() << " - "; + op->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + } + } + }); +} + +void StageDependencyAnalyzer::collectEffects() { + for (auto &node : nodes) { + for (Operation *op : node.stageInfo->ops) { + // 1. SSA Def-Use (Produced Values) + for (Value res : op->getResults()) { + node.producedValues.insert(res); + } + // 1. SSA Def-Use (Consumed Values) + for (Value operand : op->getOperands()) { + // We only care about operands defined within the loop (not block args + // or invariant) + if (auto defOp = operand.getDefiningOp()) { + if (defOp->getParentOp() == forOp) { + node.consumedValues.insert(operand); + } + } + } + + // 2. Memory Effects + if (auto memEffect = dyn_cast(op)) { + SmallVector> effects; + memEffect.getEffects(effects); + for (auto &effect : effects) { + Value val = effect.getValue(); + if (!val) + continue; + if (isa(effect.getEffect())) + node.writeValues.insert(val); + else if (isa(effect.getEffect())) + node.readValues.insert(val); + } + continue; + } + // Explicit handling for ops not implementing MemoryEffects but having + // semantics + if (auto matOp = + dyn_cast(op)) { + node.readValues.insert(matOp.getSource()); + node.writeValues.insert(matOp.getDest()); + } else if (auto copyOp = dyn_cast(op)) { + node.readValues.insert(copyOp.getSource()); + node.writeValues.insert(copyOp.getTarget()); + } + } + } +} + +void StageDependencyAnalyzer::buildDependencyGraph() { + LDBG(">>> [Analysis] Building Dependency Graph..."); + for (int i = 0; i < nodes.size(); ++i) { + for (int j = 0; j < nodes.size(); ++j) { + if (i == j) + continue; + bool hasDependency = false; + + // 1. Check SSA Dependencies (Direct Data Flow) + // If Stage J consumes a value produced by Stage I, J depends on I. + for (Value consumed : nodes[j].consumedValues) { + if (nodes[i].producedValues.count(consumed)) { + hasDependency = true; + break; + } + } + + // 2. Check Memory Dependencies + if (!hasDependency) { + for (Value writeVal : nodes[i].writeValues) { + for (Value readVal : nodes[j].readValues) { + AliasResult result = aliasAnalysis.alias(writeVal, readVal); + if (result.isMust() || result.isPartial()) { + hasDependency = true; + LDBG(" MEM DEPENDENCY: Stage " << i << " -> Stage " << j); + break; + } + } + if (hasDependency) + break; + } + } + + if (hasDependency) { + nodes[i].stageInfo->succs.insert(j); + nodes[j].stageInfo->preds.insert(i); + } + } + } +} + +LogicalResult StageDependencyAnalyzer::computeStageLevels() { + std::vector visitState(nodes.size(), + 0); // 0: unvisited, 1: visiting, 2: visited + std::function dfs = [&](int u) -> LogicalResult { + visitState[u] = 1; + int maxPredLevel = -1; + for (int v : nodes[u].stageInfo->preds) { + if (visitState[v] == 1) { + llvm::errs() << "Error: Cycle detected involving stages " << u + << " and " << v << "\n"; + return failure(); + } + if (visitState[v] == 0) { + if (failed(dfs(v))) + return failure(); + } + if (nodes[v].level > maxPredLevel) + maxPredLevel = nodes[v].level; + } + nodes[u].level = maxPredLevel + 1; + visitState[u] = 2; + return success(); + }; + + for (int i = 0; i < nodes.size(); ++i) { + if (visitState[i] == 0) { + if (failed(dfs(i))) + return failure(); + } + } + return success(); +} + +void StageDependencyAnalyzer::reorderStagesLogical() { + std::vector sortedNodes = nodes; + std::stable_sort(sortedNodes.begin(), sortedNodes.end(), + [](const StageNode &a, const StageNode &b) { + if (a.level != b.level) + return a.level < b.level; + return a.id < b.id; + }); + std::vector newStages; + newStages.reserve(stages.size()); + LDBG(">>> [Analysis] Reordered Stages (Logical Order):"); + for (const auto &node : sortedNodes) { + LDBG(" Stage ID: " << node.id << ", Level: " << node.level); + newStages.push_back(*node.stageInfo); + } + stages = std::move(newStages); +} + +void StageDependencyAnalyzer::materializeScheduleToIR() { + LDBG(">>> [Analysis] Materializing Schedule to IR (Physical Move)..."); + Operation *terminator = forOp.getBody()->getTerminator(); + for (const auto &stage : stages) { + for (Operation *op : stage.ops) { + if (op == terminator) + continue; + op->moveBefore(terminator); + } + } +} + +FailureOr> +StageDependencyAnalyzer::runAndReorder(RewriterBase &rewriter) { + collectStages(); + collectEffects(); + buildDependencyGraph(); + if (failed(computeStageLevels())) + return failure(); + reorderStagesLogical(); + LDBG(">>> [Result] Final Stage Dependency Summary:"); + LLVM_DEBUG(for (const auto &stage + : stages) { + llvm::dbgs() << "[" DEBUG_TYPE "] Stage " << stage.id << ":\n"; + llvm::dbgs() << " Predecessors (Depends on): { stage: "; + for (int p : stage.preds) + llvm::dbgs() << p << " "; + llvm::dbgs() << "}\n"; + llvm::dbgs() << " Successors (Depended by): { stage: "; + for (int s : stage.succs) + llvm::dbgs() << s << " "; + llvm::dbgs() << "}\n"; + }); + materializeScheduleToIR(); + return stages; +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt index 0b28548a..dab1048a 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -4,9 +4,11 @@ add_triton_library(LinalgExtTransforms ScalarTo1DTensorPass.cpp RemoveSingleIterationLoop.cpp TensorTransform.cpp + NPUUnroolPipeline.cpp DEPENDS LinalgExtTransformsIncGen + LinalgExtAnalysis LINK_LIBS PUBLIC TritonTilingExtIR @@ -18,6 +20,10 @@ add_triton_library(LinalgExtTransforms MLIRTensorDialect MLIRTransforms MLIRSupport + MLIRAnalysis + MLIRSCFUtils + MLIRSCFTransforms + TritonAnalysis TritonIR TritonTransforms diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/NPUUnroolPipeline.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/NPUUnroolPipeline.cpp new file mode 100644 index 00000000..d25d1f27 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/NPUUnroolPipeline.cpp @@ -0,0 +1,479 @@ +#include "dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h" +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/Support/Debug.h" + +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + +#include + +#define DEBUG_TYPE "npu-unroll-pipeline" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; +using namespace LinalgExt; + +namespace mlir { +namespace dicp { +namespace LinalgExt { +#define GEN_PASS_DEF_NPUUNROOLPIPELINE +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace LinalgExt +} // namespace dicp +} // namespace mlir + +namespace { + +LogicalResult verifyLoopForPipelining(scf::ForOp forOp) { + auto lbOpt = getConstantIntValue(forOp.getLowerBound()); + auto ubOpt = getConstantIntValue(forOp.getUpperBound()); + auto stepOpt = getConstantIntValue(forOp.getStep()); + + if (!lbOpt.has_value() || !ubOpt.has_value() || !stepOpt.has_value()) { + LDBG("Verification FAILED: Loop bounds or step are dynamic."); + return failure(); + } + + int64_t step = stepOpt.value(); + if (step == 0) { + LDBG("Verification FAILED: Infinite loop (step = 0)."); + return failure(); + } + + int64_t lb = lbOpt.value(); + int64_t ub = ubOpt.value(); + if (step > 0 && lb >= ub) { + LDBG("Verification FAILED: Loop body is never executed."); + return failure(); + } + + int64_t tripCount = (ub - lb + step - 1) / step; + LDBG("Verification PASSED. Static Trip Count: " << tripCount); + return success(); +} + +// Marks operations that define yielded values for tensor/memref iter_args +// This allows us to track loop-carried dependencies across unrolled iterations. +static LogicalResult markYieldSources(scf::ForOp forOp) { + auto yieldOp = cast(forOp.getBody()->getTerminator()); + + for (auto [idx, iterArg] : llvm::enumerate(forOp.getRegionIterArgs())) { + Value yieldVal = yieldOp.getOperand(idx); + + // Only strictly necessary for SSA values (tensors/scalars), but harmless + // for others. + if (auto defOp = yieldVal.getDefiningOp()) { + std::string attrName = "dicp.yield_for_iter_arg." + std::to_string(idx); + // We assume one op might feed multiple yield args, though rare. + // Ideally we check if attr exists, but simple overwrite is okay for 1:1. + defOp->setAttr( + attrName, + IntegerAttr::get(IntegerType::get(forOp.getContext(), 32), idx)); + LDBG(" Marked op '" << defOp->getName() + << "' as yield source for iter_arg " << idx); + } + } + return success(); +} + +static Operation *getYieldSourceForIterArg(scf::ForOp forOp, int iterArgIdx) { + // Linear scan is acceptable for loop bodies which are typically small-ish + for (Operation &op : forOp.getBody()->without_terminator()) { + std::string attrName = + "dicp.yield_for_iter_arg." + std::to_string(iterArgIdx); + if (op.hasAttr(attrName)) { + return &op; + } + } + return nullptr; +} + +class NPUUnrollPipeline { +public: + NPUUnrollPipeline(scf::ForOp forOp, int unrollFactor, + const std::vector &orderedStages) + : forOp(forOp), unrollFactor(unrollFactor), stages(orderedStages) {} + + LogicalResult run(RewriterBase &rewriter); + +private: + scf::ForOp forOp; + int unrollFactor; + const std::vector &stages; + int maxFlagPerIter = 0; + + // Map: OriginalValue -> Vector of Unrolled Values (one per iteration) + DenseMap> valueMapping; + // Map: OriginalOp -> Vector of Unrolled Ops (one per iteration) + // Needed to find cloned yield sources. + DenseMap> opMapping; + + void calculateMaxFlagStride(); + void prepareInitialMappings(RewriterBase &rewriter); + void updateHivmFlag(Operation *op, int iterIdx, RewriterBase &rewriter); + + Value getUnrolledValue(Value originalVal, int iterIdx); + Operation *cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + int iterIdx); +}; + +void NPUUnrollPipeline::calculateMaxFlagStride() { + int maxFlag = -1; + for (const auto &stage : stages) { + for (Operation *op : stage.ops) { + if (auto syncSetOp = dyn_cast(op)) { + int flag = getConstantIntValue(syncSetOp.getFlagId()).value_or(-1); + if (flag > maxFlag) + maxFlag = flag; + } else if (auto syncWaitOp = dyn_cast(op)) { + int flag = getConstantIntValue(syncWaitOp.getFlagId()).value_or(-1); + if (flag > maxFlag) + maxFlag = flag; + } + } + } + this->maxFlagPerIter = (maxFlag < 0) ? 0 : (maxFlag + 1); + LDBG("Flag Stride calculated: " << maxFlagPerIter); +} + +void NPUUnrollPipeline::prepareInitialMappings(RewriterBase &rewriter) { + LDBG(">>> [Unroll] Preparing Initial Mappings (Constants & IVs)..."); + Location loc = forOp.getLoc(); + Value lb = forOp.getLowerBound(); + Value step = forOp.getStep(); + Value iv = forOp.getInductionVar(); + Type ivType = iv.getType(); + + valueMapping[iv].resize(unrollFactor); + auto iterArgs = forOp.getRegionIterArgs(); + for (Value arg : iterArgs) { + valueMapping[arg].resize(unrollFactor, nullptr); + } + + for (int i = 0; i < unrollFactor; ++i) { + // 1. IV Calculation + Value idxVal = rewriter.create(loc, i); + Value idxValTyped = idxVal; + if (ivType != idxVal.getType()) + idxValTyped = rewriter.create(loc, ivType, idxVal); + + Value stepOffset = rewriter.create(loc, step, idxValTyped); + Value currentIV = rewriter.create(loc, lb, stepOffset); + valueMapping[iv][i] = currentIV; + + // 2. Simple IterArg Calculation (e.g. arithmetic induction) + auto yieldOp = cast(forOp.getBody()->getTerminator()); + for (auto it : llvm::enumerate(iterArgs)) { + Value iterArg = it.value(); + Value yieldVal = yieldOp.getOperand(it.index()); + + Operation *defOp = yieldVal.getDefiningOp(); + bool isSimpleIV = false; + int64_t stepConst = 0; + + if (auto addOp = dyn_cast_or_null(defOp)) { + Value lhs = addOp.getLhs(); + Value rhs = addOp.getRhs(); + Value constOp = nullptr; + if (lhs == iterArg) + constOp = rhs; + else if (rhs == iterArg) + constOp = lhs; + + if (constOp) { + if (auto cst = constOp.getDefiningOp()) { + stepConst = cst.value(); + isSimpleIV = true; + } else if (auto cst = constOp.getDefiningOp()) { + stepConst = cst.value(); + isSimpleIV = true; + } + } + } + + if (isSimpleIV) { + Value initVal = forOp.getInitArgs()[it.index()]; + Value kVal = rewriter.create(loc, i); + Value kValTyped = kVal; + if (iterArg.getType() != kVal.getType()) + kValTyped = + rewriter.create(loc, iterArg.getType(), kVal); + + Value stepVal; + if (iterArg.getType().isIndex()) + stepVal = rewriter.create(loc, stepConst); + else + stepVal = rewriter.create( + loc, iterArg.getType(), stepConst); + + Value offset = rewriter.create(loc, kValTyped, stepVal); + Value currVal = rewriter.create(loc, initVal, offset); + valueMapping[iterArg][i] = currVal; + } + } + } +} + +Value NPUUnrollPipeline::getUnrolledValue(Value originalVal, int iterIdx) { + // 1. Check existing mapping (simple IVs or previously cloned ops) + if (valueMapping.count(originalVal)) { + if (iterIdx >= 0 && iterIdx < valueMapping[originalVal].size()) { + Value mapped = valueMapping[originalVal][iterIdx]; + if (mapped) + return mapped; + } + } + + // 2. Handle BlockArguments (IterArgs) + if (auto arg = dyn_cast(originalVal)) { + if (arg.getOwner() == forOp.getBody()) { + // IV is handled in prepareInitialMappings (Slot 0 of args) + if (arg.getArgNumber() == 0) + return nullptr; + + // IterArgs start at index 1 + int iterArgIdx = arg.getArgNumber() - 1; + + // Case 2a: Iteration 0 uses the Loop Init Args (Full unroll) + if (iterIdx == 0) { + return forOp.getInitArgs()[iterArgIdx]; + } + + // Case 2b: Iteration K > 0 uses Yield result from K-1 + // Strategy: Find the op marked as yield source and look up its clone. + Operation *yieldSourceOp = getYieldSourceForIterArg(forOp, iterArgIdx); + if (yieldSourceOp) { + // The YieldOp operand tells us which result of the source op is used + auto yieldOp = cast(forOp.getBody()->getTerminator()); + Value yieldOperand = yieldOp.getOperand(iterArgIdx); + + if (auto res = dyn_cast(yieldOperand)) { + // If yield operand is a direct result of the marked op + if (res.getOwner() == yieldSourceOp) { + int resIdx = res.getResultNumber(); + // Check if the source op for the previous iteration was cloned + if (opMapping.count(yieldSourceOp) && + iterIdx - 1 < opMapping[yieldSourceOp].size()) { + Operation *prevClone = opMapping[yieldSourceOp][iterIdx - 1]; + if (prevClone) { + return prevClone->getResult(resIdx); + } else { + LDBG(" WARNING: Yield source clone missing for iter " + << iterIdx - 1); + } + } + } + } else if (auto argOperand = dyn_cast(yieldOperand)) { + // The yield operand is an IterArg itself (Pass-through) + // Recursively resolve it + return getUnrolledValue(argOperand, iterIdx - 1); + } + } + + // Fallback: If no complex logic found, try recursive lookup on yield + // operand (This handles cases where the yield val is invariant or defined + // elsewhere) + Value directYieldVal = + cast(forOp.getBody()->getTerminator()) + .getOperand(iterArgIdx); + return getUnrolledValue(directYieldVal, iterIdx - 1); + } + } + + // 3. Invariant or Global values + return originalVal; +} + +Operation *NPUUnrollPipeline::cloneAndUpdateOperands(RewriterBase &rewriter, + Operation *op, + int iterIdx) { + Operation *clone = rewriter.clone(*op); + + for (OpOperand &operand : clone->getOpOperands()) { + Value originalVal = op->getOperand(operand.getOperandNumber()); + Value replacement = getUnrolledValue(originalVal, iterIdx); + + if (replacement && replacement != originalVal) { + operand.set(replacement); + } else if (isa(originalVal) && + cast(originalVal).getOwner() == forOp.getBody()) { + // If we failed to resolve a loop arg, it's a critical error for valid IR + LDBG(" CRITICAL WARNING: Failed to resolve loop argument " + << originalVal << " at Iter " << iterIdx); + } + } + + // Record the cloned op in the mapping for future lookups (Yield Source + // resolution) + if (opMapping[op].size() <= iterIdx) + opMapping[op].resize(unrollFactor); + opMapping[op][iterIdx] = clone; + + return clone; +} + +void NPUUnrollPipeline::updateHivmFlag(Operation *op, int iterIdx, + RewriterBase &rewriter) { + if (maxFlagPerIter == 0) + return; + auto update = [&](auto syncOp) { + if (auto attr = syncOp.getStaticFlagIdAttr()) { + int64_t newFlag = attr.getInt() + iterIdx * maxFlagPerIter; + syncOp.setStaticFlagIdAttr(rewriter.getI64IntegerAttr(newFlag)); + } + }; + if (auto setOp = dyn_cast(op)) + update(setOp); + else if (auto waitOp = dyn_cast(op)) + update(waitOp); +} + +LogicalResult NPUUnrollPipeline::run(RewriterBase &rewriter) { + calculateMaxFlagStride(); + + // Resize mappings + for (Operation &op : forOp.getBody()->without_terminator()) { + opMapping[&op].resize(unrollFactor, nullptr); + for (Value res : op.getResults()) { + valueMapping[res].resize(unrollFactor, nullptr); + } + } + + rewriter.setInsertionPoint(forOp); + prepareInitialMappings(rewriter); + + LDBG(">>> [Unroll] Starting Clone (Stage-Major Order)..."); + + for (const auto &stage : stages) { + LDBG(" Processing Stage " << stage.id); + for (int iterIdx = 0; iterIdx < unrollFactor; ++iterIdx) { + for (Operation *op : stage.ops) { + if (isa(op)) + continue; + + Operation *clonedOp = cloneAndUpdateOperands(rewriter, op, iterIdx); + + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "] [Stg " << stage.id << "][Iter " + << iterIdx << "] Cloned Op: "; + clonedOp->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + updateHivmFlag(clonedOp, iterIdx, rewriter); + + // Update value mapping for results + for (auto it : llvm::zip(op->getResults(), clonedOp->getResults())) { + Value originalRes = std::get<0>(it); + Value newRes = std::get<1>(it); + if (iterIdx < valueMapping[originalRes].size()) + valueMapping[originalRes][iterIdx] = newRes; + } + } + } + } + + LDBG(">>> [Unroll] Replacing Loop Results..."); + Operation *terminator = forOp.getBody()->getTerminator(); + SmallVector finalResults; + + // The final results correspond to the yield values of the LAST iteration + int lastIter = unrollFactor - 1; + + for (Value operand : terminator->getOperands()) { + Value remapped = getUnrolledValue(operand, lastIter); + if (!remapped) + remapped = operand; + finalResults.push_back(remapped); + } + + if (forOp.getNumResults() != finalResults.size()) { + return forOp.emitError("Unroll result count mismatch"); + } + + rewriter.replaceOp(forOp, finalResults); + LDBG("<<< Pass Complete."); + return success(); +} + +struct NPUUnroolPipelinePass + : public mlir::dicp::LinalgExt::impl::NPUUnroolPipelineBase< + NPUUnroolPipelinePass> { + NPUUnroolPipelinePass() = default; + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + + SmallVector loops; + func.walk([&](scf::ForOp loop) { + if (loop->hasAttr(mlir::triton::kNumStagesAttrName)) + loops.push_back(loop); + }); + + if (loops.size() != 1) { + LDBG("The number of candidate loops is not one."); + return; + } + + scf::ForOp targetLoop = loops[0]; + if (failed(verifyLoopForPipelining(targetLoop))) { + LDBG("Loop verification failed, skipping."); + return; + } + + int numStages = mlir::cast( + targetLoop->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + + LDBG("Processing Loop with num_stages = " << numStages); + + mlir::IRRewriter rewriter(func.getContext()); + AliasAnalysis &aa = getAnalysis(); + + // 1. Analyze and Reorder Stages (Topological Sort) + StageDependencyAnalyzer analyzer(targetLoop, aa); + auto orderedStagesOrFailure = analyzer.runAndReorder(rewriter); + + if (failed(orderedStagesOrFailure)) { + LDBG("Failed to reorder stages (cyclic dependency detected)."); + signalPassFailure(); + return; + } + // 2. Mark Yield Sources for complex iter_args + if (failed(markYieldSources(targetLoop))) { + signalPassFailure(); + return; + } + + // 3. Execute Unroll (Stage-Major) + NPUUnrollPipeline unroller(targetLoop, numStages, + orderedStagesOrFailure.value()); + if (failed(unroller.run(rewriter))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::LinalgExt::createNPUUnroolPipelinePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/test/ascend/passed_tests/test_cv_unroll_pipleine.py b/test/ascend/passed_tests/test_cv_unroll_pipleine.py new file mode 100644 index 00000000..fb281581 --- /dev/null +++ b/test/ascend/passed_tests/test_cv_unroll_pipleine.py @@ -0,0 +1,756 @@ +import pytest +import torch +import triton +import triton.language as tl +import triton.language.extra.deeplink as dl +import torch_npu +import triton.testing + +DEVICE = "npu" + + +def require_npu(): + try: + torch.empty(1, device=DEVICE) + except Exception: + pytest.skip("npu device not available") + + +# ------------------- Triton kernels (kept same functional implementation) ------------------- +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + qk_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + N_CTX: tl.constexpr, + fp8_v: tl.constexpr, +): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + else: + lo, hi = 0, N_CTX + + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load(K_block_ptr) + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + + p = tl.math.exp(qk) + p_cast = p.to(tl.float16) + v = tl.load(V_block_ptr) + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + pv + + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + M, + Out, + sm_scale: tl.constexpr, + stride_qz: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_qk: tl.constexpr, + stride_kz: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kk: tl.constexpr, + stride_vz: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vk: tl.constexpr, + stride_oz: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + stride_on: tl.constexpr, + Z: tl.constexpr, + H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + NUM_BLOCKS_PER_CORE: tl.constexpr, + NUM_BLOCKS: tl.constexpr, + NUM_BLOCKS_M: tl.constexpr, +): + pid = tl.program_id(0) + for block_idx in range(pid, NUM_BLOCKS, 24): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + q = tl.load(Q_block_ptr) + + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + task_m_idx, + sm_scale, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + N_CTX, + V.dtype.element_ty == tl.float8e5, + ) + + if STAGE & 2: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + task_m_idx, + sm_scale, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 2, + offs_m, + offs_n, + N_CTX, + V.dtype.element_ty == tl.float8e5, + ) + + m_i += tl.math.log(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_split_cv( + Q, + K, + V, + M, + Out, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + stride_qz: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_qk: tl.constexpr, + stride_kz: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kk: tl.constexpr, + stride_vz: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vk: tl.constexpr, + stride_oz: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + stride_on: tl.constexpr, + w1_stride_nb: tl.constexpr, + w1_stride_bm: tl.constexpr, + w1_stride_bn: tl.constexpr, + w2_stride_nb: tl.constexpr, + w2_stride_bm: tl.constexpr, + w2_stride_bn: tl.constexpr, + w3_stride_nb: tl.constexpr, + w3_stride_bm: tl.constexpr, + w3_stride_dm: tl.constexpr, + Z: tl.constexpr, + H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_CORES: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + NUM_BLOCKS_M = N_CTX // BLOCK_M + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + pid = tl.program_id(0) + for block_idx in tl.range(pid, NUM_BLOCKS, NUM_CORES): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + q = tl.load(Q_block_ptr) + K_block_ptr = tl.advance(K_block_ptr, (0, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, 0)) + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + lo, hi = 0, N_CTX + for start_n in range(lo, hi, BLOCK_N * NUM_STAGES): + for i in tl.range(0, NUM_STAGES, num_stages=NUM_STAGES): + workspace_1_ptr = tl.make_block_ptr( + base=workspace_1 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w1_stride_nb, + shape=(BLOCK_M, BLOCK_N), + strides=(w1_stride_bm, w1_stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + workspace_2_ptr = tl.make_block_ptr( + base=workspace_2 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w2_stride_nb, + shape=(BLOCK_M, BLOCK_N), + strides=(w2_stride_bm, w2_stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + workspace_3_ptr = tl.make_block_ptr( + base=workspace_3 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w3_stride_nb, + shape=(BLOCK_M, HEAD_DIM), + strides=(w3_stride_bm, w3_stride_dm), + offsets=(0, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + with dl.async_task(scope=dl.async_task.cube): + k = tl.load(K_block_ptr) + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + tl.store(workspace_1_ptr, qk) + + dl.set_cross_flag(dl.SyncFlag.C2V, 0) + dl.wait_cross_flag(dl.SyncFlag.V2C, 1) + + p_cast = tl.load(workspace_2_ptr) + v = tl.load(V_block_ptr) + acc_l0c = tl.dot(p_cast, v) + tl.store(workspace_3_ptr, acc_l0c) + dl.set_cross_flag(dl.SyncFlag.C2V, 2) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + + with dl.async_task(scope=dl.async_task.vector): + dl.wait_cross_flag(dl.SyncFlag.C2V, 0) + + qk = tl.load(workspace_1_ptr) + qk = qk * sm_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp(qk) + p_cast = p.to(Q.type.element_ty) + tl.store(workspace_2_ptr, p_cast) + dl.set_cross_flag(dl.SyncFlag.V2C, 1) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + dl.wait_cross_flag(dl.SyncFlag.C2V, 2) + acc_ptr = acc_ptr * alpha[:, None] + acc_o_ub = tl.load(workspace_3_ptr) + acc_ptr = acc_ptr + acc_o_ub + m_i = m_ij + + m_i += tl.math.log(l_i) + accumulator = acc_ptr / l_i[:, None] + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +# ------------------- Python wrappers and Function classes ------------------- +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale, BM, BN, causal=False): + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + o = torch.empty_like(q) + stage = 3 if causal else 1 + num_cores = 24 + NUM_BLOCKS_M = triton.cdiv(q.shape[2], BM) + NUM_BLOCKS = NUM_BLOCKS_M * q.shape[0] * q.shape[1] + NUM_BLOCKS_PER_CORE = triton.cdiv(NUM_BLOCKS, num_cores) + + M = torch.empty( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + _attn_fwd[(num_cores,)]( + q, + k, + v, + M, + o, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + STAGE=stage, + NUM_BLOCKS_PER_CORE=NUM_BLOCKS_PER_CORE, + NUM_BLOCKS=NUM_BLOCKS, + NUM_BLOCKS_M=NUM_BLOCKS_M, + multibuffer=True, + unit_flag=True, + debug=False, + ) + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + +@triton.jit +def _attn_fwd_split_cv_launcher( + Q, + K, + V, + M, + o, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + q_stride0, + q_stride1, + q_stride2, + q_stride3, + k_stride0, + k_stride1, + k_stride2, + k_stride3, + v_stride0, + v_stride1, + v_stride2, + v_stride3, + o_stride0, + o_stride1, + o_stride2, + o_stride3, + w1_nb, + w1_bm, + w1_bn, + w2_nb, + w2_bm, + w2_bn, + w3_nb, + w3_bm, + w3_dm, + Z: tl.constexpr, + H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_CORES: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + _attn_fwd_split_cv( + Q, + K, + V, + M, + o, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + q_stride0, + q_stride1, + q_stride2, + q_stride3, + k_stride0, + k_stride1, + k_stride2, + k_stride3, + v_stride0, + v_stride1, + v_stride2, + v_stride3, + o_stride0, + o_stride1, + o_stride2, + o_stride3, + w1_nb, + w1_bm, + w1_bn, + w2_nb, + w2_bm, + w2_bn, + w3_nb, + w3_bm, + w3_dm, + Z=Z, + H=H, + N_CTX=N_CTX, + HEAD_DIM=HEAD_DIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + NUM_CORES=NUM_CORES, + NUM_STAGES=NUM_STAGES, + ) + + +class AttentionSplitCV(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale, BM, BN, causal=False): + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + extra_kern_args = {} + + o = torch.empty_like(q) + N_CTX = q.shape[2] + Z, H = q.shape[0], q.shape[1] + NUM_BLOCKS_M = N_CTX // BM + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + DIM = q.shape[-1] + NUM_CORES = 24 + NUM_STAGES = 4 + acc = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), + dtype=torch.float32, + device=q.device, + ) + M = torch.empty( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + workspace_1 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, BN), device=q.device, dtype=torch.float32 + ) + workspace_2 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, BN), device=q.device, dtype=q.dtype + ) + workspace_3 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, DIM), device=q.device, dtype=torch.float32 + ) + + _attn_fwd_split_cv_launcher[(NUM_CORES,)]( + q, + k, + v, + M, + o, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + workspace_1.stride(1), + workspace_1.stride(2), + workspace_1.stride(3), + workspace_2.stride(1), + workspace_2.stride(2), + workspace_2.stride(3), + workspace_3.stride(1), + workspace_3.stride(2), + workspace_3.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + NUM_CORES=NUM_CORES, + NUM_STAGES=NUM_STAGES, + disable_auto_inject_block_sync=True, + disable_auto_cv_work_space_manage=True, + **extra_kern_args, + ) + + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + return o + + +attention_base = _attention.apply +attention_split_cv = AttentionSplitCV.apply + +# ------------------- Tests (expanded to include all original test_op cases) ------------------- +ALL_CASES = [ + (1, 2, 2048, 128, 64, 128, False), + (4, 32, 1024, 64, 64, 256, False), + # 超长序列 + (1, 1, 1024 * 32, 128, 64, 128, False), + # 中等规模 + (4, 4, 512, 128, 16, 128, False), + (8, 32, 512, 256, 16, 128, False), + # 小序列 / tile 多样性 + (32, 32, 64, 64, 64, 16, False), + (32, 32, 128, 128, 64, 32, False), + (32, 32, 256, 128, 64, 64, False), + # 常见 LLM 配置 + (1, 8, 1024, 64, 64, 128, False), + (8, 12, 512, 64, 128, 128, False), + # 长上下文 + (1, 16, 2048, 128, 64, 128, False), + (1, 32, 4096, 128, 64, 128, False), +] + + +@pytest.mark.xdist_group(name="attention_ref_group") +@pytest.mark.parametrize("Z,H,N_CTX,HEAD_DIM,BM,BN,causal", ALL_CASES) +def test_attention_matches_reference_all(Z, H, N_CTX, HEAD_DIM, BM, BN, causal): + require_npu() + torch.manual_seed(20) + dtype = torch.float16 + + q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( + mean=0.0, std=0.5 + ) + k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( + mean=0.0, std=0.5 + ) + v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( + mean=0.0, std=0.5 + ) + + sm_scale = 0.5 + + ref_out = torch_npu.npu_fusion_attention( + q, + k, + v, + H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + tri_out_base = attention_base(q, k, v, sm_scale, BM, BN, causal).half() + tri_out_cv = attention_split_cv(q, k, v, sm_scale, BM, BN, causal).half() + + atol = 1e-3 + rtol = 0.0 + + assert torch.allclose(ref_out, tri_out_base, atol=atol, rtol=rtol) + assert torch.allclose(ref_out, tri_out_cv, atol=atol, rtol=rtol) + + +# # Performance test for the ultra-long sequence; asserts torch_time < tri_time* 0.30 +# @pytest.mark.xdist_group(name="test_perf_long_sequence") +# def test_perf_long_sequence(): +# require_npu() +# torch.manual_seed(20) +# Z, H, N_CTX, HEAD_DIM, BM, BN, causal = (1, 1, 1024 * 64, 128, 64, 256, False) +# dtype = torch.float16 + +# q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( +# mean=0.0, std=0.5 +# ) +# k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( +# mean=0.0, std=0.5 +# ) +# v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( +# mean=0.0, std=0.5 +# ) + +# sm_scale = 0.5 + +# # Warmup & rep counts: keep modest to avoid extremely long CI runs; adjust as needed. +# warmup = 50 +# rep = 50 + +# # measure torch_npu fused attention +# torch_fn = lambda: torch_npu.npu_fusion_attention( +# q, +# k, +# v, +# H, +# padding_mask=None, +# atten_mask=None, +# scale=sm_scale, +# keep_prob=1.0, +# input_layout="BNSD", +# pre_tockens=65535, +# next_tockens=65535, +# sparse_mode=0, +# )[0] + +# tri_fn = lambda: attention_split_cv(q, k, v, sm_scale, BM, BN, causal) + +# # triton.testing.do_bench returns ms +# torch_ms = triton.testing.do_bench(torch_fn, warmup=warmup, rep=rep) +# tri_ms = triton.testing.do_bench(tri_fn, warmup=warmup, rep=rep) + +# # print for visibility when running tests +# print( +# f"torch_npu fusion ms: {torch_ms:.3f} ms; triton split_cv ms: {tri_ms:.3f} ms" +# ) + +# # require triton to be faster than 30% of torch time +# assert ( +# torch_ms > tri_ms * 0.30 +# ), f"triton({torch_ms :.3f}ms) must be < 30% of triton({tri_ms:.3f}ms)" diff --git a/tools/dicp_triton_opt/CMakeLists.txt b/tools/dicp_triton_opt/CMakeLists.txt index bd00d88a..6e03085e 100644 --- a/tools/dicp_triton_opt/CMakeLists.txt +++ b/tools/dicp_triton_opt/CMakeLists.txt @@ -8,6 +8,7 @@ target_link_libraries(dicp_opt PRIVATE TritonAnalysis TritonTransforms TritonGPUTransforms + TritonNvidiaGPUTransforms TritonSharedAnalysis ${dialect_libs} ${translation_libs} @@ -23,6 +24,8 @@ target_link_libraries(dicp_opt PRIVATE DICPLinalgExt DiscreteMaskAccessConversion + BiShengIRHIVMDialect + LinalgExtAnalysis TritonToLinalg TritonTilingExtIR TritonToLinalgNPUCoversion diff --git a/tools/dicp_triton_opt/dicp_triton_opt.cpp b/tools/dicp_triton_opt/dicp_triton_opt.cpp index dccd884e..4ed6c24e 100644 --- a/tools/dicp_triton_opt/dicp_triton_opt.cpp +++ b/tools/dicp_triton_opt/dicp_triton_opt.cpp @@ -83,6 +83,8 @@ #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h.inc" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + using namespace mlir; inline void registerDICPDialects(mlir::DialectRegistry ®istry) { @@ -105,6 +107,7 @@ inline void registerDICPDialects(mlir::DialectRegistry ®istry) { dicp::LinalgExt::registerLinalgGenericToSCFPass(); dicp::LinalgExt::registerScalarTo1DTensorPass(); dicp::LinalgExt::registerNormalizeSliceOpsPass(); + dicp::LinalgExt::registerNPUUnroolPipelinePass(); registry.insert(); + ttx::TritonTilingExtDialect, mlir::hivm::HIVMDialect>(); } int main(int argc, char **argv) { diff --git a/triton_dicp_triton.cc b/triton_dicp_triton.cc index 979e0071..9074c342 100644 --- a/triton_dicp_triton.cc +++ b/triton_dicp_triton.cc @@ -70,6 +70,11 @@ void init_triton_dicp_triton_pass_linked_npu(py::module &&m) { pm.addNestedPass( dicp::LinalgExt::createScalarTo1DTensorPass()); }); + m.def("add_npu_unroll_pipeline", [](mlir::PassManager &pm) { + pm.addNestedPass( + dicp::LinalgExt::createNPUUnroolPipelinePass()); + }); + m.def("add_linalg_to_linked", [](mlir::PassManager &pm, bool globalKernel, bool namedOps) { pm.addPass(mlir::dicp::linked::createLinalgToLinkedPass(globalKernel, @@ -107,6 +112,7 @@ void init_triton_dicp_triton(py::module &&m) { dicp::LinalgExt::registerLinalgGenericToSCFPass(); dicp::LinalgExt::registerScalarTo1DTensorPass(); dicp::LinalgExt::registerNormalizeSliceOpsPass(); + dicp::LinalgExt::registerNPUUnroolPipelinePass(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects();