diff --git a/.gitmodules b/.gitmodules index 6889aed..a4db716 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "test/googletest"] path = test/googletest url = https://github.com/google/googletest.git +[submodule "third_party/TinyFA"] + path = third_party/TinyFA + url = https://github.com/keith2018/TinyFA.git diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0f616e8..76ea47c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,6 +4,8 @@ project(TinyTorch_lib) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(THIRD_PARTY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../third_party) + if (TINYTORCH_USE_CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda --expt-relaxed-constexpr") find_package(CUDAToolkit REQUIRED) @@ -40,7 +42,7 @@ else () add_library(${PROJECT_NAME} ${TinyTorch_SRC_CPP}) endif () -target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../third_party) +target_include_directories(${PROJECT_NAME} PRIVATE ${THIRD_PARTY_DIR}) # disable exceptions target_compile_options(${PROJECT_NAME} PRIVATE @@ -92,3 +94,11 @@ if (TINYTORCH_USE_NCCL) target_include_directories(${PROJECT_NAME} PUBLIC ${NCCL_INCLUDE_DIRS}) target_link_libraries(${PROJECT_NAME} PUBLIC ${NCCL_LIBRARY}) endif () + +# TinyFA +if (TINYTORCH_USE_CUDA) + set(TFA_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(TFA_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) + add_subdirectory(${THIRD_PARTY_DIR}/TinyFA TinyFA) + target_link_libraries(${PROJECT_NAME} PRIVATE TinyFA::tinyfa) +endif () diff --git a/src/Operation/FlashAtten/config.cuh b/src/Operation/FlashAtten/config.cuh deleted file mode 100644 index f34c262..0000000 --- a/src/Operation/FlashAtten/config.cuh +++ /dev/null @@ -1,79 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -#include -#include - -namespace tfa { - -template -struct KernelConfig { - using DType = DType_; - - static constexpr int kHeadDim = HeadDim_; - static constexpr int kBr = Br_; - static constexpr int kBc = Bc_; - static constexpr int kNumWarps = NumWarps_; - - static constexpr int kWarpSize = 32; - static constexpr int kNumThreads = kNumWarps * kWarpSize; - - static constexpr bool kUseSwizzle = true; - static constexpr int kBytesPerVecLoad = 16; // uint4 - static constexpr int kElemsPerVecLoad = kBytesPerVecLoad / sizeof(DType); - - static_assert(kBr % kNumWarps == 0, "kBr must be divisible by kNumWarps"); - static_assert(kBc % kWarpSize == 0, "kBc must be divisible by kWarpSize"); - static_assert(kHeadDim % kWarpSize == 0, "kHeadDim must be divisible by kWarpSize"); - - static constexpr int kRowsPerWarp = kBr / kNumWarps; - static constexpr int kColsPerLane = kBc / kWarpSize; - static constexpr int kDimsPerLane = kHeadDim / kWarpSize; -}; - -#define TFA_HEAD_DIM_CASE(N, HEAD_DIM_VAR, ...) \ - case N: { \ - constexpr int HEAD_DIM_VAR = N; \ - __VA_ARGS__ \ - break; \ - } - -#define TFA_DISPATCH_HEAD_DIM(headDim, HEAD_DIM_VAR, ...) \ - [&] { \ - switch (headDim) { \ - TFA_HEAD_DIM_CASE(32, HEAD_DIM_VAR, __VA_ARGS__) \ - TFA_HEAD_DIM_CASE(64, HEAD_DIM_VAR, __VA_ARGS__) \ - TFA_HEAD_DIM_CASE(128, HEAD_DIM_VAR, __VA_ARGS__) \ - TFA_HEAD_DIM_CASE(256, HEAD_DIM_VAR, __VA_ARGS__) \ - default: \ - printf("Unsupported headDim: %d\n", headDim); \ - break; \ - } \ - }() - -// default configuration -template -struct ConfigForHeadDim { - using Config = KernelConfig; -}; - -#define TFA_DEFINE_CONFIG(DTYPE, HEADDIM, BR, BC, WARPS) \ - template <> \ - struct ConfigForHeadDim { \ - using Config = KernelConfig; \ - } - -using FP32 = float; -using FP16 = __half; -using BF16 = __nv_bfloat16; - -TFA_DEFINE_CONFIG(FP32, 128, 64, 64, 8); -TFA_DEFINE_CONFIG(FP16, 128, 128, 64, 8); -TFA_DEFINE_CONFIG(BF16, 128, 128, 64, 8); - -} // namespace tfa diff --git a/src/Operation/FlashAtten/gemm.cuh b/src/Operation/FlashAtten/gemm.cuh deleted file mode 100644 index 333b015..0000000 --- a/src/Operation/FlashAtten/gemm.cuh +++ /dev/null @@ -1,196 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -#include "config.cuh" -#include "tile.cuh" -#include "utils.cuh" - -namespace tfa { - -template -struct GemmOp { - using DType = typename Config::DType; - - static constexpr int kDimsPerLane = Config::kDimsPerLane; - static constexpr int kRowsPerWarp = Config::kRowsPerWarp; - static constexpr int kColsPerLane = Config::kColsPerLane; - static constexpr int kElemsPerVec = Config::kElemsPerVecLoad; - static constexpr int kWarpSize = Config::kWarpSize; - - static_assert(Config::kHeadDim % kElemsPerVec == 0, "kHeadDim must be divisible by kElemsPerVec"); - static constexpr int kNumVecs = Config::kHeadDim / kElemsPerVec; - - // S = Q @ K^T - template - __device__ static void computeScore(S& s, const QTile& qTile, const KTile& kTile, Context& ctx) { - bool needsCausal = ctx.template needsCausalMask(); - bool isPartial = ctx.isPartialTile(); - - if (needsCausal) { - isPartial ? computeScoreImpl(s, qTile, kTile, ctx) - : computeScoreImpl(s, qTile, kTile, ctx); - } else { - isPartial ? computeScoreImpl(s, qTile, kTile, ctx) - : computeScoreImpl(s, qTile, kTile, ctx); - } - } - - private: - template - __device__ static void computeScoreImpl(S& s, const QTile& qTile, const KTile& kTile, Context& ctx) { - DType qReg[kRowsPerWarp][kElemsPerVec]; - DType kReg[kColsPerLane][kElemsPerVec]; - - initScores(s); - accumDotProducts(s, qReg, kReg, qTile, kTile, ctx); - applyScaleAndMask(s, ctx); - } - template - __device__ static __forceinline__ void initScores(S& s) { -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { -#pragma unroll - for (int n = 0; n < kColsPerLane; n++) { - s[m][n] = 0.f; - } - } - } - - template - __device__ static __forceinline__ void accumDotProducts(S& s, DType qReg[][kElemsPerVec], DType kReg[][kElemsPerVec], - const QTile& qTile, const KTile& kTile, - Context& ctx) { -#pragma unroll - for (int v = 0; v < kNumVecs; v++) { - loadQ(qReg, qTile, ctx, v); - loadK(kReg, kTile, ctx, v); - accumulate(s, qReg, kReg); - } - } - - template - __device__ static __forceinline__ void loadQ(DType qReg[][kElemsPerVec], const QTile& qTile, - const Context& ctx, int v) { -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { - qTile.sm2reg(&qReg[m][0], ctx.warpRowOffset + m, v); - } - } - - template - __device__ static __forceinline__ void loadK(DType kReg[][kElemsPerVec], const KTile& kTile, - const Context& ctx, int v) { -#pragma unroll - for (int n = 0; n < kColsPerLane; n++) { - int colN = ctx.laneId + n * kWarpSize; - if constexpr (kBoundaryMask) { - if (colN < ctx.curTileSizeKV) { - kTile.sm2reg(&kReg[n][0], colN, v); - } else { -#pragma unroll - for (int e = 0; e < kElemsPerVec; e++) { - kReg[n][e] = fromFloat(0.f); - } - } - } else { - kTile.sm2reg(&kReg[n][0], colN, v); - } - } - } - - template - __device__ static __forceinline__ void accumulate(S& s, const DType qReg[][kElemsPerVec], - const DType kReg[][kElemsPerVec]) { -#pragma unroll - for (int e = 0; e < kElemsPerVec; e++) { -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { - float q = toFloat(qReg[m][e]); -#pragma unroll - for (int n = 0; n < kColsPerLane; n++) { - s[m][n] += q * toFloat(kReg[n][e]); - } - } - } - } - - template - __device__ static __forceinline__ void applyScaleAndMask(S& s, const Context& ctx) { -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { - int globalQ = ctx.globalRowQ(m); - int globalKV = ctx.globalKV(0); -#pragma unroll - for (int n = 0; n < kColsPerLane; n++) { - bool masked = false; - if constexpr (kBoundaryMask) { - masked |= (ctx.laneId + n * kWarpSize >= ctx.curTileSizeKV); - } - if constexpr (kCausalMask) { - masked |= (globalKV > globalQ); - globalKV += kWarpSize; - } - s[m][n] = masked ? -INFINITY : (s[m][n] * ctx.kAttnScale); - } - } - } - - public: - // O += P @ V - template - __device__ static void computeOutput(OTile& oTile, const P& prob, const VTile& vTile, - const Context& ctx) { - constexpr int kVecsPerThread = kDimsPerLane / kElemsPerVec; - const int vecBase = ctx.laneId * kVecsPerThread; - DType vReg[kDimsPerLane]; - -#pragma unroll - for (int n = 0; n < Config::kBc; n++) { - loadV(vReg, vTile, ctx, n, vecBase); - accumPV(oTile.acc, prob, vReg, n); - } - } - - private: - template - __device__ static __forceinline__ void loadV(DType vReg[], const VTile& vTile, const Context& ctx, int n, - int vecBase) { - constexpr int kRemainElems = kDimsPerLane % kElemsPerVec; - -#pragma unroll - for (int vi = 0; vi < kVecsPerThread; vi++) { - vTile.sm2reg(&vReg[vi * kElemsPerVec], n, vecBase + vi); - } - - // remaining elements - if constexpr (kRemainElems > 0) { - int remainOffset = ctx.laneId * kDimsPerLane + kVecsPerThread * kElemsPerVec; -#pragma unroll - for (int ri = 0; ri < kRemainElems; ri++) { - vReg[kVecsPerThread * kElemsPerVec + ri] = vTile.at(n, remainOffset + ri); - } - } - } - - template - __device__ static __forceinline__ void accumPV(AccO& accO, const P& prob, const DType vReg[], int n) { - int srcLane = n % kWarpSize; - int colIdx = n / kWarpSize; - -#pragma unroll - for (int k = 0; k < kDimsPerLane; k++) { - float v = toFloat(vReg[k]); -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { - accO[m][k] += __shfl_sync(0xffffffff, prob[m][colIdx], srcLane) * v; - } - } - } -}; - -} // namespace tfa diff --git a/src/Operation/FlashAtten/kernel.cuh b/src/Operation/FlashAtten/kernel.cuh deleted file mode 100644 index 449d1f9..0000000 --- a/src/Operation/FlashAtten/kernel.cuh +++ /dev/null @@ -1,188 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -#include "gemm.cuh" -#include "softmax.cuh" -#include "tile.cuh" - -namespace tfa { - -template -struct ThreadInfo { - static constexpr int kWarpSize = Config::kWarpSize; - static constexpr int kRowsPerWarp = Config::kRowsPerWarp; - static constexpr int kDimsPerLane = Config::kDimsPerLane; - - const int threadId = threadIdx.x; - const int blockSize = blockDim.x; - const int warpId = threadId / kWarpSize; - const int laneId = threadId % kWarpSize; - const int warpRowOffset = warpId * kRowsPerWarp; - - __device__ ThreadInfo() = default; -}; - -template -struct BlockInfo { - static constexpr int kTileQ = Config::kBr; // tile size for Q - static constexpr int kTileKV = Config::kBc; // tile size for KV - - const int batchIdx = blockIdx.z; - const int headIdx = blockIdx.y; - const int tileIdxQ = blockIdx.x; - const int tileQ = tileIdxQ * kTileQ; - const int headIdxKV; - const int seqLenQ; // sequence length for Q - const int seqLenKV; // sequence length for KV - const int tileSizeQ; - - __device__ explicit BlockInfo(const Params& params) - : headIdxKV(params.getKVHead(headIdx)), - seqLenQ(params.getSeqLenQ(batchIdx)), - seqLenKV(params.getSeqLenKV(batchIdx)), - tileSizeQ(min(kTileQ, seqLenQ - tileQ)) {} -}; - -template -struct KernelContext : ThreadInfo, BlockInfo { - using DType = typename Config::DType; - using Thread = ThreadInfo; - using Block = BlockInfo; - - static constexpr int kRowsPerWarp = Config::kRowsPerWarp; - static constexpr int kColsPerLane = Config::kColsPerLane; - static constexpr int kDimsPerLane = Config::kDimsPerLane; - static constexpr float kAttnScale = AttentionScale::value; - - const Params& params; - - int tileKV = 0; - int curTileSizeKV = 0; - - __device__ explicit KernelContext(const Params& p) : Thread(), Block(p), params(p) {} - - using Block::batchIdx; - using Block::headIdx; - using Block::headIdxKV; - using Block::kTileKV; - using Block::seqLenKV; - using Block::seqLenQ; - using Block::tileQ; - using Block::tileSizeQ; - - using Thread::blockSize; - using Thread::laneId; - using Thread::threadId; - using Thread::warpRowOffset; - - __device__ __forceinline__ bool isValidTile() const { return tileQ < seqLenQ; } - - __device__ __forceinline__ int validWarpRows() const { - int warpStartQ = tileQ + warpRowOffset; - return (warpStartQ < seqLenQ) ? min(kRowsPerWarp, seqLenQ - warpStartQ) : 0; - } - - __device__ __forceinline__ int globalRowQ(int localRow) const { return tileQ + warpRowOffset + localRow; } - - __device__ __forceinline__ int globalKV(int n) const { return tileKV + laneId + n * Config::kWarpSize; } - - template - __device__ __forceinline__ void setTileKV(int tileIdx) { - tileKV = tileIdx * kTileKV; - curTileSizeKV = min(kTileKV, seqLenKV - tileKV); - if constexpr (kIsCausal) { - curTileSizeKV = min(curTileSizeKV, tileQ + tileSizeQ - tileKV); - } - } - - template - __device__ __forceinline__ int numTilesKV() const { - int tiles = params.getKVTiles(seqLenKV, kTileKV); - if constexpr (kIsCausal) { - tiles = min(tiles, ceilDiv(tileQ + tileSizeQ, kTileKV)); - } - return tiles; - } - - template - __device__ __forceinline__ bool needsCausalMask() const { - return kIsCausal && (tileKV + kTileKV > tileQ); - } - - // check if current KV tile is partial - __device__ __forceinline__ bool isPartialTile() const { return curTileSizeKV < kTileKV; } - - __device__ __forceinline__ const DType* qPtr() const { return params.qPtr(*this); } - __device__ __forceinline__ const DType* kPtr() const { return params.kPtr(*this, tileKV); } - __device__ __forceinline__ const DType* vPtr() const { return params.vPtr(*this, tileKV); } - __device__ __forceinline__ DType* oPtr() const { return params.oPtr(*this); } - __device__ __forceinline__ int seqDimQ() const { return params.seqDimQ; } - __device__ __forceinline__ int seqDimKV() const { return params.seqDimKV; } -}; - -template -__global__ void flashAttentionKernel(Params params) { - using DType = typename Config::DType; - using Context = KernelContext; - - constexpr int kRowsPerWarp = Config::kRowsPerWarp; - constexpr int kColsPerLane = Config::kColsPerLane; - - Context ctx(params); - if (!ctx.isValidTile()) return; - - extern __shared__ char smemBuf[]; - auto* smem = reinterpret_cast(smemBuf); - - QTile qTile(smem); - OTile oTile(smem); - KTile kTile(smem + qTile.numElems()); - VTile vTile(smem + qTile.numElems()); - - Softmax softmax; - softmax.init(); - - // load Q tile - qTile.gm2sm(ctx.qPtr(), ctx.seqDimQ(), ctx.tileSizeQ, ctx); - __syncthreads(); - - // main loop over KV tiles - const int numTilesKV = ctx.template numTilesKV(); - for (int tileIdx = 0; tileIdx < numTilesKV; tileIdx++) { - ctx.template setTileKV(tileIdx); - - // load K - kTile.gm2sm(ctx.kPtr(), ctx.seqDimKV(), ctx.curTileSizeKV, ctx); - __syncthreads(); - - // S = Q @ K^T - float score[kRowsPerWarp][kColsPerLane]; - GemmOp::template computeScore(score, qTile, kTile, ctx); - - // softmax - softmax.update(score, score, oTile.acc); - __syncthreads(); - - // load V - vTile.gm2sm(ctx.vPtr(), ctx.seqDimKV(), ctx.curTileSizeKV, ctx); - __syncthreads(); - - // O += P @ V - GemmOp::computeOutput(oTile, score, vTile, ctx); - __syncthreads(); - } - - // normalize - oTile.normalize(ctx, softmax); - __syncthreads(); - - // store O - oTile.store(ctx.oPtr(), ctx.seqDimQ(), ctx); -} - -} // namespace tfa diff --git a/src/Operation/FlashAtten/launcher.cuh b/src/Operation/FlashAtten/launcher.cuh deleted file mode 100644 index 2bfd678..0000000 --- a/src/Operation/FlashAtten/launcher.cuh +++ /dev/null @@ -1,102 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -#include - -#include "kernel.cuh" -#include "params.cuh" -#include "utils.cuh" - -namespace tfa { - -namespace detail { - -template -void launchKernel(const Params& params, int gridX, int numHeads, int batchSize, bool isCausal, cudaStream_t stream) { - size_t smemSize = QTile::smemSize() + KTile::smemSize(); - - dim3 grid(gridX, numHeads, batchSize); - dim3 block(Config::kNumThreads); - - auto kernel = isCausal ? flashAttentionKernel : flashAttentionKernel; - - bool sharedMemOk = setDynamicSharedMemory(kernel, smemSize); - assert(sharedMemOk && "error: shared memory not fit"); - - kernel<<>>(params); -} - -template -void flashAttnImpl(const DType* Q, const DType* K, const DType* V, DType* O, int batchSize, int seqLenQ, int seqLenKV, - int numHeadsQ, int numHeadsKV, bool isCausal, cudaStream_t stream) { - using ConfigHelper = ConfigForHeadDim; - using Config = typename ConfigHelper::Config; - using Params = FixLenParams; - - Params params; - params.Q = Q; - params.K = K; - params.V = V; - params.O = O; - params.seqLenQ = seqLenQ; - params.seqLenKV = seqLenKV; - params.numKVTiles = ceilDiv(seqLenKV, Config::kBc); - params.seqDimQ = numHeadsQ * kHeadDim; - params.seqDimKV = numHeadsKV * kHeadDim; - params.groupSize = numHeadsQ / numHeadsKV; - - launchKernel(params, ceilDiv(seqLenQ, Config::kBr), numHeadsQ, batchSize, isCausal, stream); -} - -template -void flashAttnVarLenImpl(const DType* Q, const DType* K, const DType* V, DType* O, const int* cuSeqLensQ, - const int* cuSeqLensKV, int batchSize, int maxSeqLenQ, int maxSeqLenKV, int numHeadsQ, - int numHeadsKV, bool isCausal, cudaStream_t stream) { - using ConfigHelper = ConfigForHeadDim; - using Config = typename ConfigHelper::Config; - using Params = VarLenParams; - - Params params; - params.Q = Q; - params.K = K; - params.V = V; - params.O = O; - params.cuSeqLensQ = cuSeqLensQ; - params.cuSeqLensKV = cuSeqLensKV; - params.maxSeqLenQ = maxSeqLenQ; - params.maxSeqLenKV = maxSeqLenKV; - params.maxKVTiles = ceilDiv(maxSeqLenKV, Config::kBc); - params.seqDimQ = numHeadsQ * kHeadDim; - params.seqDimKV = numHeadsKV * kHeadDim; - params.groupSize = numHeadsQ / numHeadsKV; - - launchKernel(params, ceilDiv(maxSeqLenQ, Config::kBr), numHeadsQ, batchSize, isCausal, stream); -} - -} // namespace detail - -template -void flashAttn(const DType* Q, const DType* K, const DType* V, DType* O, int batchSize, int seqLenQ, int seqLenKV, - int numHeadsQ, int numHeadsKV, int headDim, bool isCausal = false, cudaStream_t stream = nullptr) { - TFA_DISPATCH_HEAD_DIM(headDim, kHeadDim, { - detail::flashAttnImpl(Q, K, V, O, batchSize, seqLenQ, seqLenKV, numHeadsQ, numHeadsKV, isCausal, - stream); - }); -} - -template -void flashAttnVarLen(const DType* Q, const DType* K, const DType* V, DType* O, const int* cuSeqLensQ, - const int* cuSeqLensKV, int batchSize, int maxSeqLenQ, int maxSeqLenKV, int numHeadsQ, - int numHeadsKV, int headDim, bool isCausal = false, cudaStream_t stream = nullptr) { - TFA_DISPATCH_HEAD_DIM(headDim, kHeadDim, { - detail::flashAttnVarLenImpl(Q, K, V, O, cuSeqLensQ, cuSeqLensKV, batchSize, maxSeqLenQ, - maxSeqLenKV, numHeadsQ, numHeadsKV, isCausal, stream); - }); -} - -} // namespace tfa diff --git a/src/Operation/FlashAtten/layout.cuh b/src/Operation/FlashAtten/layout.cuh deleted file mode 100644 index 46efbb2..0000000 --- a/src/Operation/FlashAtten/layout.cuh +++ /dev/null @@ -1,62 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -namespace tfa { - -struct LayoutIdentity { - __device__ __forceinline__ static int map(int row, int col, int stride) { return row * stride + col; } -}; - -// Ref https://leimao.github.io/blog/CuTe-Swizzle/. -template -struct CuteSwizzle { - static constexpr int mbase = MBase; - static constexpr int mask_bits = BBits; - static constexpr int mask_shift = SShift; - - static constexpr int bit_mask = (1 << mask_bits) - 1; - static constexpr int yy_mask = bit_mask << (mbase + mask_shift); - static constexpr int yy_mask_lowest_bit = yy_mask & -yy_mask; - - __device__ __forceinline__ constexpr static int apply(int offset) { - const int row_shifted = (offset & yy_mask) >> mask_shift; - return offset ^ row_shifted; - } -}; - -template -struct LayoutSwizzle { - static constexpr int kVecBytes = 16; - static constexpr int kDTypeBytes = sizeof(DType); - static constexpr int kVecElem = kVecBytes / kDTypeBytes; - - static constexpr int MBase = (kVecElem == 8) ? 3 : (kVecElem == 4) ? 2 : 0; - - static constexpr int kHeadDimBits = (HeadDim == 256) ? 8 - : (HeadDim == 128) ? 7 - : (HeadDim == 64) ? 6 - : (HeadDim == 32) ? 5 - : 0; - - static_assert(kHeadDimBits > 0, "Unsupported HeadDim"); - static constexpr int kSShift = kHeadDimBits - MBase; - - using Swizzle = CuteSwizzle<3, MBase, kSShift>; - - __device__ __forceinline__ static int map(int row, int col, int stride) { - int offset = row * stride + col; - return Swizzle::apply(offset); - } -}; - -template -using TileLayout = - typename std::conditional, - LayoutIdentity>::type; - -} // namespace tfa diff --git a/src/Operation/FlashAtten/memory.cuh b/src/Operation/FlashAtten/memory.cuh deleted file mode 100644 index b524e47..0000000 --- a/src/Operation/FlashAtten/memory.cuh +++ /dev/null @@ -1,62 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -#include "layout.cuh" - -namespace tfa { - -template -struct MemLoader { - using DType = typename Config::DType; - using Layout = TileLayout; - using VecT = int4; // 128-bit - - static constexpr int kElemsPerVec = Config::kElemsPerVecLoad; - static constexpr int kHeadDim = Config::kHeadDim; - static constexpr int kNumVecs = kHeadDim / kElemsPerVec; - - template - __device__ __forceinline__ static void copy(DType* smem, DType* gmem, int stride, int validRows, const Context& ctx) { - static_assert(kHeadDim % kElemsPerVec == 0, "kHeadDim must be divisible by kElemsPerVec"); - - for (int idx = ctx.threadId; idx < Rows * kNumVecs; idx += ctx.blockSize) { - int row = idx / kNumVecs; - int col = (idx % kNumVecs) * kElemsPerVec; - int smemIdx = Layout::map(row, col, kHeadDim); - - VecT* sPtr = reinterpret_cast(&smem[smemIdx]); - VecT* gPtr = reinterpret_cast(gmem + row * stride + col); - - if (IsLoad) { - // global -> shared - *sPtr = (row < validRows) ? *gPtr : make_int4(0, 0, 0, 0); - } else if (row < validRows) { - // shared -> global - *gPtr = *sPtr; - } - } - } - - template - __device__ __forceinline__ static void gm2sm(DType* __restrict__ smem, const DType* __restrict__ gmem, int stride, - int validRows, const Context& ctx) { - copy(smem, const_cast(gmem), stride, validRows, ctx); - } - - template - __device__ __forceinline__ static void sm2gm(DType* __restrict__ gmem, const DType* __restrict__ smem, int stride, - int validRows, const Context& ctx) { - copy(const_cast(smem), gmem, stride, validRows, ctx); - } - - __device__ __forceinline__ static void sm2reg(DType* __restrict__ reg, const DType* __restrict__ smem) { - *reinterpret_cast(reg) = *reinterpret_cast(smem); - } -}; - -} // namespace tfa diff --git a/src/Operation/FlashAtten/params.cuh b/src/Operation/FlashAtten/params.cuh deleted file mode 100644 index dcb1a57..0000000 --- a/src/Operation/FlashAtten/params.cuh +++ /dev/null @@ -1,110 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -namespace tfa { - -template -struct FixLenParams { - using DType = DType_; - static constexpr int kHeadDim = kHeadDim_; - - const DType* __restrict__ Q; - const DType* __restrict__ K; - const DType* __restrict__ V; - DType* __restrict__ O; - - int seqLenQ; - int seqLenKV; - - int seqDimQ; // numHeadsQ * kHeadDim - int seqDimKV; // numHeadsKV * kHeadDim - - int numKVTiles; // ceilDiv(seqLenKV, kTileKV) - int groupSize; // numHeadsQ / numHeadsKV (GQA) - - __device__ __forceinline__ int getSeqLenQ(int batchIdx) const { return seqLenQ; } - __device__ __forceinline__ int getSeqLenKV(int batchIdx) const { return seqLenKV; } - __device__ __forceinline__ int getKVHead(int headIdx) const { return headIdx / groupSize; } - __device__ __forceinline__ int getKVTiles(int seqLen, int tileKV) const { return numKVTiles; } - - // Layout: [batch, seq, head, dim] - template - __device__ __forceinline__ const DType* qPtr(const Context& ctx) const { - return Q + (ctx.batchIdx * seqLenQ + ctx.tileQ) * seqDimQ + ctx.headIdx * kHeadDim; - } - - template - __device__ __forceinline__ const DType* kPtr(const Context& ctx, int seqIdx) const { - return K + (ctx.batchIdx * seqLenKV + seqIdx) * seqDimKV + ctx.headIdxKV * kHeadDim; - } - - template - __device__ __forceinline__ const DType* vPtr(const Context& ctx, int seqIdx) const { - return V + (ctx.batchIdx * seqLenKV + seqIdx) * seqDimKV + ctx.headIdxKV * kHeadDim; - } - - template - __device__ __forceinline__ DType* oPtr(const Context& ctx) const { - return O + (ctx.batchIdx * seqLenQ + ctx.tileQ) * seqDimQ + ctx.headIdx * kHeadDim; - } -}; - -template -struct VarLenParams { - using DType = DType_; - static constexpr int kHeadDim = kHeadDim_; - - const DType* __restrict__ Q; // [totalQ, numHeadsQ, headDim] - const DType* __restrict__ K; // [totalKV, numHeadsKV, headDim] - const DType* __restrict__ V; // [totalKV, numHeadsKV, headDim] - DType* __restrict__ O; // [totalQ, numHeadsQ, headDim] - - const int* __restrict__ cuSeqLensQ; // [batch + 1], cumulative sequence lengths for Q - const int* __restrict__ cuSeqLensKV; // [batch + 1], cumulative sequence lengths for KV - - int maxSeqLenQ; - int maxSeqLenKV; - - int seqDimQ; // numHeadsQ * headDim - int seqDimKV; // numHeadsKV * headDim - - int maxKVTiles; // ceilDiv(maxSeqLenKV, kTileKV) - int groupSize; // numHeadsQ / numHeadsKV (GQA) - - __device__ __forceinline__ int getSeqLenQ(int batchIdx) const { - return cuSeqLensQ[batchIdx + 1] - cuSeqLensQ[batchIdx]; - } - __device__ __forceinline__ int getSeqLenKV(int batchIdx) const { - return cuSeqLensKV[batchIdx + 1] - cuSeqLensKV[batchIdx]; - } - __device__ __forceinline__ int getKVHead(int headIdx) const { return headIdx / groupSize; } - __device__ __forceinline__ int getKVTiles(int seqLen, int tileKV) const { return ceilDiv(seqLen, tileKV); } - - // Layout: [totalSeq, numHeads, headDim] (packed, no padding) - template - __device__ __forceinline__ const DType* qPtr(const Context& ctx) const { - return Q + (cuSeqLensQ[ctx.batchIdx] + ctx.tileQ) * seqDimQ + ctx.headIdx * kHeadDim; - } - - template - __device__ __forceinline__ const DType* kPtr(const Context& ctx, int seqIdx) const { - return K + (cuSeqLensKV[ctx.batchIdx] + seqIdx) * seqDimKV + ctx.headIdxKV * kHeadDim; - } - - template - __device__ __forceinline__ const DType* vPtr(const Context& ctx, int seqIdx) const { - return V + (cuSeqLensKV[ctx.batchIdx] + seqIdx) * seqDimKV + ctx.headIdxKV * kHeadDim; - } - - template - __device__ __forceinline__ DType* oPtr(const Context& ctx) const { - return O + (cuSeqLensQ[ctx.batchIdx] + ctx.tileQ) * seqDimQ + ctx.headIdx * kHeadDim; - } -}; - -} // namespace tfa diff --git a/src/Operation/FlashAtten/softmax.cuh b/src/Operation/FlashAtten/softmax.cuh deleted file mode 100644 index 528ca2d..0000000 --- a/src/Operation/FlashAtten/softmax.cuh +++ /dev/null @@ -1,102 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -namespace tfa { - -template -struct Softmax { - static constexpr int kWarpSize = Config::kWarpSize; - static constexpr int kRowsPerWarp = Config::kRowsPerWarp; - static constexpr int kColsPerLane = Config::kColsPerLane; - static constexpr int kDimsPerLane = Config::kDimsPerLane; - - float rowMax[kRowsPerWarp]; - float rowSum[kRowsPerWarp]; - - __device__ void init() { -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { - rowMax[m] = -INFINITY; - rowSum[m] = 0.f; - } - } - - template - __device__ __forceinline__ void update(P& prob, const S& score, AccO& accO) { -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { - float newMax = computeRowMax(score[m]); - float correction = rescaleState(m, newMax); - float localSum = computeExpAndSum(prob[m], score[m], rowMax[m]); - updateRowSum(m, correction, localSum); - rescaleOutput(accO[m], correction); - } - } - - __device__ __forceinline__ float getNorm(int m) const { return (rowSum[m] > 0.f) ? (1.f / rowSum[m]) : 0.f; } - - private: - template - __device__ __forceinline__ float computeRowMax(const Row& row) const { - float localMax = row[0]; -#pragma unroll - for (int n = 1; n < kColsPerLane; n++) { - localMax = fmaxf(localMax, row[n]); - } - return warpReduceMax(localMax); - } - - __device__ __forceinline__ float rescaleState(int m, float newMax) { - float prevMax = rowMax[m]; - rowMax[m] = fmaxf(prevMax, newMax); - return fastExp(prevMax - rowMax[m]); - } - - template - __device__ __forceinline__ float computeExpAndSum(PRow& prob, const SRow& score, float maxVal) const { - float localSum = 0.f; -#pragma unroll - for (int n = 0; n < kColsPerLane; n++) { - float p = fastExp(score[n] - maxVal); - prob[n] = p; - localSum += p; - } - return localSum; - } - - __device__ __forceinline__ void updateRowSum(int m, float correction, float localSum) { - float warpSum = warpReduceSum(localSum); - rowSum[m] = rowSum[m] * correction + warpSum; - } - - template - __device__ __forceinline__ void rescaleOutput(AccRow& accRow, float correction) const { -#pragma unroll - for (int k = 0; k < kDimsPerLane; k++) { - accRow[k] *= correction; - } - } - - __device__ __forceinline__ float warpReduceMax(float val) const { -#pragma unroll - for (int delta = kWarpSize / 2; delta > 0; delta >>= 1) { - val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, delta)); - } - return val; - } - - __device__ __forceinline__ float warpReduceSum(float val) const { -#pragma unroll - for (int delta = kWarpSize / 2; delta > 0; delta >>= 1) { - val += __shfl_xor_sync(0xffffffff, val, delta); - } - return val; - } -}; - -} // namespace tfa diff --git a/src/Operation/FlashAtten/tile.cuh b/src/Operation/FlashAtten/tile.cuh deleted file mode 100644 index a9d35ed..0000000 --- a/src/Operation/FlashAtten/tile.cuh +++ /dev/null @@ -1,98 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -#include "layout.cuh" -#include "memory.cuh" -#include "utils.cuh" - -namespace tfa { - -template -struct Tile { - using DType = typename Config::DType; - using Layout = TileLayout; - static constexpr int kHeadDim = Config::kHeadDim; - static constexpr int kElemsPerVec = Config::kElemsPerVecLoad; - - DType* smem; - - __device__ explicit Tile(DType* smemBase) : smem(smemBase) {} - - static constexpr size_t numElems() { return NumRows * kHeadDim; } - static constexpr size_t smemSize() { return numElems() * sizeof(DType); } - - template - __device__ __forceinline__ void gm2sm(const DType* __restrict__ globalPtr, int stride, int validRows, - const Context& ctx) { - MemLoader::template gm2sm(smem, globalPtr, stride, validRows, ctx); - } - - __device__ __forceinline__ DType at(int row, int col) const { return smem[Layout::map(row, col, kHeadDim)]; } - - __device__ __forceinline__ void sm2reg(DType* __restrict__ regPtr, int row, int vecIdx) const { - int smemIdx = Layout::map(row, vecIdx * kElemsPerVec, kHeadDim); - MemLoader::sm2reg(regPtr, &smem[smemIdx]); - } -}; - -template -using QTile = Tile; - -template -using KTile = Tile; - -template -using VTile = Tile; - -template -struct OTile { - using DType = typename Config::DType; - using Layout = TileLayout; - - static constexpr int kHeadDim = Config::kHeadDim; - static constexpr int kRowsPerWarp = Config::kRowsPerWarp; - static constexpr int kDimsPerLane = Config::kDimsPerLane; - - float acc[kRowsPerWarp][kDimsPerLane]{}; - DType* smemPtr; - - __device__ explicit OTile(DType* smemBase) : smemPtr(smemBase) { -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { -#pragma unroll - for (int k = 0; k < kDimsPerLane; k++) { - acc[m][k] = 0.f; - } - } - } - - template - __device__ __forceinline__ void normalize(const Context& ctx, const SoftmaxT& softmax) { - const int validRows = ctx.validWarpRows(); -#pragma unroll - for (int m = 0; m < kRowsPerWarp; m++) { - if (m < validRows) { - float norm = softmax.getNorm(m); - int row = ctx.warpRowOffset + m; -#pragma unroll - for (int k = 0; k < kDimsPerLane; k++) { - int col = ctx.laneId * kDimsPerLane + k; - int smemIdx = Layout::map(row, col, kHeadDim); - smemPtr[smemIdx] = fromFloat(acc[m][k] * norm); - } - } - } - } - - template - __device__ __forceinline__ void store(DType* __restrict__ globalPtr, int stride, const Context& ctx) { - MemLoader::template sm2gm(globalPtr, smemPtr, stride, ctx.tileSizeQ, ctx); - } -}; - -} // namespace tfa diff --git a/src/Operation/FlashAtten/utils.cuh b/src/Operation/FlashAtten/utils.cuh deleted file mode 100644 index 17765f2..0000000 --- a/src/Operation/FlashAtten/utils.cuh +++ /dev/null @@ -1,97 +0,0 @@ -/* - * TinyFA - * @author : keith@robot9.me - * - */ - -#pragma once - -#include -#include - -#include - -namespace tfa { - -#define TFA_CUDA_CHECK(call) \ - do { \ - cudaError_t err = call; \ - if (err != cudaSuccess) { \ - fprintf(stderr, "CUDA error %s:%d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -template -__host__ __device__ __forceinline__ constexpr T ceilDiv(T m, T n) { - return (m + n - 1) / n; -} - -template -__host__ __device__ __forceinline__ float toFloat(T val) { - return static_cast(val); -} - -template <> -__host__ __device__ __forceinline__ float toFloat(__half val) { - return __half2float(val); -} - -template <> -__host__ __device__ __forceinline__ float toFloat(__nv_bfloat16 val) { - return __bfloat162float(val); -} - -template -__host__ __device__ __forceinline__ T fromFloat(float val) { - return static_cast(val); -} - -template <> -__host__ __device__ __forceinline__ __half fromFloat(float val) { - return __float2half(val); -} - -template <> -__host__ __device__ __forceinline__ __nv_bfloat16 fromFloat(float val) { - return __float2bfloat16(val); -} - -__device__ __forceinline__ float fastExp(float x) { return __expf(x); } - -template -struct AttentionScale { - // 1/sqrt(kHeadDim) - static constexpr float value = (kHeadDim == 32) ? 0.17677669529f - : (kHeadDim == 64) ? 0.125f - : (kHeadDim == 128) ? 0.08838834764f - : (kHeadDim == 256) ? 0.0625f - : 0.f; - static_assert(value != 0.f, "Unsupported kHeadDim for AttentionScale"); -}; - -template -bool setDynamicSharedMemory(KernelFunc kernel, size_t requestedSize) { - int device; - TFA_CUDA_CHECK(cudaGetDevice(&device)); - - int maxOptin = 0; - TFA_CUDA_CHECK(cudaDeviceGetAttribute(&maxOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - - if (requestedSize > static_cast(maxOptin)) { - fprintf(stderr, - "Error: requested shared memory %zu exceeds " - "cudaDevAttrMaxSharedMemoryPerBlockOptin (%d bytes)\n", - requestedSize, maxOptin); - return false; - } - - int maxPerBlock = 0; - TFA_CUDA_CHECK(cudaDeviceGetAttribute(&maxPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, device)); - if (requestedSize > static_cast(maxPerBlock)) { - TFA_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, requestedSize)); - } - return true; -} - -} // namespace tfa diff --git a/src/Operation/OpNNLayerCuda.cu b/src/Operation/OpNNLayerCuda.cu index 9151a01..523edbc 100644 --- a/src/Operation/OpNNLayerCuda.cu +++ b/src/Operation/OpNNLayerCuda.cu @@ -4,7 +4,7 @@ * */ -#include "FlashAtten/launcher.cuh" +#include "flash_attn/flash_api.cuh" #include "OpNNLayerCuda.cuh" namespace tinytorch::op { diff --git a/third_party/TinyFA b/third_party/TinyFA new file mode 160000 index 0000000..ffc2647 --- /dev/null +++ b/third_party/TinyFA @@ -0,0 +1 @@ +Subproject commit ffc264708d49e63f167b067b6d42339340469ca1