Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
[submodule "test/googletest"]
path = test/googletest
url = https://github.com/google/googletest.git
[submodule "third_party/TinyFA"]
path = third_party/TinyFA
url = https://github.com/keith2018/TinyFA.git
8 changes: 0 additions & 8 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,3 @@ if (TINYTORCH_USE_NCCL)
target_include_directories(${PROJECT_NAME} PUBLIC ${NCCL_INCLUDE_DIRS})
target_link_libraries(${PROJECT_NAME} PUBLIC ${NCCL_LIBRARY})
endif ()

# TinyFA
if (TINYTORCH_USE_CUDA)
set(TFA_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(TFA_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
add_subdirectory(${THIRD_PARTY_DIR}/TinyFA TinyFA)
target_link_libraries(${PROJECT_NAME} PRIVATE TinyFA::tinyfa)
endif ()
13 changes: 13 additions & 0 deletions src/Function/FuncFused.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ class FuncSiluMul : public Function<FuncSiluMul> {
static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); }
};

class FuncFusedAddRmsNorm : public Function<FuncFusedAddRmsNorm> {
public:
static void forward(AutogradContext* ctx, Tensor& input, Tensor& residual, const Tensor& weight, float eps) {
op::fusedAddRmsNorm(input, residual, weight, eps);
}

static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); }
};

inline Tensor siluMul(const Tensor& x) { return FuncSiluMul::apply(x); }

inline void fusedAddRmsNorm(Tensor& input, Tensor& residual, const Tensor& weight, float eps) {
FuncFusedAddRmsNorm::apply(input, residual, weight, eps);
}

} // namespace tinytorch::function
25 changes: 0 additions & 25 deletions src/Function/FuncNNLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,24 +267,6 @@ class FuncSDPAttention : public Function<FuncSDPAttention> {
static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); }
};

class FuncFlashAttention : public Function<FuncFlashAttention> {
public:
static Tensor forward(AutogradContext* ctx, const Tensor& query, const Tensor& key, const Tensor& value,
bool isCausal) {
return op::flashAttention(query, key, value, isCausal);
}
static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); }
};

class FuncRoPE : public Function<FuncRoPE> {
public:
static Tensor forward(AutogradContext* ctx, const Tensor& input, const Tensor& rope, int64_t offset,
QKVLayout layout) {
return op::ropeApply(input, rope, offset, layout);
}
static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); }
};

inline Tensor linear(const Tensor& input, const Tensor& weight, const Tensor& bias = {}) {
return FuncLinear::apply(input, weight, bias);
}
Expand Down Expand Up @@ -318,12 +300,5 @@ inline Tensor sdpAttention(const Tensor& query, const Tensor& key, const Tensor&
std::optional<float> scale = std::nullopt) {
return FuncSDPAttention::apply(query, key, value, isCausal, attnMask, dropoutP, scale);
}
inline Tensor flashAttention(const Tensor& query, const Tensor& key, const Tensor& value, bool isCausal = false) {
return FuncFlashAttention::apply(query, key, value, isCausal);
}
inline Tensor ropeApply(const Tensor& input, const Tensor& rope, int64_t offset = 0,
QKVLayout layout = QKVLayout::BHSD) {
return FuncRoPE::apply(input, rope, offset, layout);
}

} // namespace tinytorch::function
16 changes: 0 additions & 16 deletions src/Module/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,4 @@ std::vector<std::pair<std::string, TensorPtr>> Embedding::namedParameters_() { r

void Embedding::resetParameters() { Initializer::normal(weight_); }

RoPE::RoPE(int64_t headDim, int64_t contextLength, float thetaBase, std::optional<RopeScalingConfig> scaling,
Options options)
: headDim_(headDim), contextLength_(contextLength), thetaBase_(thetaBase), scaling_(scaling), options_(options) {
RoPE::resetParameters();
}

Tensor RoPE::forward(const Tensor &input) { return function::ropeApply(input, rope_); }

Tensor RoPE::forward(const Tensor &input, int64_t offset, QKVLayout layout) {
return function::ropeApply(input, rope_, offset, layout);
}

void RoPE::resetParameters() { rope_ = op::ropeInit(headDim_, contextLength_, thetaBase_, scaling_, options_); }

std::vector<std::pair<std::string, TensorPtr>> RoPE::namedStates_() { return {{"rope", &rope_}}; }

} // namespace tinytorch::nn
30 changes: 0 additions & 30 deletions src/Module/Basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,34 +138,4 @@ class Embedding : public Module {
Tensor weight_;
};

