From 4b4a794ec444ecb82be746244d03a420b02fe5c4 Mon Sep 17 00:00:00 2001 From: keith2018 Date: Fri, 3 Apr 2026 21:49:18 +0800 Subject: [PATCH 1/5] feat: fuse matmul with bias --- src/Function/FuncLinalg.h | 2 +- src/Function/FuncNNLayer.h | 26 +++++------------ src/Operation/OpLinalg.cpp | 25 +++++++++++----- src/Operation/OpLinalg.h | 6 ++-- src/Operation/OpLinalgCpu.h | 29 +++++++++++++------ src/Operation/OpLinalgCuda.cuh | 53 ++++++++++++++++++++++++++-------- test/test_operation.cpp | 42 +++++++++++++-------------- 7 files changed, 112 insertions(+), 71 deletions(-) diff --git a/src/Function/FuncLinalg.h b/src/Function/FuncLinalg.h index 4324922..a048355 100644 --- a/src/Function/FuncLinalg.h +++ b/src/Function/FuncLinalg.h @@ -14,7 +14,7 @@ namespace tinytorch::function { class FuncMatmul : public Function { public: static Tensor forward(AutogradContext* ctx, const Tensor& a, const Tensor& b, bool transA, bool transB) { - return op::matmul(a, b, transA, transB); + return op::matmul(a, b, transA, transB, Tensor{}); } static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); } diff --git a/src/Function/FuncNNLayer.h b/src/Function/FuncNNLayer.h index 6190960..f9ac33d 100644 --- a/src/Function/FuncNNLayer.h +++ b/src/Function/FuncNNLayer.h @@ -14,12 +14,7 @@ namespace tinytorch::function { class FuncLinear : public Function { public: static Tensor forward(AutogradContext* ctx, const Tensor& input, const Tensor& weight, const Tensor& bias) { - auto output = op::matmul(input, weight, false, true); - if (bias.defined()) { - op::addInplace(output, bias, 1); - } - // TODO fuse - return output; + return op::matmul(input, weight, false, true, bias); } static void backward(AutogradContext* ctx, const Tensor& grad) { @@ -28,10 +23,10 @@ class FuncLinear : public Function { auto& bias = ctx->savedInputs[2]; if (input.requiresGrad()) { - input.addGrad(op::matmul(grad, weight, false, false)); + input.addGrad(op::matmul(grad, weight, false, false, Tensor{})); } if (weight.requiresGrad()) { - weight.addGrad(op::matmul(grad, input, true, false)); + weight.addGrad(op::matmul(grad, input, true, false, Tensor{})); } if (bias.defined() && bias.requiresGrad()) { bias.addGrad(op::sumOnDim(grad, 0, false)); @@ -161,12 +156,7 @@ class FuncConv2D : public Function { auto col = op::im2col(input, kernel, stride, padding); auto colW = op::reshape(weight, IntArrayView{outChannels, -1}); - auto ret = op::matmul(col, colW, false, true); - if (bias.defined()) { - ASSERT(bias.dim() == 1); - ASSERT(bias.shape()[0] == outChannels); - op::addInplace(ret, bias, 1); - } + auto ret = op::matmul(col, colW, false, true, bias); ret.reshape_({batch, outChannels, outH, outW}); if (ctx) { @@ -192,13 +182,13 @@ class FuncConv2D : public Function { auto colW = op::reshape(weight, IntArrayView{outChannels, -1}); if (input.requiresGrad()) { - auto gradCol = op::matmul(gradW, colW, false, false); + auto gradCol = op::matmul(gradW, colW, false, false, Tensor{}); auto inputGrad = op::col2im(gradCol, input.shape(), kernel, stride, padding); input.addGrad(std::move(inputGrad)); } if (weight.requiresGrad()) { auto col = ctx->popData().toTensor(); - auto gradColW = op::matmul(col, gradW, true, false); + auto gradColW = op::matmul(col, gradW, true, false, Tensor{}); auto weightGrad = op::reshape(gradColW.permute(), weight.shape()); weight.addGrad(std::move(weightGrad)); } @@ -244,7 +234,7 @@ class FuncSDPAttention : public Function { auto S = key.size(-2); float scaleFactor = scale.has_value() ? scale.value() : (1.f / std::sqrt(static_cast(query.size(-1)))); - auto attnWeight = op::matmul(query, op::transpose(key, -2, -1), false, false); + auto attnWeight = op::matmul(query, op::transpose(key, -2, -1), false, false, Tensor{}); op::mulInplace(attnWeight, Tensor::scalar(scaleFactor, attnWeight.options())); Tensor attnBias; @@ -272,7 +262,7 @@ class FuncSDPAttention : public Function { if (dropoutP > 0.f) { attnWeight = op::dropout(attnWeight, dropoutP); } - return op::matmul(attnWeight, value, false, false); + return op::matmul(attnWeight, value, false, false, Tensor{}); } static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); } }; diff --git a/src/Operation/OpLinalg.cpp b/src/Operation/OpLinalg.cpp index 6d956dc..90e5b61 100644 --- a/src/Operation/OpLinalg.cpp +++ b/src/Operation/OpLinalg.cpp @@ -6,6 +6,8 @@ #include "OpLinalg.h" +#include "OpElemWise.h" + namespace tinytorch::op { inline SizeVector makePaddedStrides(const IntArrayView &strides, int64_t targetDim) { @@ -216,10 +218,11 @@ Tensor matmulOpImplDetail(const Tensor &a, const Tensor &b, bool transA = false, } } - gemm(retPtr + batch * m * n, selfPtr + aOffset, otherPtr + bOffset, m, k, n, transA, transB, a.device().index); + gemm(retPtr + batch * m * n, selfPtr + aOffset, otherPtr + bOffset, m, k, n, transA, transB, a.device().index, + nullptr); } } else { - gemm(retPtr, selfPtr, otherPtr, m, k, n, transA, transB, a.device().index); + gemm(retPtr, selfPtr, otherPtr, m, k, n, transA, transB, a.device().index, nullptr); if (prependA) { retTensor.reshape_({n}); } @@ -238,8 +241,8 @@ Tensor matmulOpImplDetail(const Tensor &a, const Tensor &b, bool transA = false, } template -Tensor matmulOpImpl(const Tensor &a, const Tensor &b, bool transA, bool transB) { - // fast path +Tensor matmulOpImpl(const Tensor &a, const Tensor &b, bool transA, bool transB, const Tensor &bias) { + // 2D fast path if (a.dim() == 2 && b.dim() == 2) { // a[m, k], b[k, n] -> [m, n] int64_t m = a.shape(transA ? 1 : 0); @@ -255,14 +258,22 @@ Tensor matmulOpImpl(const Tensor &a, const Tensor &b, bool transA, bool transB) const T *otherPtr = b.dataPtr(); T *retPtr = retTensor.dataPtr(); + if (bias.defined()) { + ASSERT(bias.dim() == 1 && bias.shape(0) == n); + } auto gemm = getGemmFunc(a.device().type); ASSERT(gemm != nullptr); - gemm(retPtr, selfPtr, otherPtr, m, k, n, transA, transB, a.device().index); + gemm(retPtr, selfPtr, otherPtr, m, k, n, transA, transB, a.device().index, + bias.defined() ? bias.dataPtr() : nullptr); return retTensor; } - // slow path - return matmulOpImplDetail(a, b, transA, transB); + // batched / broadcast path + auto result = matmulOpImplDetail(a, b, transA, transB); + if (bias.defined()) { + addInplace(result, bias, 1); + } + return result; } void registerLinalgCommon() { diff --git a/src/Operation/OpLinalg.h b/src/Operation/OpLinalg.h index bb58132..d0ba255 100644 --- a/src/Operation/OpLinalg.h +++ b/src/Operation/OpLinalg.h @@ -13,7 +13,7 @@ namespace tinytorch::op { SizeVector broadcastShape(IntArrayView t0, IntArrayView t1, int64_t skipLast); template -void gemmImpl(T*, const T*, const T*, int64_t, int64_t, int64_t, bool, bool, DeviceIndex); +void gemmImpl(T*, const T*, const T*, int64_t, int64_t, int64_t, bool, bool, DeviceIndex, const T* = nullptr); template void gemmStridedBatchedImpl(T*, const T*, const T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool, @@ -23,7 +23,7 @@ template void gemmBatchedImpl(T**, const T**, const T**, int64_t, int64_t, int64_t, int64_t, bool, bool, DeviceIndex); template -using GemmFunc = void (*)(T*, const T*, const T*, int64_t, int64_t, int64_t, bool, bool, DeviceIndex); +using GemmFunc = void (*)(T*, const T*, const T*, int64_t, int64_t, int64_t, bool, bool, DeviceIndex, const T*); template using GemmStridedBatchedFunc = void (*)(T*, const T*, const T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, @@ -74,7 +74,7 @@ GemmBatchedFunc getGemmBatchedFunc(DeviceType deviceType) { using DotOpFn = Tensor (*)(const Tensor& self, const Tensor& other); using Im2ColOpFn = Tensor (*)(const Tensor& self, Dim2D kernel, Dim2D stride, Dim2D padding); using Col2ImOpFn = Tensor (*)(const Tensor& self, IntArrayView shape, Dim2D kernel, Dim2D stride, Dim2D padding); -using MatmulOpFn = Tensor (*)(const Tensor& a, const Tensor& b, bool transA, bool transB); +using MatmulOpFn = Tensor (*)(const Tensor& a, const Tensor& b, bool transA, bool transB, const Tensor& bias); // dot DEFINE_OP(dot, DotOpFn) diff --git a/src/Operation/OpLinalgCpu.h b/src/Operation/OpLinalgCpu.h index 5a83028..371378b 100644 --- a/src/Operation/OpLinalgCpu.h +++ b/src/Operation/OpLinalgCpu.h @@ -116,19 +116,29 @@ Tensor col2imOpCpuImpl(const Tensor& self, const IntArrayView shape, Dim2D kerne } template -void gemmCpuImpl(T* c, const T* a, const T* b, int64_t m, int64_t k, int64_t n, bool transA, bool transB) { +void gemmCpuImpl(T* c, const T* a, const T* b, int64_t m, int64_t k, int64_t n, bool transA, bool transB, + const T* bias = nullptr) { + if (bias) { + // broadcast bias into C: c[i][j] = bias[j] + for (int64_t i = 0; i < m; i++) { + std::memcpy(c + i * n, bias, n * sizeof(T)); + } + } // blas #if defined(__APPLE__) || defined(__BLAS__) if constexpr (std::is_same_v) { CBLAS_TRANSPOSE ta = transA ? CblasTrans : CblasNoTrans; CBLAS_TRANSPOSE tb = transB ? CblasTrans : CblasNoTrans; + float betaVal = bias ? 1.0f : 0.0f; cblas_sgemm(CblasRowMajor, ta, tb, (int)m, (int)n, (int)k, 1.0f, a, transA ? (int)m : (int)k, b, - transB ? (int)k : (int)n, 0.0f, c, (int)n); + transB ? (int)k : (int)n, betaVal, c, (int)n); return; } #endif // basic - std::memset(c, 0, m * n * sizeof(T)); + if (!bias) { + std::memset(c, 0, m * n * sizeof(T)); + } for (int64_t i = 0; i < m; i++) { for (int64_t p = 0; p < k; p++) { T aVal = transA ? a[p * m + i] : a[i * k + p]; @@ -142,20 +152,21 @@ void gemmCpuImpl(T* c, const T* a, const T* b, int64_t m, int64_t k, int64_t n, template <> void gemmImpl(float* c, const float* a, const float* b, int64_t m, int64_t k, int64_t n, - bool transA, bool transB, DeviceIndex device) { - gemmCpuImpl(c, a, b, m, k, n, transA, transB); + bool transA, bool transB, DeviceIndex device, const float* bias) { + gemmCpuImpl(c, a, b, m, k, n, transA, transB, bias); } template <> void gemmImpl(Half* c, const Half* a, const Half* b, int64_t m, int64_t k, int64_t n, - bool transA, bool transB, DeviceIndex device) { - gemmCpuImpl(c, a, b, m, k, n, transA, transB); + bool transA, bool transB, DeviceIndex device, const Half* bias) { + gemmCpuImpl(c, a, b, m, k, n, transA, transB, bias); } template <> void gemmImpl(BFloat16* c, const BFloat16* a, const BFloat16* b, int64_t m, int64_t k, - int64_t n, bool transA, bool transB, DeviceIndex device) { - gemmCpuImpl(c, a, b, m, k, n, transA, transB); + int64_t n, bool transA, bool transB, DeviceIndex device, + const BFloat16* bias) { + gemmCpuImpl(c, a, b, m, k, n, transA, transB, bias); } } // namespace tinytorch::op diff --git a/src/Operation/OpLinalgCuda.cuh b/src/Operation/OpLinalgCuda.cuh index 117e53e..41b14dc 100644 --- a/src/Operation/OpLinalgCuda.cuh +++ b/src/Operation/OpLinalgCuda.cuh @@ -137,8 +137,23 @@ Tensor col2imOpCudaImpl(const Tensor& self, const IntArrayView shape, Dim2D kern template Tensor dotOpCudaImpl(const Tensor& self, const Tensor& other); +template +__global__ void kBroadcastBias(T* c, const T* bias, int64_t m, int64_t n) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < m * n) { + c[idx] = bias[idx % n]; + } +} + inline void gemmCudaF32Impl(float* c, const float* a, const float* b, int64_t m, int64_t k, int64_t n, bool transA, - bool transB, DeviceIndex device) { + bool transB, DeviceIndex device, const float* bias = nullptr) { + float beta = 0.f; + if (bias) { + auto params = cuda::getKernelLaunchParams(device, m * n); + CUDA_LAUNCH_KERNEL(kBroadcastBias, params, c, bias, m, n); + beta = 1.f; + } + cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -147,14 +162,20 @@ inline void gemmCudaF32Impl(float* c, const float* a, const float* b, int64_t m, int ldc = static_cast(n); constexpr float alpha = 1.f; - constexpr float beta = 0.f; auto handle = cuda::getCublasHandle(device); CUBLAS_CHECK(cublasSgemm(handle, opB, opA, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); } inline void gemmCudaF16Impl(__half* c, const __half* a, const __half* b, int64_t m, int64_t k, int64_t n, bool transA, - bool transB, DeviceIndex device) { + bool transB, DeviceIndex device, const __half* bias = nullptr) { + float beta = 0.f; + if (bias) { + auto params = cuda::getKernelLaunchParams(device, m * n); + CUDA_LAUNCH_KERNEL(kBroadcastBias<__half>, params, c, bias, m, n); + beta = 1.f; + } + cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -163,7 +184,6 @@ inline void gemmCudaF16Impl(__half* c, const __half* a, const __half* b, int64_t int ldc = static_cast(n); constexpr float alpha = 1.f; - constexpr float beta = 0.f; auto handle = cuda::getCublasHandle(device); CUBLAS_CHECK(cublasGemmEx(handle, opB, opA, n, m, k, &alpha, b, CUDA_R_16F, ldb, a, CUDA_R_16F, lda, &beta, c, @@ -171,7 +191,15 @@ inline void gemmCudaF16Impl(__half* c, const __half* a, const __half* b, int64_t } inline void gemmCudaBF16Impl(__nv_bfloat16* c, const __nv_bfloat16* a, const __nv_bfloat16* b, int64_t m, int64_t k, - int64_t n, bool transA, bool transB, DeviceIndex device) { + int64_t n, bool transA, bool transB, DeviceIndex device, + const __nv_bfloat16* bias = nullptr) { + float beta = 0.f; + if (bias) { + auto params = cuda::getKernelLaunchParams(device, m * n); + CUDA_LAUNCH_KERNEL(kBroadcastBias<__nv_bfloat16>, params, c, bias, m, n); + beta = 1.f; + } + cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -180,7 +208,6 @@ inline void gemmCudaBF16Impl(__nv_bfloat16* c, const __nv_bfloat16* a, const __n int ldc = static_cast(n); constexpr float alpha = 1.f; - constexpr float beta = 0.f; auto handle = cuda::getCublasHandle(device); CUBLAS_CHECK(cublasGemmEx(handle, opB, opA, n, m, k, &alpha, b, CUDA_R_16BF, ldb, a, CUDA_R_16BF, lda, &beta, c, @@ -189,22 +216,24 @@ inline void gemmCudaBF16Impl(__nv_bfloat16* c, const __nv_bfloat16* a, const __n template <> void gemmImpl(float* c, const float* a, const float* b, int64_t m, int64_t k, int64_t n, - bool transA, bool transB, DeviceIndex device) { - gemmCudaF32Impl(c, a, b, m, k, n, transA, transB, device); + bool transA, bool transB, DeviceIndex device, const float* bias) { + gemmCudaF32Impl(c, a, b, m, k, n, transA, transB, device, bias); } template <> void gemmImpl(Half* c, const Half* a, const Half* b, int64_t m, int64_t k, int64_t n, - bool transA, bool transB, DeviceIndex device) { + bool transA, bool transB, DeviceIndex device, const Half* bias) { gemmCudaF16Impl(reinterpret_cast<__half*>(c), reinterpret_cast(a), reinterpret_cast(b), - m, k, n, transA, transB, device); + m, k, n, transA, transB, device, reinterpret_cast(bias)); } template <> void gemmImpl(BFloat16* c, const BFloat16* a, const BFloat16* b, int64_t m, int64_t k, - int64_t n, bool transA, bool transB, DeviceIndex device) { + int64_t n, bool transA, bool transB, DeviceIndex device, + const BFloat16* bias) { gemmCudaBF16Impl(reinterpret_cast<__nv_bfloat16*>(c), reinterpret_cast(a), - reinterpret_cast(b), m, k, n, transA, transB, device); + reinterpret_cast(b), m, k, n, transA, transB, device, + reinterpret_cast(bias)); } inline void gemmStridedBatchedCudaF32Impl(float* c, const float* a, const float* b, int64_t m, int64_t k, int64_t n, diff --git a/test/test_operation.cpp b/test/test_operation.cpp index cc47f10..6cde665 100644 --- a/test/test_operation.cpp +++ b/test/test_operation.cpp @@ -1223,28 +1223,28 @@ TEST(TEST_Operation, basic_im2col_col2im) { TEST(TEST_Operation, math_matmul_01) { Array2d d1 = {{1, 2}, {3, 4}}; Array2d d2 = {{2, 3}, {4, 5}}; - auto y = op::matmul(Tensor(d1), Tensor(d2), false, false); + auto y = op::matmul(Tensor(d1), Tensor(d2), false, false, Tensor{}); EXPECT_THAT(y.shape(), ElementsAre(2, 2)); EXPECT_THAT(y.toList(), ElementsAre(10, 13, 22, 29)); Array2d d3 = {{1, 2, 3}, {4, 5, 6}}; Array2d d4 = {{2, 3}, {4, 5}, {6, 7}}; - y = op::matmul(Tensor(d3), Tensor(d4), false, false); + y = op::matmul(Tensor(d3), Tensor(d4), false, false, Tensor{}); EXPECT_THAT(y.shape(), ElementsAre(2, 2)); EXPECT_THAT(y.toList(), ElementsAre(28, 34, 64, 79)); Array2d d5 = {{1, 0}, {0, 1}}; Array1d d6 = {1, 2}; - y = op::matmul(Tensor(d5), Tensor(d6), false, false); + y = op::matmul(Tensor(d5), Tensor(d6), false, false, Tensor{}); EXPECT_THAT(y.shape(), ElementsAre(2)); EXPECT_THAT(y.toList(), ElementsAre(1, 2)); - y = op::matmul(Tensor(d6), Tensor(d5), false, false); + y = op::matmul(Tensor(d6), Tensor(d5), false, false, Tensor{}); EXPECT_THAT(y.shape(), ElementsAre(2)); EXPECT_THAT(y.toList(), ElementsAre(1, 2)); Array1d d7 = {2}; - y = op::matmul(Tensor(d7), Tensor(d7), false, false); + y = op::matmul(Tensor(d7), Tensor(d7), false, false, Tensor{}); EXPECT_TRUE(y.dim() == 0); EXPECT_THAT(y.toList(), ElementsAre(4)); @@ -1255,8 +1255,8 @@ TEST(TEST_Operation, math_matmul_01) { b.reshape_({1, 2, 4, 2}); auto c = Tensor::arange(0, 1 * 2 * 4); c.reshape_({1, 4, 2}); - auto d = op::matmul(a, b, false, false); - auto e = op::matmul(a, c, false, false); + auto d = op::matmul(a, b, false, false, Tensor{}); + auto e = op::matmul(a, c, false, false, Tensor{}); EXPECT_THAT(d.shape(), ElementsAre(1, 2, 2, 2)); EXPECT_THAT(d.toList(), ElementsAre(28, 34, 76, 98, 428, 466, 604, 658)); @@ -1268,13 +1268,13 @@ TEST(TEST_Operation, math_matmul_01) { TEST(TEST_Operation, math_matmul_02) { Array2d d1 = {{1, 2}, {3, 4}}; Array2d d2 = {{2, 3}, {4, 5}}; - auto y = op::matmul(Tensor(d1), Tensor(d2), false, true); + auto y = op::matmul(Tensor(d1), Tensor(d2), false, true, Tensor{}); EXPECT_THAT(y.shape(), ElementsAre(2, 2)); EXPECT_THAT(y.toList(), ElementsAre(8, 14, 18, 32)); Array2d d3 = {{1, 2, 3}, {4, 5, 6}}; Array2d d4 = {{2, 4, 6}, {3, 5, 7}}; - y = op::matmul(Tensor(d3), Tensor(d4), false, true); + y = op::matmul(Tensor(d3), Tensor(d4), false, true, Tensor{}); EXPECT_THAT(y.shape(), ElementsAre(2, 2)); EXPECT_THAT(y.toList(), ElementsAre(28, 34, 64, 79)); } @@ -1285,7 +1285,7 @@ TEST(TEST_Operation, math_matmul_03) { a1.reshape_({2, 3, 4}); auto b1 = Tensor::arange(0, 2 * 4 * 2); b1.reshape_({2, 4, 2}); - auto c1 = op::matmul(a1, b1, false, false); + auto c1 = op::matmul(a1, b1, false, false, Tensor{}); EXPECT_THAT(c1.shape(), ElementsAre(2, 3, 2)); EXPECT_THAT(c1.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 604, 658, 780, 850, 956, 1042)); @@ -1294,7 +1294,7 @@ TEST(TEST_Operation, math_matmul_03) { a3.reshape_({2, 3, 4}); auto b3 = Tensor::arange(0, 4 * 2); b3.reshape_({4, 2}); - auto c3 = op::matmul(a3, b3, false, false); + auto c3 = op::matmul(a3, b3, false, false, Tensor{}); EXPECT_THAT(c3.shape(), ElementsAre(2, 3, 2)); EXPECT_THAT(c3.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 172, 226, 220, 290, 268, 354)); @@ -1303,7 +1303,7 @@ TEST(TEST_Operation, math_matmul_03) { a4.reshape_({3, 4}); auto b4 = Tensor::arange(0, 2 * 4 * 2); b4.reshape_({2, 4, 2}); - auto c4 = op::matmul(a4, b4, false, false); + auto c4 = op::matmul(a4, b4, false, false, Tensor{}); EXPECT_THAT(c4.shape(), ElementsAre(2, 3, 2)); EXPECT_THAT(c4.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 76, 82, 252, 274, 428, 466)); @@ -1312,7 +1312,7 @@ TEST(TEST_Operation, math_matmul_03) { a5.reshape_({1, 3, 4}); auto b5 = Tensor::arange(0, 2 * 4 * 2); b5.reshape_({2, 4, 2}); - auto c5 = op::matmul(a5, b5, false, false); + auto c5 = op::matmul(a5, b5, false, false, Tensor{}); EXPECT_THAT(c5.shape(), ElementsAre(2, 3, 2)); EXPECT_THAT(c5.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 76, 82, 252, 274, 428, 466)); @@ -1321,7 +1321,7 @@ TEST(TEST_Operation, math_matmul_03) { a6.reshape_({2, 3, 4}); auto b6 = Tensor::arange(0, 1 * 4 * 2); b6.reshape_({1, 4, 2}); - auto c6 = op::matmul(a6, b6, false, false); + auto c6 = op::matmul(a6, b6, false, false, Tensor{}); EXPECT_THAT(c6.shape(), ElementsAre(2, 3, 2)); EXPECT_THAT(c6.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 172, 226, 220, 290, 268, 354)); @@ -1330,7 +1330,7 @@ TEST(TEST_Operation, math_matmul_03) { a8.reshape_({2, 3, 4}); auto b8 = Tensor::arange(0, 2 * 2 * 4); b8.reshape_({2, 2, 4}); - auto c8 = op::matmul(a8, b8, false, true); + auto c8 = op::matmul(a8, b8, false, true, Tensor{}); EXPECT_THAT(c8.shape(), ElementsAre(2, 3, 2)); EXPECT_THAT(c8.toList(), ElementsAre(14, 38, 38, 126, 62, 214, 518, 734, 670, 950, 822, 1166)); @@ -1339,7 +1339,7 @@ TEST(TEST_Operation, math_matmul_03) { a9.reshape_({2, 4, 3}); auto b9 = Tensor::arange(0, 2 * 4 * 2); b9.reshape_({2, 4, 2}); - auto c9 = op::matmul(a9, b9, true, false); + auto c9 = op::matmul(a9, b9, true, false, Tensor{}); EXPECT_THAT(c9.shape(), ElementsAre(2, 3, 2)); EXPECT_THAT(c9.toList(), ElementsAre(84, 102, 96, 118, 108, 134, 756, 822, 800, 870, 844, 918)); @@ -1348,7 +1348,7 @@ TEST(TEST_Operation, math_matmul_03) { a10.reshape_({1, 3, 4}); auto b10 = Tensor::arange(0, 1 * 4 * 2); b10.reshape_({1, 4, 2}); - auto c10 = op::matmul(a10, b10, false, false); + auto c10 = op::matmul(a10, b10, false, false, Tensor{}); EXPECT_THAT(c10.shape(), ElementsAre(1, 3, 2)); EXPECT_THAT(c10.toList(), ElementsAre(28, 34, 76, 98, 124, 162)); @@ -1357,7 +1357,7 @@ TEST(TEST_Operation, math_matmul_03) { a12.reshape_({2, 1, 3, 4}); auto b12 = Tensor::arange(0, 2 * 1 * 4 * 2); b12.reshape_({2, 1, 4, 2}); - auto c12 = op::matmul(a12, b12, false, false); + auto c12 = op::matmul(a12, b12, false, false, Tensor{}); EXPECT_THAT(c12.shape(), ElementsAre(2, 1, 3, 2)); EXPECT_THAT(c12.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 604, 658, 780, 850, 956, 1042)); } @@ -1368,7 +1368,7 @@ TEST(TEST_Operation, math_matmul_04) { a1.reshape_({1, 2, 3}); auto b1 = Tensor::arange(0, 6 * 3 * 2); b1.reshape_({6, 3, 2}); - auto c1 = op::matmul(a1, b1, false, false); + auto c1 = op::matmul(a1, b1, false, false, Tensor{}); EXPECT_THAT(c1.shape(), ElementsAre(6, 2, 2)); EXPECT_THAT(c1.toList(), ElementsAre(10, 13, 28, 40, 28, 31, 100, 112, 46, 49, 172, 184, 64, 67, 244, 256, 82, 85, 316, 328, 100, 103, 388, 400)); @@ -1378,7 +1378,7 @@ TEST(TEST_Operation, math_matmul_04) { a2.reshape_({8, 2, 3}); auto b2 = Tensor::arange(0, 1 * 3 * 2); b2.reshape_({1, 3, 2}); - auto c2 = op::matmul(a2, b2, false, false); + auto c2 = op::matmul(a2, b2, false, false, Tensor{}); EXPECT_THAT(c2.shape(), ElementsAre(8, 2, 2)); EXPECT_THAT(c2.toList(), ElementsAre(10, 13, 28, 40, 46, 67, 64, 94, 82, 121, 100, 148, 118, 175, 136, 202, 154, 229, 172, 256, @@ -1389,7 +1389,7 @@ TEST(TEST_Operation, math_matmul_04) { a3.reshape_({2, 1, 2, 3}); auto b3 = Tensor::arange(0, 2 * 3 * 3 * 2); b3.reshape_({2, 3, 3, 2}); - auto c3 = op::matmul(a3, b3, false, false); + auto c3 = op::matmul(a3, b3, false, false, Tensor{}); EXPECT_THAT(c3.shape(), ElementsAre(2, 3, 2, 2)); EXPECT_THAT(c3.toList(), ElementsAre(10, 13, 28, 40, 28, 31, 100, 112, 46, 49, 172, 184, 424, 445, 604, 634, 550, 571, 784, 814, 676, 697, 964, 994)); From 3c8ea9d52088cab79e4093c4082b844d35f7b8d6 Mon Sep 17 00:00:00 2001 From: keith2018 Date: Fri, 3 Apr 2026 21:50:19 +0800 Subject: [PATCH 2/5] feat: enable copyOnDevice with secondary stream --- src/Tensor/Storage.cpp | 25 ++++++++++++++++--------- src/Tensor/Storage.h | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/Tensor/Storage.cpp b/src/Tensor/Storage.cpp index 53f2d4f..5c3b5d8 100644 --- a/src/Tensor/Storage.cpp +++ b/src/Tensor/Storage.cpp @@ -31,8 +31,12 @@ std::shared_ptr Storage::clone() const { return newStorage; } -void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, - int64_t nbytes) { +void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, int64_t nbytes +#ifdef USE_CUDA + , + const cuda::CUDAStream* stream +#endif +) { if (nbytes == 0) { return; } @@ -47,25 +51,28 @@ void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, // CUDA -> CUDA if (dstDevice.isCuda() && srcDevice.isCuda()) { cuda::CudaDeviceGuard guard(dstDevice.index); - auto& stream = cuda::getCurrentCUDAStream(dstDevice.index); - CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToDevice, stream.stream())); + const auto& s = stream ? *stream : cuda::getCurrentCUDAStream(dstDevice.index); + CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToDevice, s.stream())); return; } // CPU -> CUDA if (dstDevice.isCuda() && srcDevice.isCpu()) { cuda::CudaDeviceGuard guard(dstDevice.index); - auto& stream = cuda::getCurrentCUDAStream(dstDevice.index); - CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyHostToDevice, stream.stream())); + const auto& s = stream ? *stream : cuda::getCurrentCUDAStream(dstDevice.index); + CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyHostToDevice, s.stream())); return; } // CUDA -> CPU if (dstDevice.isCpu() && srcDevice.isCuda()) { cuda::CudaDeviceGuard guard(srcDevice.index); - auto& stream = cuda::getCurrentCUDAStream(srcDevice.index); - CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToHost, stream.stream())); - stream.synchronize(); + const auto& s = stream ? *stream : cuda::getCurrentCUDAStream(srcDevice.index); + CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToHost, s.stream())); + // synchronization when use default stream + if (!stream) { + s.synchronize(); + } return; } #endif diff --git a/src/Tensor/Storage.h b/src/Tensor/Storage.h index 94b85c9..67238d9 100644 --- a/src/Tensor/Storage.h +++ b/src/Tensor/Storage.h @@ -14,6 +14,12 @@ namespace tinytorch { +#ifdef USE_CUDA +namespace cuda { +struct CUDAStream; +} +#endif + class Storage { public: Storage(int64_t nbytes, Device device, Allocator* allocator = nullptr); @@ -35,8 +41,12 @@ class Storage { int64_t size() const { return nbytes_; } Device device() const { return device_; } - static void copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, - int64_t nbytes); + static void copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, int64_t nbytes +#ifdef USE_CUDA + , + const cuda::CUDAStream* stream = nullptr +#endif + ); static void copyOnDevice(void* dst, const void* src, int64_t nbytes, const Device& device); private: From a6a8f80619a36f7a46b9b24121fd985c88b36987 Mon Sep 17 00:00:00 2001 From: keith2018 Date: Fri, 3 Apr 2026 21:55:06 +0800 Subject: [PATCH 3/5] deps: update TinyFA --- third_party/TinyFA | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/TinyFA b/third_party/TinyFA index ffc2647..4e18516 160000 --- a/third_party/TinyFA +++ b/third_party/TinyFA @@ -1 +1 @@ -Subproject commit ffc264708d49e63f167b067b6d42339340469ca1 +Subproject commit 4e18516165acb029076b2ecc7d733c9ebf4d552a From 8d1d638b8f6b38a9f3443bf201eb242c930d7fe6 Mon Sep 17 00:00:00 2001 From: keith2018 Date: Fri, 3 Apr 2026 22:17:51 +0800 Subject: [PATCH 4/5] feat: refactor copyOnDevice --- src/Tensor/Storage.cpp | 17 ++++++++--------- src/Tensor/Storage.h | 14 ++------------ 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/src/Tensor/Storage.cpp b/src/Tensor/Storage.cpp index 5c3b5d8..2bffe03 100644 --- a/src/Tensor/Storage.cpp +++ b/src/Tensor/Storage.cpp @@ -31,12 +31,8 @@ std::shared_ptr Storage::clone() const { return newStorage; } -void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, int64_t nbytes -#ifdef USE_CUDA - , - const cuda::CUDAStream* stream -#endif -) { +void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, int64_t nbytes, + const void* stream) { if (nbytes == 0) { return; } @@ -51,7 +47,8 @@ void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, // CUDA -> CUDA if (dstDevice.isCuda() && srcDevice.isCuda()) { cuda::CudaDeviceGuard guard(dstDevice.index); - const auto& s = stream ? *stream : cuda::getCurrentCUDAStream(dstDevice.index); + const auto& s = + stream ? *static_cast(stream) : cuda::getCurrentCUDAStream(dstDevice.index); CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToDevice, s.stream())); return; } @@ -59,7 +56,8 @@ void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, // CPU -> CUDA if (dstDevice.isCuda() && srcDevice.isCpu()) { cuda::CudaDeviceGuard guard(dstDevice.index); - const auto& s = stream ? *stream : cuda::getCurrentCUDAStream(dstDevice.index); + const auto& s = + stream ? *static_cast(stream) : cuda::getCurrentCUDAStream(dstDevice.index); CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyHostToDevice, s.stream())); return; } @@ -67,7 +65,8 @@ void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, // CUDA -> CPU if (dstDevice.isCpu() && srcDevice.isCuda()) { cuda::CudaDeviceGuard guard(srcDevice.index); - const auto& s = stream ? *stream : cuda::getCurrentCUDAStream(srcDevice.index); + const auto& s = + stream ? *static_cast(stream) : cuda::getCurrentCUDAStream(srcDevice.index); CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToHost, s.stream())); // synchronization when use default stream if (!stream) { diff --git a/src/Tensor/Storage.h b/src/Tensor/Storage.h index 67238d9..9b71c73 100644 --- a/src/Tensor/Storage.h +++ b/src/Tensor/Storage.h @@ -14,12 +14,6 @@ namespace tinytorch { -#ifdef USE_CUDA -namespace cuda { -struct CUDAStream; -} -#endif - class Storage { public: Storage(int64_t nbytes, Device device, Allocator* allocator = nullptr); @@ -41,12 +35,8 @@ class Storage { int64_t size() const { return nbytes_; } Device device() const { return device_; } - static void copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, int64_t nbytes -#ifdef USE_CUDA - , - const cuda::CUDAStream* stream = nullptr -#endif - ); + static void copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, int64_t nbytes, + const void* stream = nullptr); static void copyOnDevice(void* dst, const void* src, int64_t nbytes, const Device& device); private: From ab62352e2efe5db54b7a2dd0f89c72b8f16d2c7d Mon Sep 17 00:00:00 2001 From: keith2018 Date: Sat, 4 Apr 2026 21:53:52 +0800 Subject: [PATCH 5/5] refactor: update project structure --- .github/workflows/cmake_linux.yml | 8 +- .github/workflows/cmake_macos.yml | 4 +- .github/workflows/cmake_windows.yml | 8 +- .gitignore | 2 +- CMakeLists.txt | 8 +- README.md | 124 ++++++++++-------- demo/demo.h | 17 --- demo/main.cpp | 21 --- examples/CMakeLists.txt | 9 ++ examples/autograd/CMakeLists.txt | 16 +++ .../autograd/main.cpp | 6 +- examples/ddp/CMakeLists.txt | 26 ++++ demo/demo_ddp.cpp => examples/ddp/main.cpp | 12 +- {demo => examples/mnist}/CMakeLists.txt | 24 +--- .../mnist}/data/t10k-images-idx3-ubyte | Bin .../mnist}/data/t10k-labels-idx1-ubyte | Bin .../mnist}/data/train-images-idx3-ubyte | Bin .../mnist}/data/train-labels-idx1-ubyte | Bin .../demo_mnist.cpp => examples/mnist/main.cpp | 8 +- examples/module/CMakeLists.txt | 16 +++ .../module/main.cpp | 6 +- examples/nccl/CMakeLists.txt | 16 +++ demo/demo_nccl.cpp => examples/nccl/main.cpp | 16 ++- examples/optimizer/CMakeLists.txt | 16 +++ .../optimizer/main.cpp | 6 +- 25 files changed, 217 insertions(+), 152 deletions(-) delete mode 100644 demo/demo.h delete mode 100644 demo/main.cpp create mode 100644 examples/CMakeLists.txt create mode 100644 examples/autograd/CMakeLists.txt rename demo/demo_autograd.cpp => examples/autograd/main.cpp (95%) create mode 100644 examples/ddp/CMakeLists.txt rename demo/demo_ddp.cpp => examples/ddp/main.cpp (98%) rename {demo => examples/mnist}/CMakeLists.txt (57%) rename {demo => examples/mnist}/data/t10k-images-idx3-ubyte (100%) rename {demo => examples/mnist}/data/t10k-labels-idx1-ubyte (100%) rename {demo => examples/mnist}/data/train-images-idx3-ubyte (100%) rename {demo => examples/mnist}/data/train-labels-idx1-ubyte (100%) rename demo/demo_mnist.cpp => examples/mnist/main.cpp (98%) create mode 100644 examples/module/CMakeLists.txt rename demo/demo_module.cpp => examples/module/main.cpp (96%) create mode 100644 examples/nccl/CMakeLists.txt rename demo/demo_nccl.cpp => examples/nccl/main.cpp (82%) create mode 100644 examples/optimizer/CMakeLists.txt rename demo/demo_optim.cpp => examples/optimizer/main.cpp (95%) diff --git a/.github/workflows/cmake_linux.yml b/.github/workflows/cmake_linux.yml index 301ed0d..17c4041 100644 --- a/.github/workflows/cmake_linux.yml +++ b/.github/workflows/cmake_linux.yml @@ -28,8 +28,8 @@ jobs: - name: Test run: cd ${{github.workspace}}/build && ctest - # - name: Demo - # run: cd ${{github.workspace}}/demo/bin && ./TinyTorch_demo + # - name: Run MNIST Example + # run: cd ${{github.workspace}}/examples/mnist/bin && ./tinytorch_example_mnist # build_linux_gpu: # name: build_linux_gpu @@ -58,5 +58,5 @@ jobs: # - name: Test # run: cd ${{github.workspace}}/build && ctest - # - name: Demo - # run: cd ${{github.workspace}}/demo/bin && ./TinyTorch_demo + # - name: Run MNIST Example + # run: cd ${{github.workspace}}/examples/mnist/bin && ./tinytorch_example_mnist diff --git a/.github/workflows/cmake_macos.yml b/.github/workflows/cmake_macos.yml index 7e91ffe..7e380f5 100644 --- a/.github/workflows/cmake_macos.yml +++ b/.github/workflows/cmake_macos.yml @@ -28,5 +28,5 @@ jobs: - name: Test run: cd ${{github.workspace}}/build && ctest - - name: Demo - run: cd ${{github.workspace}}/demo/bin && ./TinyTorch_demo + - name: Run MNIST Example + run: cd ${{github.workspace}}/examples/mnist/bin && ./tinytorch_example_mnist diff --git a/.github/workflows/cmake_windows.yml b/.github/workflows/cmake_windows.yml index c826c18..41d70c4 100644 --- a/.github/workflows/cmake_windows.yml +++ b/.github/workflows/cmake_windows.yml @@ -28,8 +28,8 @@ jobs: - name: Test run: cd ${{github.workspace}}/build && ctest - # - name: Demo - # run: cd ${{github.workspace}}/demo/bin/${{env.BUILD_TYPE}} && ./TinyTorch_demo.exe + # - name: Run MNIST Example + # run: cd ${{github.workspace}}/examples/mnist/bin/${{env.BUILD_TYPE}} && ./tinytorch_example_mnist.exe # build_windows_gpu: # name: build_windows_gpu @@ -59,5 +59,5 @@ jobs: # - name: Test # run: cd ${{github.workspace}}/build && ctest - # - name: Demo - # run: cd ${{github.workspace}}/demo/bin/${{env.BUILD_TYPE}} && ./TinyTorch_demo.exe + # - name: Run MNIST Example + # run: cd ${{github.workspace}}/examples/mnist/bin/${{env.BUILD_TYPE}} && ./tinytorch_example_mnist.exe diff --git a/.gitignore b/.gitignore index 502a905..1a0d23f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ .DS_Store .idea/ .vs/ -/demo/bin +/examples/*/bin out build cmake-build-*/ diff --git a/CMakeLists.txt b/CMakeLists.txt index ae48010..db45cd5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.10) project(TinyTorch) -option(TINYTORCH_BUILD_DEMO "Whether or not to build demo" ON) +option(TINYTORCH_BUILD_EXAMPLES "Whether or not to build examples" ON) option(TINYTORCH_BUILD_TEST "Whether or not to build the tests" OFF) option(TINYTORCH_USE_CUDA "Use CUDA" ON) @@ -15,7 +15,7 @@ if (NOT TINYTORCH_USE_CUDA OR APPLE OR MSVC) set(TINYTORCH_USE_NCCL OFF) endif () -message(STATUS "TINYTORCH_BUILD_DEMO ${TINYTORCH_BUILD_DEMO}") +message(STATUS "TINYTORCH_BUILD_EXAMPLES ${TINYTORCH_BUILD_EXAMPLES}") message(STATUS "TINYTORCH_BUILD_TEST ${TINYTORCH_BUILD_TEST}") message(STATUS "TINYTORCH_USE_CUDA ${TINYTORCH_USE_CUDA}") message(STATUS "TINYTORCH_USE_NCCL ${TINYTORCH_USE_NCCL}") @@ -30,8 +30,8 @@ endif () add_subdirectory(src) -if (TINYTORCH_BUILD_DEMO) - add_subdirectory(demo) +if (TINYTORCH_BUILD_EXAMPLES) + add_subdirectory(examples) endif () if (TINYTORCH_BUILD_TEST) diff --git a/README.md b/README.md index 97d2b73..bd09856 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # TinyTorch -**TinyTorch** is a lightweight deep learning training framework implemented from scratch in C++. +A lightweight deep learning training framework implemented from scratch in C++, featuring a PyTorch-style API. -For more details, please refer to my blog post: [Write a nn training framework from scratch](https://robot9.me/write-nn-framework-from-scratch-tinytorch/) +For more details, please refer to the blog post: [Write a nn training framework from scratch](https://robot9.me/write-nn-framework-from-scratch-tinytorch/) [![CMake Linux](https://github.com/keith2018/TinyTorch/actions/workflows/cmake_linux.yml/badge.svg)](https://github.com/keith2018/TinyTorch/actions/workflows/cmake_linux.yml) [![CMake MacOS](https://github.com/keith2018/TinyTorch/actions/workflows/cmake_macos.yml/badge.svg)](https://github.com/keith2018/TinyTorch/actions/workflows/cmake_macos.yml) @@ -10,88 +10,96 @@ For more details, please refer to my blog post: [Write a nn training framework f ## Key Features -* **PyTorch-Style API**: Similar naming conventions as PyTorch (`Tensor`, `Functions`, `nn.Module`, `Optimizer`). -* **Pure C++ Implementation**: No dependency on external deep learning libraries. -* **CPU & CUDA Support**: Runs on both CPU and CUDA-enabled GPUs. -* **Mixed Precision**: Supports FP16, FP32, BF16. -* **Distributed**: Multi-machine, multi-GPU training & inference. -* **LLM Inference**: Supports inference for llama/qwen/mistral models: [https://github.com/keith2018/TinyGPT](https://github.com/keith2018/TinyGPT) +- **PyTorch-style API** — Familiar naming conventions (`Tensor`, `nn.Module`, `Optimizer`, `DataLoader`). +- **Pure C++ implementation** — No dependency on external deep learning libraries, C++17 only. +- **CPU & CUDA** — Runs on both CPU (with BLAS acceleration) and CUDA-enabled GPUs. +- **Mixed precision** — Supports FP16, FP32 and BF16. +- **Distributed training** — Multi-machine, multi-GPU training & inference via NCCL. +- **LLM inference** — Supports inference for LLaMA / Qwen / Mistral models: [TinyGPT](https://github.com/keith2018/TinyGPT). -## Implemented Operators and Components +## Architecture -### Activation Functions -* `relu`, `gelu`, `silu` -* `softmax`, `logSoftmax` +TinyTorch implements automatic differentiation by building a dynamic computation graph. Each operation on a `Tensor` creates a `Function` node that records both the forward computation and the backward gradient rule. These nodes are linked via `nextFunctions`, forming a DAG. Calling `backward()` traverses this graph in reverse topological order, propagating gradients via the chain rule. -### Mathematical Operations -* `add`, `sub`, `mul`, `div`, `matmul` -* `sin`, `cos`, `sqrt`, `pow` -* `maximum`, `minimum` + -### Comparison and Logical Operations -* `lt`, `le`, `gt`, `ge`, `eq`, `ne` -* `logicNot`, `logicAnd`, `logicOr` +## Project Structure -### Statistical and Reduction Operations -* `min`, `argmin`, `max`, `argmax` -* `sum`, `mean`, `var` +``` +TinyTorch/ +├── src/ # Core library (Tensor, Function, nn.Module, Optimizer, ...) +├── examples/ # Standalone example programs +│ ├── autograd/ # Automatic differentiation basics +│ ├── module/ # Building models with nn.Module +│ ├── optimizer/ # Using built-in optimizers +│ ├── mnist/ # Full MNIST training pipeline +│ ├── nccl/ # NCCL collective communication +│ └── ddp/ # Distributed data-parallel training +├── test/ # Unit tests +└── third_party/ # Third-party dependencies +``` -### Tensor Shape and Indexing Operations -* `reshape`, `view`, `permute`, `transpose` -* `flatten`, `unflatten`, `squeeze`, `unsqueeze` -* `split`, `concat`, `stack`, `hstack`, `vstack`, `narrow` -* `topk`, `sort`, `cumsum` -* `gather`, `scatter` +## Getting Started -### Neural Network Layers and Loss Functions -* `linear` -* `dropout` -* `maxPool2d` -* `conv2d` -* `embedding` -* `layerNorm` -* `rmsNorm` -* `sdpAttention` -* `mseLoss` -* `nllLoss` +### Prerequisites -### Optimizers -* `SGD`, `Adagrad`, `RMSprop`, `AdaDelta`, `Adam`, `AdamW` +- CMake 3.10+ +- C++17 compatible compiler +- CUDA Toolkit 11.0+ *(optional, for GPU support)* +- NCCL *(optional, for distributed training)* -### Other -* `Dataset`, `DataLoader`, `data.Transform` +### Build -## Automatic differentiation +```bash +mkdir build +cmake -B ./build -DCMAKE_BUILD_TYPE=Release +cmake --build ./build --config Release +``` -TinyTorch's automatic differentiation (AD) is implemented by building a computation graph. Each operation on a `Tensor` is represented by a `Function` object, which is responsible for both the forward and backward passes. The `Function` nodes are connected via a `nextFunctions` field, creating the dependency graph. During the `backward()` call, the framework traverses this graph in reverse order, computing and propagating gradients using the chain rule. +#### CMake Options - +| Option | Default | Description | +|--------|---------|-------------| +| `TINYTORCH_BUILD_EXAMPLES` | `ON` | Build example programs | +| `TINYTORCH_BUILD_TEST` | `OFF` | Build unit tests | +| `TINYTORCH_USE_CUDA` | `ON` | Enable CUDA support | +| `TINYTORCH_USE_NCCL` | `ON` | Enable NCCL support | -## Getting Started +### Run Examples -### Prerequisites -* CMake -* C++17 or a more recent compiler -* CUDA Toolkit 11.0+ (optional) +Each example is an independent executable: -### Build ```bash -mkdir build -cmake -B ./build -DCMAKE_BUILD_TYPE=Release -cmake --build ./build --config Release +# Autograd basics +cd examples/autograd/bin && ./tinytorch_example_autograd + +# nn.Module usage +cd examples/module/bin && ./tinytorch_example_module + +# Optimizer usage +cd examples/optimizer/bin && ./tinytorch_example_optimizer + +# MNIST training +cd examples/mnist/bin && ./tinytorch_example_mnist ``` -### Run `MNIST` Demo +For distributed examples (requires NCCL and multiple GPUs): + ```bash -cd demo/bin -./TinyTorch_demo +# NCCL all-reduce +cd examples/nccl/bin && ./tinytorch_example_nccl + +# Distributed data-parallel training +cd examples/ddp/bin && ./tinytorch_example_ddp ``` ### Run Tests + ```bash cd build ctest ``` ## License + This code is licensed under the MIT License (see [LICENSE](LICENSE)). diff --git a/demo/demo.h b/demo/demo.h deleted file mode 100644 index 961e4f8..0000000 --- a/demo/demo.h +++ /dev/null @@ -1,17 +0,0 @@ -/* - * TinyTorch - * @author : keith@robot9.me - * - */ - -#pragma once - -void demo_autograd(); -void demo_module(); -void demo_optim(); -void demo_mnist(); - -#ifdef USE_NCCL -void demo_nccl(int argc, char **argv); -void demo_ddp(int argc, char **argv); -#endif diff --git a/demo/main.cpp b/demo/main.cpp deleted file mode 100644 index f60e37c..0000000 --- a/demo/main.cpp +++ /dev/null @@ -1,21 +0,0 @@ -/* - * TinyTorch - * @author : keith@robot9.me - * - */ - -#include "demo.h" - -int main(int argc, char **argv) { - demo_autograd(); - demo_module(); - demo_optim(); - demo_mnist(); - -#ifdef USE_NCCL - demo_nccl(argc, argv); - demo_ddp(argc, argv); -#endif - - return 0; -} \ No newline at end of file diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 0000000..4928a88 --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,9 @@ +add_subdirectory(autograd) +add_subdirectory(module) +add_subdirectory(optimizer) +add_subdirectory(mnist) + +if (TINYTORCH_USE_NCCL) + add_subdirectory(nccl) + add_subdirectory(ddp) +endif () diff --git a/examples/autograd/CMakeLists.txt b/examples/autograd/CMakeLists.txt new file mode 100644 index 0000000..9895726 --- /dev/null +++ b/examples/autograd/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.10) +project(tinytorch_example_autograd) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(${PROJECT_NAME} main.cpp) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party +) + +target_link_libraries(${PROJECT_NAME} TinyTorch_lib) + +set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin) diff --git a/demo/demo_autograd.cpp b/examples/autograd/main.cpp similarity index 95% rename from demo/demo_autograd.cpp rename to examples/autograd/main.cpp index 6b26be6..caec1aa 100644 --- a/demo/demo_autograd.cpp +++ b/examples/autograd/main.cpp @@ -11,8 +11,8 @@ using namespace tinytorch; // https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-tensors-and-autograd -void demo_autograd() { - LOGD("demo_autograd ..."); +int main() { + LOGD("autograd example ..."); Timer timer; timer.start(); @@ -53,4 +53,6 @@ void demo_autograd() { timer.mark(); LOGD("Time cost: %lld ms", timer.elapseMillis()); + + return 0; } diff --git a/examples/ddp/CMakeLists.txt b/examples/ddp/CMakeLists.txt new file mode 100644 index 0000000..7eb891b --- /dev/null +++ b/examples/ddp/CMakeLists.txt @@ -0,0 +1,26 @@ +cmake_minimum_required(VERSION 3.10) +project(tinytorch_example_ddp) + +if (CMAKE_BUILD_TYPE STREQUAL Debug) + add_definitions(-DDEBUG) +endif () + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(${PROJECT_NAME} main.cpp) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party +) + +target_link_libraries(${PROJECT_NAME} TinyTorch_lib) + +set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin) + +# copy assets +add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E remove_directory $/data + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/../mnist/data $/data +) diff --git a/demo/demo_ddp.cpp b/examples/ddp/main.cpp similarity index 98% rename from demo/demo_ddp.cpp rename to examples/ddp/main.cpp index 45973ac..4958997 100644 --- a/demo/demo_ddp.cpp +++ b/examples/ddp/main.cpp @@ -142,8 +142,8 @@ static void test(nn::Module &model, Device device, data::DataLoader &dataLoader) testLoss, correct, total, 100. * correct / (float)total, elapsed); } -void demo_ddp(int argc, char **argv) { - LOGD("demo_ddp ..."); +int main(int argc, char **argv) { + LOGD("DDP training example ..."); ASSERT(argc == 4); int localRank = std::stoi(argv[1]); @@ -154,7 +154,7 @@ void demo_ddp(int argc, char **argv) { LOGD("deviceCount: %d", deviceCount); if (localRank >= deviceCount) { LOGE("Not enough GPUs available. Required: %d, Available: %d", (localRank + 1), deviceCount); - return; + return 1; } auto dpg = distributed::DistributedProcessGroup::getInstance(); @@ -163,7 +163,7 @@ void demo_ddp(int argc, char **argv) { bool success = dpg->initProcessGroup(distributed::NCCL, initMethod, rank, worldSize); if (!success) { LOGE("InitProcessGroup failed"); - return; + return 1; } cuda::setDevice(localRank); @@ -181,7 +181,7 @@ void demo_ddp(int argc, char **argv) { if (trainDataset->size() == 0 || testDataset->size() == 0) { LOGE("Dataset invalid."); - return; + return 1; } auto sampler = @@ -224,4 +224,6 @@ void demo_ddp(int argc, char **argv) { timer.mark(); LOGD("Time cost: %lld ms", timer.elapseMillis()); + + return 0; } diff --git a/demo/CMakeLists.txt b/examples/mnist/CMakeLists.txt similarity index 57% rename from demo/CMakeLists.txt rename to examples/mnist/CMakeLists.txt index 3926ed3..5dc305f 100644 --- a/demo/CMakeLists.txt +++ b/examples/mnist/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.10) -project(TinyTorch_demo) +project(tinytorch_example_mnist) if (CMAKE_BUILD_TYPE STREQUAL Debug) add_definitions(-DDEBUG) @@ -8,32 +8,16 @@ endif () set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(DEMO_SRCS - demo_autograd.cpp - demo_module.cpp - demo_optim.cpp - demo_mnist.cpp - main.cpp -) - -if (TINYTORCH_USE_NCCL) - list(APPEND DEMO_SRCS - demo_nccl.cpp - demo_ddp.cpp - ) -endif () - -add_executable(${PROJECT_NAME} ${DEMO_SRCS}) +add_executable(${PROJECT_NAME} main.cpp) target_include_directories(${PROJECT_NAME} PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../src - ${CMAKE_CURRENT_SOURCE_DIR}/../third_party + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party ) target_link_libraries(${PROJECT_NAME} TinyTorch_lib) set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin) -SET(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/../bin) # copy assets add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD diff --git a/demo/data/t10k-images-idx3-ubyte b/examples/mnist/data/t10k-images-idx3-ubyte similarity index 100% rename from demo/data/t10k-images-idx3-ubyte rename to examples/mnist/data/t10k-images-idx3-ubyte diff --git a/demo/data/t10k-labels-idx1-ubyte b/examples/mnist/data/t10k-labels-idx1-ubyte similarity index 100% rename from demo/data/t10k-labels-idx1-ubyte rename to examples/mnist/data/t10k-labels-idx1-ubyte diff --git a/demo/data/train-images-idx3-ubyte b/examples/mnist/data/train-images-idx3-ubyte similarity index 100% rename from demo/data/train-images-idx3-ubyte rename to examples/mnist/data/train-images-idx3-ubyte diff --git a/demo/data/train-labels-idx1-ubyte b/examples/mnist/data/train-labels-idx1-ubyte similarity index 100% rename from demo/data/train-labels-idx1-ubyte rename to examples/mnist/data/train-labels-idx1-ubyte diff --git a/demo/demo_mnist.cpp b/examples/mnist/main.cpp similarity index 98% rename from demo/demo_mnist.cpp rename to examples/mnist/main.cpp index fd6f030..b4ea25a 100644 --- a/demo/demo_mnist.cpp +++ b/examples/mnist/main.cpp @@ -143,8 +143,8 @@ static void test(nn::Module &model, Device device, data::DataLoader &dataLoader) testLoss, correct, total, 100. * correct / (float)total, elapsed); } -void demo_mnist() { - LOGD("demo_mnist ..."); +int main() { + LOGD("MNIST training example ..."); TrainArgs args; manualSeed(args.seed); @@ -161,7 +161,7 @@ void demo_mnist() { if (trainDataset->size() == 0 || testDataset->size() == 0) { LOGE("Dataset invalid."); - return; + return 1; } auto trainDataloader = data::DataLoader(trainDataset, args.batchSize); @@ -197,4 +197,6 @@ void demo_mnist() { timer.mark(); LOGD("Total Time cost: %lld ms", timer.elapseMillis()); + + return 0; } diff --git a/examples/module/CMakeLists.txt b/examples/module/CMakeLists.txt new file mode 100644 index 0000000..f008a85 --- /dev/null +++ b/examples/module/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.10) +project(tinytorch_example_module) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(${PROJECT_NAME} main.cpp) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party +) + +target_link_libraries(${PROJECT_NAME} TinyTorch_lib) + +set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin) diff --git a/demo/demo_module.cpp b/examples/module/main.cpp similarity index 96% rename from demo/demo_module.cpp rename to examples/module/main.cpp index 2da0d06..be48042 100644 --- a/demo/demo_module.cpp +++ b/examples/module/main.cpp @@ -12,8 +12,8 @@ using namespace tinytorch; // https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-nn -void demo_module() { - LOGD("demo_module ..."); +int main() { + LOGD("module example ..."); Timer timer; timer.start(); @@ -58,4 +58,6 @@ void demo_module() { timer.mark(); LOGD("Time cost: %lld ms", timer.elapseMillis()); + + return 0; } diff --git a/examples/nccl/CMakeLists.txt b/examples/nccl/CMakeLists.txt new file mode 100644 index 0000000..a91a174 --- /dev/null +++ b/examples/nccl/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.10) +project(tinytorch_example_nccl) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(${PROJECT_NAME} main.cpp) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party +) + +target_link_libraries(${PROJECT_NAME} TinyTorch_lib) + +set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin) diff --git a/demo/demo_nccl.cpp b/examples/nccl/main.cpp similarity index 82% rename from demo/demo_nccl.cpp rename to examples/nccl/main.cpp index 7540513..6350323 100644 --- a/demo/demo_nccl.cpp +++ b/examples/nccl/main.cpp @@ -13,8 +13,8 @@ using namespace tinytorch; namespace tinytorch::distributed { -static void demoAllReduce(int localRank, int rank, int worldSize) { - LOGD("demoAllReduce: %d, %d, %d", localRank, rank, worldSize); +static void allReduceExample(int localRank, int rank, int worldSize) { + LOGD("allReduceExample: %d, %d, %d", localRank, rank, worldSize); auto dpg = DistributedProcessGroup::getInstance(); @@ -44,14 +44,14 @@ static void demoAllReduce(int localRank, int rank, int worldSize) { auto expected = worldSize * (worldSize + 1) / 2; bool correct = std::abs(result[0] - static_cast(expected)) < 1e-5; - std::cout << "Rank " << rank << " correct: " << (correct ? "✓" : "✗") << " (expected: " << expected + std::cout << "Rank " << rank << " correct: " << (correct ? "Y" : "N") << " (expected: " << expected << ", result: " << result[0] << ")" << std::endl; } } // namespace tinytorch::distributed -void demo_nccl(int argc, char** argv) { - LOGD("demo_nccl ..."); +int main(int argc, char** argv) { + LOGD("NCCL example ..."); Timer timer; timer.start(); @@ -64,11 +64,13 @@ void demo_nccl(int argc, char** argv) { LOGD("deviceCount: %d", deviceCount); if (localRank >= deviceCount) { LOGE("Not enough GPUs available. Required: %d, Available: %d", (localRank + 1), deviceCount); - return; + return 1; } - distributed::demoAllReduce(localRank, rank, worldSize); + distributed::allReduceExample(localRank, rank, worldSize); timer.mark(); LOGD("Time cost: %lld ms", timer.elapseMillis()); + + return 0; } diff --git a/examples/optimizer/CMakeLists.txt b/examples/optimizer/CMakeLists.txt new file mode 100644 index 0000000..fdae43f --- /dev/null +++ b/examples/optimizer/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.10) +project(tinytorch_example_optimizer) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(${PROJECT_NAME} main.cpp) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party +) + +target_link_libraries(${PROJECT_NAME} TinyTorch_lib) + +set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin) diff --git a/demo/demo_optim.cpp b/examples/optimizer/main.cpp similarity index 95% rename from demo/demo_optim.cpp rename to examples/optimizer/main.cpp index 4a7d40e..2bf7925 100644 --- a/demo/demo_optim.cpp +++ b/examples/optimizer/main.cpp @@ -12,8 +12,8 @@ using namespace tinytorch; // https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-optim -void demo_optim() { - LOGD("demo_optim ..."); +int main() { + LOGD("optimizer example ..."); Timer timer; timer.start(); @@ -54,4 +54,6 @@ void demo_optim() { timer.mark(); LOGD("Time cost: %lld ms", timer.elapseMillis()); + + return 0; }