class RoPE : public Module {
public:
explicit RoPE(int64_t headDim, int64_t contextLength = 4096, float thetaBase = 10000.0f,
std::optional<RopeScalingConfig> scaling = std::nullopt, Options options = {});

using Module::forward;
Tensor forward(const Tensor &input) override;
Tensor forward(const Tensor &input, int64_t offset, QKVLayout layout = QKVLayout::BHSD);

using Module::operator();
Tensor operator()(const Tensor &input, int64_t offset, QKVLayout layout = QKVLayout::BHSD) {
return forward(input, offset, layout);
}

void resetParameters() override;

Tensor &cache() { return rope_; }

protected:
std::vector<std::pair<std::string, TensorPtr>> namedStates_() override;

int64_t headDim_;
int64_t contextLength_;
float thetaBase_;
std::optional<RopeScalingConfig> scaling_;
Options options_;

Tensor rope_;
};

} // namespace tinytorch::nn
1 change: 1 addition & 0 deletions src/Module/Norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class RMSNorm : public Module {
void resetParameters() override;

Tensor &weight() { return weight_; }
float eps() const { return eps_; }

protected:
std::vector<std::pair<std::string, TensorPtr>> namedParameters_() override;
Expand Down
5 changes: 5 additions & 0 deletions src/Operation/OpFused.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@ namespace tinytorch::op {

using SiluMulOpFn = Tensor (*)(const Tensor& self);

using FusedAddRmsNormOpFn = void (*)(Tensor& input, Tensor& residual, const Tensor& weight, float eps);

// siluMul
DEFINE_OP(siluMul, SiluMulOpFn)

// fusedAddRmsNorm
DEFINE_OP(fusedAddRmsNorm, FusedAddRmsNormOpFn)

void registerFusedCpu();
STATIC_CALL(registerFusedCpu);

Expand Down
3 changes: 3 additions & 0 deletions src/Operation/OpFusedCpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ namespace tinytorch::op {
void registerFusedCpu() {
// siluMul
REG_FUSED_CPU_FLT(siluMul, siluMulOpCpuImpl);

// fusedAddRmsNorm
REG_FUSED_CPU_FLT(fusedAddRmsNorm, fusedAddRmsNormOpCpuImpl);
}

} // namespace tinytorch::op
31 changes: 31 additions & 0 deletions src/Operation/OpFusedCpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,35 @@ Tensor siluMulOpCpuImpl(const Tensor& self) {
return ret;
}

template <typename T>
void fusedAddRmsNormOpCpuImpl(Tensor& input, Tensor& residual, const Tensor& weight, float eps) {
ASSERT(input.shape() == residual.shape());
int64_t dim = input.size(-1);
int64_t numRows = input.numel() / dim;

T* inputPtr = input.dataPtr<T>();
T* residualPtr = residual.dataPtr<T>();
const T* weightPtr = weight.dataPtr<T>();

for (int64_t row = 0; row < numRows; row++) {
int64_t base = row * dim;

// add residual + accumulate sum‑of‑squares
float sumSq = 0.f;
for (int64_t i = 0; i < dim; i++) {
float r = static_cast<float>(inputPtr[base + i]) + static_cast<float>(residualPtr[base + i]);
residualPtr[base + i] = static_cast<T>(r);
sumSq += r * r;
}

float invRms = 1.f / std::sqrt(sumSq / static_cast<float>(dim) + eps);

// normalize + affine
for (int64_t i = 0; i < dim; i++) {
auto r = static_cast<float>(residualPtr[base + i]);
inputPtr[base + i] = static_cast<T>(r * invRms * static_cast<float>(weightPtr[i]));
}
}
}

} // namespace tinytorch::op
3 changes: 3 additions & 0 deletions src/Operation/OpFusedCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ namespace tinytorch::op {
void registerFusedCuda() {
// siluMul
REG_FUSED_CUDA_FLT(siluMul, siluMulOpCudaImpl);

// fusedAddRmsNorm
REG_FUSED_CUDA_FLT(fusedAddRmsNorm, fusedAddRmsNormOpCudaImpl);
}

} // namespace tinytorch::op
100 changes: 97 additions & 3 deletions src/Operation/OpFusedCuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,46 @@

#include "OpElemWiseCuda.cuh"
#include "OpFused.h"
#include "OpNNLayerCuda.cuh"
#include "Utils/CUDAUtils.h"

namespace tinytorch::op {

template <typename T, int VEC_ELEMENTS>
__global__ void kSiluMulVec(T* __restrict__ outPtr, const T* __restrict__ selfPtr, const int d) {
const unsigned int row = blockIdx.x;
const T* gatePtr = selfPtr + static_cast<int64_t>(row) * d * 2;
const T* upPtr = gatePtr + d;
T* rowOut = outPtr + static_cast<int64_t>(row) * d;

const int numVecs = d / VEC_ELEMENTS;

const int4* gateVec = reinterpret_cast<const int4*>(gatePtr);
const int4* upVec = reinterpret_cast<const int4*>(upPtr);
int4* outVec = reinterpret_cast<int4*>(rowOut);

for (auto i = threadIdx.x; i < numVecs; i += blockDim.x) {
// 128-bit load via read-only cache (__ldg)
int4 gv = __ldg(&gateVec[i]);
int4 uv = __ldg(&upVec[i]);

const T* g = reinterpret_cast<const T*>(&gv);
const T* u = reinterpret_cast<const T*>(&uv);
T r[VEC_ELEMENTS];

#pragma unroll
for (int j = 0; j < VEC_ELEMENTS; ++j) {
auto gf = static_cast<float>(g[j]);
auto uf = static_cast<float>(u[j]);
r[j] = static_cast<T>(gf / (1.f + __expf(-gf)) * uf);
}

outVec[i] = *reinterpret_cast<const int4*>(r);
}
}

template <typename T>
__global__ void kSiluMul(T* retPtr, const T* selfPtr, const int64_t halfLastDim, const int64_t n) {
__global__ void kSiluMulScalar(T* retPtr, const T* selfPtr, const int64_t halfLastDim, const int64_t n) {
const auto index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < n) {
const int64_t sliceIdx = index / halfLastDim;
Expand Down Expand Up @@ -42,9 +76,69 @@ Tensor siluMulOpCudaImpl(const Tensor& self) {
const int64_t lastDim = self.size(-1);
const int64_t halfLastDim = lastDim / 2;
const int64_t n = ret.numel();
auto params = cuda::getKernelLaunchParams(self.device().index, n);
CUDA_LAUNCH_KERNEL(kSiluMul<CudaT>, params, retPtr, selfPtr, halfLastDim, n);
const int64_t numRows = n / halfLastDim;

constexpr int kVecBytes = 16; // int4
constexpr int kVecElements = kVecBytes / static_cast<int>(sizeof(CudaT));

const bool useVec = (halfLastDim % kVecElements == 0);

if (useVec) {
const int d = static_cast<int>(halfLastDim);
dim3 grid(static_cast<unsigned>(numRows));
dim3 block(std::min(d / kVecElements, 1024));
auto stream = cuda::getCurrentCUDAStream(self.device().index).stream();
kSiluMulVec<CudaT, kVecElements><<<grid, block, 0, stream>>>(retPtr, selfPtr, d);
CUDA_KERNEL_CHECK();
} else {
auto params = cuda::getKernelLaunchParams(self.device().index, n);
CUDA_LAUNCH_KERNEL(kSiluMulScalar<CudaT>, params, retPtr, selfPtr, halfLastDim, n);
}

return ret;
}

template <typename T>
__global__ void kFusedAddRMSNorm(T* __restrict__ input, T* __restrict__ residual, const T* __restrict__ weight,
int64_t dim, float eps) {
const auto row = blockIdx.x;
const auto tid = threadIdx.x;
const auto base = row * dim;

// add residual + accumulate sum‑of‑squares
float sumSq = 0.f;
for (auto i = tid; i < dim; i += blockDim.x) {
float r = static_cast<float>(input[base + i]) + static_cast<float>(residual[base + i]);
residual[base + i] = static_cast<T>(r);
sumSq += r * r;
}

sumSq = cudaBlockReduce<float, OpCudaReduceSum>(sumSq, 0.f);
float invRms = cuda::rsqrt(sumSq / static_cast<float>(dim) + eps);

// normalize + affine
for (auto i = tid; i < dim; i += blockDim.x) {
auto r = static_cast<float>(residual[base + i]);
input[base + i] = static_cast<T>(r * invRms * static_cast<float>(weight[i]));
}
}

template <typename T>
void fusedAddRmsNormOpCudaImpl(Tensor& input, Tensor& residual, const Tensor& weight, float eps) {
ASSERT(input.shape() == residual.shape());
int64_t dim = input.size(-1);
int64_t numRows = input.numel() / dim;

using CudaT = typename cuda::CudaTypeCast<T>::type;
CudaT* inputPtr = input.dataPtr<CudaT>();
CudaT* residualPtr = residual.dataPtr<CudaT>();
const CudaT* weightPtr = weight.dataPtr<CudaT>();

auto stream = cuda::getCurrentCUDAStream(input.device().index).stream();
dim3 blockSize(std::clamp(nextPow2(dim), 32u, 1024u));
dim3 gridSize(numRows);
kFusedAddRMSNorm<CudaT><<<gridSize, blockSize, 0, stream>>>(inputPtr, residualPtr, weightPtr, dim, eps);
CUDA_KERNEL_CHECK();
}

} // namespace tinytorch::op
32 changes: 0 additions & 32 deletions src/Operation/OpNNLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,6 @@

#include "Tensor/Dispatch.h"

namespace tinytorch {

struct RopeScalingConfig {
float factor;
float highFreqFactor;
float lowFreqFactor;
int64_t originalContextLength;

RopeScalingConfig(float f, float hf, float lf, int64_t len)
: factor(f), highFreqFactor(hf), lowFreqFactor(lf), originalContextLength(len) {}
};

enum class QKVLayout {
BHSD, // [batch, numHead, seqLen, headDim]
BSHD // [batch, seqLen, numHead, headDim]
};

} // namespace tinytorch

namespace tinytorch::op {

enum class SoftmaxType : uint8_t {
Expand Down Expand Up @@ -59,12 +40,6 @@ using LayerNormOpFn = Tensor (*)(const Tensor& self, IntArrayView normalizedShap

using RMSNormOpFn = Tensor (*)(const Tensor& self, IntArrayView normalizedShape, const Tensor& weight, float eps);

using RopeInitOpFn = Tensor (*)(int64_t headDim, int64_t contextLength, float thetaBase,
std::optional<RopeScalingConfig> scaling, Options options);
using RopeApplyOpFn = Tensor (*)(const Tensor& input, const Tensor& rope, int64_t offset, QKVLayout layout);

using FlashAttentionOpFn = Tensor (*)(const Tensor& query, const Tensor& key, const Tensor& value, bool isCausal);

// softmax
DEFINE_OP(softmax, SoftmaxOpFn);
DEFINE_OP(softmaxOut, SoftmaxOpOutFn);
Expand All @@ -85,13 +60,6 @@ DEFINE_OP(layerNorm, LayerNormOpFn);
// rmsNorm
DEFINE_OP(rmsNorm, RMSNormOpFn);

// rope
DEFINE_OP(ropeInit, RopeInitOpFn);
DEFINE_OP(ropeApply, RopeApplyOpFn);

// flashAttention
DEFINE_OP(flashAttention, FlashAttentionOpFn);

void registerNNLayerCpu();
STATIC_CALL(registerNNLayerCpu);

Expand Down
4 changes: 0 additions & 4 deletions src/Operation/OpNNLayerCpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ void registerNNLayerCpu() {

// rmsNorm
REG_NN_LAYER_CPU_FLT(rmsNorm, rmsNormOpCpuImpl);

// rope
REG_NN_LAYER_CPU_FLT(ropeInit, ropeInitOpCpuImpl);
REG_NN_LAYER_CPU_FLT(ropeApply, ropeApplyOpCpuImpl);
}

} // namespace tinytorch::op
Loading
Loading