diff --git a/.github/workflows/cmake_linux.yml b/.github/workflows/cmake_linux.yml
index 301ed0d..17c4041 100644
--- a/.github/workflows/cmake_linux.yml
+++ b/.github/workflows/cmake_linux.yml
@@ -28,8 +28,8 @@ jobs:
- name: Test
run: cd ${{github.workspace}}/build && ctest
- # - name: Demo
- # run: cd ${{github.workspace}}/demo/bin && ./TinyTorch_demo
+ # - name: Run MNIST Example
+ # run: cd ${{github.workspace}}/examples/mnist/bin && ./tinytorch_example_mnist
# build_linux_gpu:
# name: build_linux_gpu
@@ -58,5 +58,5 @@ jobs:
# - name: Test
# run: cd ${{github.workspace}}/build && ctest
- # - name: Demo
- # run: cd ${{github.workspace}}/demo/bin && ./TinyTorch_demo
+ # - name: Run MNIST Example
+ # run: cd ${{github.workspace}}/examples/mnist/bin && ./tinytorch_example_mnist
diff --git a/.github/workflows/cmake_macos.yml b/.github/workflows/cmake_macos.yml
index 7e91ffe..7e380f5 100644
--- a/.github/workflows/cmake_macos.yml
+++ b/.github/workflows/cmake_macos.yml
@@ -28,5 +28,5 @@ jobs:
- name: Test
run: cd ${{github.workspace}}/build && ctest
- - name: Demo
- run: cd ${{github.workspace}}/demo/bin && ./TinyTorch_demo
+ - name: Run MNIST Example
+ run: cd ${{github.workspace}}/examples/mnist/bin && ./tinytorch_example_mnist
diff --git a/.github/workflows/cmake_windows.yml b/.github/workflows/cmake_windows.yml
index c826c18..41d70c4 100644
--- a/.github/workflows/cmake_windows.yml
+++ b/.github/workflows/cmake_windows.yml
@@ -28,8 +28,8 @@ jobs:
- name: Test
run: cd ${{github.workspace}}/build && ctest
- # - name: Demo
- # run: cd ${{github.workspace}}/demo/bin/${{env.BUILD_TYPE}} && ./TinyTorch_demo.exe
+ # - name: Run MNIST Example
+ # run: cd ${{github.workspace}}/examples/mnist/bin/${{env.BUILD_TYPE}} && ./tinytorch_example_mnist.exe
# build_windows_gpu:
# name: build_windows_gpu
@@ -59,5 +59,5 @@ jobs:
# - name: Test
# run: cd ${{github.workspace}}/build && ctest
- # - name: Demo
- # run: cd ${{github.workspace}}/demo/bin/${{env.BUILD_TYPE}} && ./TinyTorch_demo.exe
+ # - name: Run MNIST Example
+ # run: cd ${{github.workspace}}/examples/mnist/bin/${{env.BUILD_TYPE}} && ./tinytorch_example_mnist.exe
diff --git a/.gitignore b/.gitignore
index 502a905..1a0d23f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,7 @@
.DS_Store
.idea/
.vs/
-/demo/bin
+/examples/*/bin
out
build
cmake-build-*/
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ae48010..db45cd5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.10)
project(TinyTorch)
-option(TINYTORCH_BUILD_DEMO "Whether or not to build demo" ON)
+option(TINYTORCH_BUILD_EXAMPLES "Whether or not to build examples" ON)
option(TINYTORCH_BUILD_TEST "Whether or not to build the tests" OFF)
option(TINYTORCH_USE_CUDA "Use CUDA" ON)
@@ -15,7 +15,7 @@ if (NOT TINYTORCH_USE_CUDA OR APPLE OR MSVC)
set(TINYTORCH_USE_NCCL OFF)
endif ()
-message(STATUS "TINYTORCH_BUILD_DEMO ${TINYTORCH_BUILD_DEMO}")
+message(STATUS "TINYTORCH_BUILD_EXAMPLES ${TINYTORCH_BUILD_EXAMPLES}")
message(STATUS "TINYTORCH_BUILD_TEST ${TINYTORCH_BUILD_TEST}")
message(STATUS "TINYTORCH_USE_CUDA ${TINYTORCH_USE_CUDA}")
message(STATUS "TINYTORCH_USE_NCCL ${TINYTORCH_USE_NCCL}")
@@ -30,8 +30,8 @@ endif ()
add_subdirectory(src)
-if (TINYTORCH_BUILD_DEMO)
- add_subdirectory(demo)
+if (TINYTORCH_BUILD_EXAMPLES)
+ add_subdirectory(examples)
endif ()
if (TINYTORCH_BUILD_TEST)
diff --git a/README.md b/README.md
index 97d2b73..bd09856 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
# TinyTorch
-**TinyTorch** is a lightweight deep learning training framework implemented from scratch in C++.
+A lightweight deep learning training framework implemented from scratch in C++, featuring a PyTorch-style API.
-For more details, please refer to my blog post: [Write a nn training framework from scratch](https://robot9.me/write-nn-framework-from-scratch-tinytorch/)
+For more details, please refer to the blog post: [Write a nn training framework from scratch](https://robot9.me/write-nn-framework-from-scratch-tinytorch/)
[](https://github.com/keith2018/TinyTorch/actions/workflows/cmake_linux.yml)
[](https://github.com/keith2018/TinyTorch/actions/workflows/cmake_macos.yml)
@@ -10,88 +10,96 @@ For more details, please refer to my blog post: [Write a nn training framework f
## Key Features
-* **PyTorch-Style API**: Similar naming conventions as PyTorch (`Tensor`, `Functions`, `nn.Module`, `Optimizer`).
-* **Pure C++ Implementation**: No dependency on external deep learning libraries.
-* **CPU & CUDA Support**: Runs on both CPU and CUDA-enabled GPUs.
-* **Mixed Precision**: Supports FP16, FP32, BF16.
-* **Distributed**: Multi-machine, multi-GPU training & inference.
-* **LLM Inference**: Supports inference for llama/qwen/mistral models: [https://github.com/keith2018/TinyGPT](https://github.com/keith2018/TinyGPT)
+- **PyTorch-style API** — Familiar naming conventions (`Tensor`, `nn.Module`, `Optimizer`, `DataLoader`).
+- **Pure C++ implementation** — No dependency on external deep learning libraries, C++17 only.
+- **CPU & CUDA** — Runs on both CPU (with BLAS acceleration) and CUDA-enabled GPUs.
+- **Mixed precision** — Supports FP16, FP32 and BF16.
+- **Distributed training** — Multi-machine, multi-GPU training & inference via NCCL.
+- **LLM inference** — Supports inference for LLaMA / Qwen / Mistral models: [TinyGPT](https://github.com/keith2018/TinyGPT).
-## Implemented Operators and Components
+## Architecture
-### Activation Functions
-* `relu`, `gelu`, `silu`
-* `softmax`, `logSoftmax`
+TinyTorch implements automatic differentiation by building a dynamic computation graph. Each operation on a `Tensor` creates a `Function` node that records both the forward computation and the backward gradient rule. These nodes are linked via `nextFunctions`, forming a DAG. Calling `backward()` traverses this graph in reverse topological order, propagating gradients via the chain rule.
-### Mathematical Operations
-* `add`, `sub`, `mul`, `div`, `matmul`
-* `sin`, `cos`, `sqrt`, `pow`
-* `maximum`, `minimum`
+
-### Comparison and Logical Operations
-* `lt`, `le`, `gt`, `ge`, `eq`, `ne`
-* `logicNot`, `logicAnd`, `logicOr`
+## Project Structure
-### Statistical and Reduction Operations
-* `min`, `argmin`, `max`, `argmax`
-* `sum`, `mean`, `var`
+```
+TinyTorch/
+├── src/ # Core library (Tensor, Function, nn.Module, Optimizer, ...)
+├── examples/ # Standalone example programs
+│ ├── autograd/ # Automatic differentiation basics
+│ ├── module/ # Building models with nn.Module
+│ ├── optimizer/ # Using built-in optimizers
+│ ├── mnist/ # Full MNIST training pipeline
+│ ├── nccl/ # NCCL collective communication
+│ └── ddp/ # Distributed data-parallel training
+├── test/ # Unit tests
+└── third_party/ # Third-party dependencies
+```
-### Tensor Shape and Indexing Operations
-* `reshape`, `view`, `permute`, `transpose`
-* `flatten`, `unflatten`, `squeeze`, `unsqueeze`
-* `split`, `concat`, `stack`, `hstack`, `vstack`, `narrow`
-* `topk`, `sort`, `cumsum`
-* `gather`, `scatter`
+## Getting Started
-### Neural Network Layers and Loss Functions
-* `linear`
-* `dropout`
-* `maxPool2d`
-* `conv2d`
-* `embedding`
-* `layerNorm`
-* `rmsNorm`
-* `sdpAttention`
-* `mseLoss`
-* `nllLoss`
+### Prerequisites
-### Optimizers
-* `SGD`, `Adagrad`, `RMSprop`, `AdaDelta`, `Adam`, `AdamW`
+- CMake 3.10+
+- C++17 compatible compiler
+- CUDA Toolkit 11.0+ *(optional, for GPU support)*
+- NCCL *(optional, for distributed training)*
-### Other
-* `Dataset`, `DataLoader`, `data.Transform`
+### Build
-## Automatic differentiation
+```bash
+mkdir build
+cmake -B ./build -DCMAKE_BUILD_TYPE=Release
+cmake --build ./build --config Release
+```
-TinyTorch's automatic differentiation (AD) is implemented by building a computation graph. Each operation on a `Tensor` is represented by a `Function` object, which is responsible for both the forward and backward passes. The `Function` nodes are connected via a `nextFunctions` field, creating the dependency graph. During the `backward()` call, the framework traverses this graph in reverse order, computing and propagating gradients using the chain rule.
+#### CMake Options
-
+| Option | Default | Description |
+|--------|---------|-------------|
+| `TINYTORCH_BUILD_EXAMPLES` | `ON` | Build example programs |
+| `TINYTORCH_BUILD_TEST` | `OFF` | Build unit tests |
+| `TINYTORCH_USE_CUDA` | `ON` | Enable CUDA support |
+| `TINYTORCH_USE_NCCL` | `ON` | Enable NCCL support |
-## Getting Started
+### Run Examples
-### Prerequisites
-* CMake
-* C++17 or a more recent compiler
-* CUDA Toolkit 11.0+ (optional)
+Each example is an independent executable:
-### Build
```bash
-mkdir build
-cmake -B ./build -DCMAKE_BUILD_TYPE=Release
-cmake --build ./build --config Release
+# Autograd basics
+cd examples/autograd/bin && ./tinytorch_example_autograd
+
+# nn.Module usage
+cd examples/module/bin && ./tinytorch_example_module
+
+# Optimizer usage
+cd examples/optimizer/bin && ./tinytorch_example_optimizer
+
+# MNIST training
+cd examples/mnist/bin && ./tinytorch_example_mnist
```
-### Run `MNIST` Demo
+For distributed examples (requires NCCL and multiple GPUs):
+
```bash
-cd demo/bin
-./TinyTorch_demo
+# NCCL all-reduce
+cd examples/nccl/bin && ./tinytorch_example_nccl
+
+# Distributed data-parallel training
+cd examples/ddp/bin && ./tinytorch_example_ddp
```
### Run Tests
+
```bash
cd build
ctest
```
## License
+
This code is licensed under the MIT License (see [LICENSE](LICENSE)).
diff --git a/demo/demo.h b/demo/demo.h
deleted file mode 100644
index 961e4f8..0000000
--- a/demo/demo.h
+++ /dev/null
@@ -1,17 +0,0 @@
-/*
- * TinyTorch
- * @author : keith@robot9.me
- *
- */
-
-#pragma once
-
-void demo_autograd();
-void demo_module();
-void demo_optim();
-void demo_mnist();
-
-#ifdef USE_NCCL
-void demo_nccl(int argc, char **argv);
-void demo_ddp(int argc, char **argv);
-#endif
diff --git a/demo/main.cpp b/demo/main.cpp
deleted file mode 100644
index f60e37c..0000000
--- a/demo/main.cpp
+++ /dev/null
@@ -1,21 +0,0 @@
-/*
- * TinyTorch
- * @author : keith@robot9.me
- *
- */
-
-#include "demo.h"
-
-int main(int argc, char **argv) {
- demo_autograd();
- demo_module();
- demo_optim();
- demo_mnist();
-
-#ifdef USE_NCCL
- demo_nccl(argc, argv);
- demo_ddp(argc, argv);
-#endif
-
- return 0;
-}
\ No newline at end of file
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
new file mode 100644
index 0000000..4928a88
--- /dev/null
+++ b/examples/CMakeLists.txt
@@ -0,0 +1,9 @@
+add_subdirectory(autograd)
+add_subdirectory(module)
+add_subdirectory(optimizer)
+add_subdirectory(mnist)
+
+if (TINYTORCH_USE_NCCL)
+ add_subdirectory(nccl)
+ add_subdirectory(ddp)
+endif ()
diff --git a/examples/autograd/CMakeLists.txt b/examples/autograd/CMakeLists.txt
new file mode 100644
index 0000000..9895726
--- /dev/null
+++ b/examples/autograd/CMakeLists.txt
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.10)
+project(tinytorch_example_autograd)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+add_executable(${PROJECT_NAME} main.cpp)
+
+target_include_directories(${PROJECT_NAME} PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../src
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party
+)
+
+target_link_libraries(${PROJECT_NAME} TinyTorch_lib)
+
+set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin)
diff --git a/demo/demo_autograd.cpp b/examples/autograd/main.cpp
similarity index 95%
rename from demo/demo_autograd.cpp
rename to examples/autograd/main.cpp
index 6b26be6..caec1aa 100644
--- a/demo/demo_autograd.cpp
+++ b/examples/autograd/main.cpp
@@ -11,8 +11,8 @@
using namespace tinytorch;
// https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-tensors-and-autograd
-void demo_autograd() {
- LOGD("demo_autograd ...");
+int main() {
+ LOGD("autograd example ...");
Timer timer;
timer.start();
@@ -53,4 +53,6 @@ void demo_autograd() {
timer.mark();
LOGD("Time cost: %lld ms", timer.elapseMillis());
+
+ return 0;
}
diff --git a/examples/ddp/CMakeLists.txt b/examples/ddp/CMakeLists.txt
new file mode 100644
index 0000000..7eb891b
--- /dev/null
+++ b/examples/ddp/CMakeLists.txt
@@ -0,0 +1,26 @@
+cmake_minimum_required(VERSION 3.10)
+project(tinytorch_example_ddp)
+
+if (CMAKE_BUILD_TYPE STREQUAL Debug)
+ add_definitions(-DDEBUG)
+endif ()
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+add_executable(${PROJECT_NAME} main.cpp)
+
+target_include_directories(${PROJECT_NAME} PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../src
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party
+)
+
+target_link_libraries(${PROJECT_NAME} TinyTorch_lib)
+
+set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin)
+
+# copy assets
+add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E remove_directory $/data
+ COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/../mnist/data $/data
+)
diff --git a/demo/demo_ddp.cpp b/examples/ddp/main.cpp
similarity index 98%
rename from demo/demo_ddp.cpp
rename to examples/ddp/main.cpp
index 45973ac..4958997 100644
--- a/demo/demo_ddp.cpp
+++ b/examples/ddp/main.cpp
@@ -142,8 +142,8 @@ static void test(nn::Module &model, Device device, data::DataLoader &dataLoader)
testLoss, correct, total, 100. * correct / (float)total, elapsed);
}
-void demo_ddp(int argc, char **argv) {
- LOGD("demo_ddp ...");
+int main(int argc, char **argv) {
+ LOGD("DDP training example ...");
ASSERT(argc == 4);
int localRank = std::stoi(argv[1]);
@@ -154,7 +154,7 @@ void demo_ddp(int argc, char **argv) {
LOGD("deviceCount: %d", deviceCount);
if (localRank >= deviceCount) {
LOGE("Not enough GPUs available. Required: %d, Available: %d", (localRank + 1), deviceCount);
- return;
+ return 1;
}
auto dpg = distributed::DistributedProcessGroup::getInstance();
@@ -163,7 +163,7 @@ void demo_ddp(int argc, char **argv) {
bool success = dpg->initProcessGroup(distributed::NCCL, initMethod, rank, worldSize);
if (!success) {
LOGE("InitProcessGroup failed");
- return;
+ return 1;
}
cuda::setDevice(localRank);
@@ -181,7 +181,7 @@ void demo_ddp(int argc, char **argv) {
if (trainDataset->size() == 0 || testDataset->size() == 0) {
LOGE("Dataset invalid.");
- return;
+ return 1;
}
auto sampler =
@@ -224,4 +224,6 @@ void demo_ddp(int argc, char **argv) {
timer.mark();
LOGD("Time cost: %lld ms", timer.elapseMillis());
+
+ return 0;
}
diff --git a/demo/CMakeLists.txt b/examples/mnist/CMakeLists.txt
similarity index 57%
rename from demo/CMakeLists.txt
rename to examples/mnist/CMakeLists.txt
index 3926ed3..5dc305f 100644
--- a/demo/CMakeLists.txt
+++ b/examples/mnist/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.10)
-project(TinyTorch_demo)
+project(tinytorch_example_mnist)
if (CMAKE_BUILD_TYPE STREQUAL Debug)
add_definitions(-DDEBUG)
@@ -8,32 +8,16 @@ endif ()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
-set(DEMO_SRCS
- demo_autograd.cpp
- demo_module.cpp
- demo_optim.cpp
- demo_mnist.cpp
- main.cpp
-)
-
-if (TINYTORCH_USE_NCCL)
- list(APPEND DEMO_SRCS
- demo_nccl.cpp
- demo_ddp.cpp
- )
-endif ()
-
-add_executable(${PROJECT_NAME} ${DEMO_SRCS})
+add_executable(${PROJECT_NAME} main.cpp)
target_include_directories(${PROJECT_NAME} PRIVATE
- ${CMAKE_CURRENT_SOURCE_DIR}/../src
- ${CMAKE_CURRENT_SOURCE_DIR}/../third_party
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../src
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party
)
target_link_libraries(${PROJECT_NAME} TinyTorch_lib)
set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin)
-SET(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/../bin)
# copy assets
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
diff --git a/demo/data/t10k-images-idx3-ubyte b/examples/mnist/data/t10k-images-idx3-ubyte
similarity index 100%
rename from demo/data/t10k-images-idx3-ubyte
rename to examples/mnist/data/t10k-images-idx3-ubyte
diff --git a/demo/data/t10k-labels-idx1-ubyte b/examples/mnist/data/t10k-labels-idx1-ubyte
similarity index 100%
rename from demo/data/t10k-labels-idx1-ubyte
rename to examples/mnist/data/t10k-labels-idx1-ubyte
diff --git a/demo/data/train-images-idx3-ubyte b/examples/mnist/data/train-images-idx3-ubyte
similarity index 100%
rename from demo/data/train-images-idx3-ubyte
rename to examples/mnist/data/train-images-idx3-ubyte
diff --git a/demo/data/train-labels-idx1-ubyte b/examples/mnist/data/train-labels-idx1-ubyte
similarity index 100%
rename from demo/data/train-labels-idx1-ubyte
rename to examples/mnist/data/train-labels-idx1-ubyte
diff --git a/demo/demo_mnist.cpp b/examples/mnist/main.cpp
similarity index 98%
rename from demo/demo_mnist.cpp
rename to examples/mnist/main.cpp
index fd6f030..b4ea25a 100644
--- a/demo/demo_mnist.cpp
+++ b/examples/mnist/main.cpp
@@ -143,8 +143,8 @@ static void test(nn::Module &model, Device device, data::DataLoader &dataLoader)
testLoss, correct, total, 100. * correct / (float)total, elapsed);
}
-void demo_mnist() {
- LOGD("demo_mnist ...");
+int main() {
+ LOGD("MNIST training example ...");
TrainArgs args;
manualSeed(args.seed);
@@ -161,7 +161,7 @@ void demo_mnist() {
if (trainDataset->size() == 0 || testDataset->size() == 0) {
LOGE("Dataset invalid.");
- return;
+ return 1;
}
auto trainDataloader = data::DataLoader(trainDataset, args.batchSize);
@@ -197,4 +197,6 @@ void demo_mnist() {
timer.mark();
LOGD("Total Time cost: %lld ms", timer.elapseMillis());
+
+ return 0;
}
diff --git a/examples/module/CMakeLists.txt b/examples/module/CMakeLists.txt
new file mode 100644
index 0000000..f008a85
--- /dev/null
+++ b/examples/module/CMakeLists.txt
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.10)
+project(tinytorch_example_module)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+add_executable(${PROJECT_NAME} main.cpp)
+
+target_include_directories(${PROJECT_NAME} PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../src
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party
+)
+
+target_link_libraries(${PROJECT_NAME} TinyTorch_lib)
+
+set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin)
diff --git a/demo/demo_module.cpp b/examples/module/main.cpp
similarity index 96%
rename from demo/demo_module.cpp
rename to examples/module/main.cpp
index 2da0d06..be48042 100644
--- a/demo/demo_module.cpp
+++ b/examples/module/main.cpp
@@ -12,8 +12,8 @@
using namespace tinytorch;
// https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-nn
-void demo_module() {
- LOGD("demo_module ...");
+int main() {
+ LOGD("module example ...");
Timer timer;
timer.start();
@@ -58,4 +58,6 @@ void demo_module() {
timer.mark();
LOGD("Time cost: %lld ms", timer.elapseMillis());
+
+ return 0;
}
diff --git a/examples/nccl/CMakeLists.txt b/examples/nccl/CMakeLists.txt
new file mode 100644
index 0000000..a91a174
--- /dev/null
+++ b/examples/nccl/CMakeLists.txt
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.10)
+project(tinytorch_example_nccl)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+add_executable(${PROJECT_NAME} main.cpp)
+
+target_include_directories(${PROJECT_NAME} PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../src
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party
+)
+
+target_link_libraries(${PROJECT_NAME} TinyTorch_lib)
+
+set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin)
diff --git a/demo/demo_nccl.cpp b/examples/nccl/main.cpp
similarity index 82%
rename from demo/demo_nccl.cpp
rename to examples/nccl/main.cpp
index 7540513..6350323 100644
--- a/demo/demo_nccl.cpp
+++ b/examples/nccl/main.cpp
@@ -13,8 +13,8 @@ using namespace tinytorch;
namespace tinytorch::distributed {
-static void demoAllReduce(int localRank, int rank, int worldSize) {
- LOGD("demoAllReduce: %d, %d, %d", localRank, rank, worldSize);
+static void allReduceExample(int localRank, int rank, int worldSize) {
+ LOGD("allReduceExample: %d, %d, %d", localRank, rank, worldSize);
auto dpg = DistributedProcessGroup::getInstance();
@@ -44,14 +44,14 @@ static void demoAllReduce(int localRank, int rank, int worldSize) {
auto expected = worldSize * (worldSize + 1) / 2;
bool correct = std::abs(result[0] - static_cast(expected)) < 1e-5;
- std::cout << "Rank " << rank << " correct: " << (correct ? "✓" : "✗") << " (expected: " << expected
+ std::cout << "Rank " << rank << " correct: " << (correct ? "Y" : "N") << " (expected: " << expected
<< ", result: " << result[0] << ")" << std::endl;
}
} // namespace tinytorch::distributed
-void demo_nccl(int argc, char** argv) {
- LOGD("demo_nccl ...");
+int main(int argc, char** argv) {
+ LOGD("NCCL example ...");
Timer timer;
timer.start();
@@ -64,11 +64,13 @@ void demo_nccl(int argc, char** argv) {
LOGD("deviceCount: %d", deviceCount);
if (localRank >= deviceCount) {
LOGE("Not enough GPUs available. Required: %d, Available: %d", (localRank + 1), deviceCount);
- return;
+ return 1;
}
- distributed::demoAllReduce(localRank, rank, worldSize);
+ distributed::allReduceExample(localRank, rank, worldSize);
timer.mark();
LOGD("Time cost: %lld ms", timer.elapseMillis());
+
+ return 0;
}
diff --git a/examples/optimizer/CMakeLists.txt b/examples/optimizer/CMakeLists.txt
new file mode 100644
index 0000000..fdae43f
--- /dev/null
+++ b/examples/optimizer/CMakeLists.txt
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.10)
+project(tinytorch_example_optimizer)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+add_executable(${PROJECT_NAME} main.cpp)
+
+target_include_directories(${PROJECT_NAME} PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../src
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party
+)
+
+target_link_libraries(${PROJECT_NAME} TinyTorch_lib)
+
+set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin)
diff --git a/demo/demo_optim.cpp b/examples/optimizer/main.cpp
similarity index 95%
rename from demo/demo_optim.cpp
rename to examples/optimizer/main.cpp
index 4a7d40e..2bf7925 100644
--- a/demo/demo_optim.cpp
+++ b/examples/optimizer/main.cpp
@@ -12,8 +12,8 @@
using namespace tinytorch;
// https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-optim
-void demo_optim() {
- LOGD("demo_optim ...");
+int main() {
+ LOGD("optimizer example ...");
Timer timer;
timer.start();
@@ -54,4 +54,6 @@ void demo_optim() {
timer.mark();
LOGD("Time cost: %lld ms", timer.elapseMillis());
+
+ return 0;
}
diff --git a/src/Function/FuncLinalg.h b/src/Function/FuncLinalg.h
index 4324922..a048355 100644
--- a/src/Function/FuncLinalg.h
+++ b/src/Function/FuncLinalg.h
@@ -14,7 +14,7 @@ namespace tinytorch::function {
class FuncMatmul : public Function {
public:
static Tensor forward(AutogradContext* ctx, const Tensor& a, const Tensor& b, bool transA, bool transB) {
- return op::matmul(a, b, transA, transB);
+ return op::matmul(a, b, transA, transB, Tensor{});
}
static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); }
diff --git a/src/Function/FuncNNLayer.h b/src/Function/FuncNNLayer.h
index 6190960..f9ac33d 100644
--- a/src/Function/FuncNNLayer.h
+++ b/src/Function/FuncNNLayer.h
@@ -14,12 +14,7 @@ namespace tinytorch::function {
class FuncLinear : public Function {
public:
static Tensor forward(AutogradContext* ctx, const Tensor& input, const Tensor& weight, const Tensor& bias) {
- auto output = op::matmul(input, weight, false, true);
- if (bias.defined()) {
- op::addInplace(output, bias, 1);
- }
- // TODO fuse
- return output;
+ return op::matmul(input, weight, false, true, bias);
}
static void backward(AutogradContext* ctx, const Tensor& grad) {
@@ -28,10 +23,10 @@ class FuncLinear : public Function {
auto& bias = ctx->savedInputs[2];
if (input.requiresGrad()) {
- input.addGrad(op::matmul(grad, weight, false, false));
+ input.addGrad(op::matmul(grad, weight, false, false, Tensor{}));
}
if (weight.requiresGrad()) {
- weight.addGrad(op::matmul(grad, input, true, false));
+ weight.addGrad(op::matmul(grad, input, true, false, Tensor{}));
}
if (bias.defined() && bias.requiresGrad()) {
bias.addGrad(op::sumOnDim(grad, 0, false));
@@ -161,12 +156,7 @@ class FuncConv2D : public Function {
auto col = op::im2col(input, kernel, stride, padding);
auto colW = op::reshape(weight, IntArrayView{outChannels, -1});
- auto ret = op::matmul(col, colW, false, true);
- if (bias.defined()) {
- ASSERT(bias.dim() == 1);
- ASSERT(bias.shape()[0] == outChannels);
- op::addInplace(ret, bias, 1);
- }
+ auto ret = op::matmul(col, colW, false, true, bias);
ret.reshape_({batch, outChannels, outH, outW});
if (ctx) {
@@ -192,13 +182,13 @@ class FuncConv2D : public Function {
auto colW = op::reshape(weight, IntArrayView{outChannels, -1});
if (input.requiresGrad()) {
- auto gradCol = op::matmul(gradW, colW, false, false);
+ auto gradCol = op::matmul(gradW, colW, false, false, Tensor{});
auto inputGrad = op::col2im(gradCol, input.shape(), kernel, stride, padding);
input.addGrad(std::move(inputGrad));
}
if (weight.requiresGrad()) {
auto col = ctx->popData().toTensor();
- auto gradColW = op::matmul(col, gradW, true, false);
+ auto gradColW = op::matmul(col, gradW, true, false, Tensor{});
auto weightGrad = op::reshape(gradColW.permute(), weight.shape());
weight.addGrad(std::move(weightGrad));
}
@@ -244,7 +234,7 @@ class FuncSDPAttention : public Function {
auto S = key.size(-2);
float scaleFactor = scale.has_value() ? scale.value() : (1.f / std::sqrt(static_cast(query.size(-1))));
- auto attnWeight = op::matmul(query, op::transpose(key, -2, -1), false, false);
+ auto attnWeight = op::matmul(query, op::transpose(key, -2, -1), false, false, Tensor{});
op::mulInplace(attnWeight, Tensor::scalar(scaleFactor, attnWeight.options()));
Tensor attnBias;
@@ -272,7 +262,7 @@ class FuncSDPAttention : public Function {
if (dropoutP > 0.f) {
attnWeight = op::dropout(attnWeight, dropoutP);
}
- return op::matmul(attnWeight, value, false, false);
+ return op::matmul(attnWeight, value, false, false, Tensor{});
}
static void backward(AutogradContext* ctx, const Tensor& grad) { NOT_IMPLEMENTED(); }
};
diff --git a/src/Operation/OpLinalg.cpp b/src/Operation/OpLinalg.cpp
index 6d956dc..90e5b61 100644
--- a/src/Operation/OpLinalg.cpp
+++ b/src/Operation/OpLinalg.cpp
@@ -6,6 +6,8 @@
#include "OpLinalg.h"
+#include "OpElemWise.h"
+
namespace tinytorch::op {
inline SizeVector makePaddedStrides(const IntArrayView &strides, int64_t targetDim) {
@@ -216,10 +218,11 @@ Tensor matmulOpImplDetail(const Tensor &a, const Tensor &b, bool transA = false,
}
}
- gemm(retPtr + batch * m * n, selfPtr + aOffset, otherPtr + bOffset, m, k, n, transA, transB, a.device().index);
+ gemm(retPtr + batch * m * n, selfPtr + aOffset, otherPtr + bOffset, m, k, n, transA, transB, a.device().index,
+ nullptr);
}
} else {
- gemm(retPtr, selfPtr, otherPtr, m, k, n, transA, transB, a.device().index);
+ gemm(retPtr, selfPtr, otherPtr, m, k, n, transA, transB, a.device().index, nullptr);
if (prependA) {
retTensor.reshape_({n});
}
@@ -238,8 +241,8 @@ Tensor matmulOpImplDetail(const Tensor &a, const Tensor &b, bool transA = false,
}
template
-Tensor matmulOpImpl(const Tensor &a, const Tensor &b, bool transA, bool transB) {
- // fast path
+Tensor matmulOpImpl(const Tensor &a, const Tensor &b, bool transA, bool transB, const Tensor &bias) {
+ // 2D fast path
if (a.dim() == 2 && b.dim() == 2) {
// a[m, k], b[k, n] -> [m, n]
int64_t m = a.shape(transA ? 1 : 0);
@@ -255,14 +258,22 @@ Tensor matmulOpImpl(const Tensor &a, const Tensor &b, bool transA, bool transB)
const T *otherPtr = b.dataPtr();
T *retPtr = retTensor.dataPtr();
+ if (bias.defined()) {
+ ASSERT(bias.dim() == 1 && bias.shape(0) == n);
+ }
auto gemm = getGemmFunc(a.device().type);
ASSERT(gemm != nullptr);
- gemm(retPtr, selfPtr, otherPtr, m, k, n, transA, transB, a.device().index);
+ gemm(retPtr, selfPtr, otherPtr, m, k, n, transA, transB, a.device().index,
+ bias.defined() ? bias.dataPtr() : nullptr);
return retTensor;
}
- // slow path
- return matmulOpImplDetail(a, b, transA, transB);
+ // batched / broadcast path
+ auto result = matmulOpImplDetail(a, b, transA, transB);
+ if (bias.defined()) {
+ addInplace(result, bias, 1);
+ }
+ return result;
}
void registerLinalgCommon() {
diff --git a/src/Operation/OpLinalg.h b/src/Operation/OpLinalg.h
index bb58132..d0ba255 100644
--- a/src/Operation/OpLinalg.h
+++ b/src/Operation/OpLinalg.h
@@ -13,7 +13,7 @@ namespace tinytorch::op {
SizeVector broadcastShape(IntArrayView t0, IntArrayView t1, int64_t skipLast);
template
-void gemmImpl(T*, const T*, const T*, int64_t, int64_t, int64_t, bool, bool, DeviceIndex);
+void gemmImpl(T*, const T*, const T*, int64_t, int64_t, int64_t, bool, bool, DeviceIndex, const T* = nullptr);
template
void gemmStridedBatchedImpl(T*, const T*, const T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool,
@@ -23,7 +23,7 @@ template
void gemmBatchedImpl(T**, const T**, const T**, int64_t, int64_t, int64_t, int64_t, bool, bool, DeviceIndex);
template
-using GemmFunc = void (*)(T*, const T*, const T*, int64_t, int64_t, int64_t, bool, bool, DeviceIndex);
+using GemmFunc = void (*)(T*, const T*, const T*, int64_t, int64_t, int64_t, bool, bool, DeviceIndex, const T*);
template
using GemmStridedBatchedFunc = void (*)(T*, const T*, const T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
@@ -74,7 +74,7 @@ GemmBatchedFunc getGemmBatchedFunc(DeviceType deviceType) {
using DotOpFn = Tensor (*)(const Tensor& self, const Tensor& other);
using Im2ColOpFn = Tensor (*)(const Tensor& self, Dim2D kernel, Dim2D stride, Dim2D padding);
using Col2ImOpFn = Tensor (*)(const Tensor& self, IntArrayView shape, Dim2D kernel, Dim2D stride, Dim2D padding);
-using MatmulOpFn = Tensor (*)(const Tensor& a, const Tensor& b, bool transA, bool transB);
+using MatmulOpFn = Tensor (*)(const Tensor& a, const Tensor& b, bool transA, bool transB, const Tensor& bias);
// dot
DEFINE_OP(dot, DotOpFn)
diff --git a/src/Operation/OpLinalgCpu.h b/src/Operation/OpLinalgCpu.h
index 5a83028..371378b 100644
--- a/src/Operation/OpLinalgCpu.h
+++ b/src/Operation/OpLinalgCpu.h
@@ -116,19 +116,29 @@ Tensor col2imOpCpuImpl(const Tensor& self, const IntArrayView shape, Dim2D kerne
}
template
-void gemmCpuImpl(T* c, const T* a, const T* b, int64_t m, int64_t k, int64_t n, bool transA, bool transB) {
+void gemmCpuImpl(T* c, const T* a, const T* b, int64_t m, int64_t k, int64_t n, bool transA, bool transB,
+ const T* bias = nullptr) {
+ if (bias) {
+ // broadcast bias into C: c[i][j] = bias[j]
+ for (int64_t i = 0; i < m; i++) {
+ std::memcpy(c + i * n, bias, n * sizeof(T));
+ }
+ }
// blas
#if defined(__APPLE__) || defined(__BLAS__)
if constexpr (std::is_same_v) {
CBLAS_TRANSPOSE ta = transA ? CblasTrans : CblasNoTrans;
CBLAS_TRANSPOSE tb = transB ? CblasTrans : CblasNoTrans;
+ float betaVal = bias ? 1.0f : 0.0f;
cblas_sgemm(CblasRowMajor, ta, tb, (int)m, (int)n, (int)k, 1.0f, a, transA ? (int)m : (int)k, b,
- transB ? (int)k : (int)n, 0.0f, c, (int)n);
+ transB ? (int)k : (int)n, betaVal, c, (int)n);
return;
}
#endif
// basic
- std::memset(c, 0, m * n * sizeof(T));
+ if (!bias) {
+ std::memset(c, 0, m * n * sizeof(T));
+ }
for (int64_t i = 0; i < m; i++) {
for (int64_t p = 0; p < k; p++) {
T aVal = transA ? a[p * m + i] : a[i * k + p];
@@ -142,20 +152,21 @@ void gemmCpuImpl(T* c, const T* a, const T* b, int64_t m, int64_t k, int64_t n,
template <>
void gemmImpl(float* c, const float* a, const float* b, int64_t m, int64_t k, int64_t n,
- bool transA, bool transB, DeviceIndex device) {
- gemmCpuImpl(c, a, b, m, k, n, transA, transB);
+ bool transA, bool transB, DeviceIndex device, const float* bias) {
+ gemmCpuImpl(c, a, b, m, k, n, transA, transB, bias);
}
template <>
void gemmImpl(Half* c, const Half* a, const Half* b, int64_t m, int64_t k, int64_t n,
- bool transA, bool transB, DeviceIndex device) {
- gemmCpuImpl(c, a, b, m, k, n, transA, transB);
+ bool transA, bool transB, DeviceIndex device, const Half* bias) {
+ gemmCpuImpl(c, a, b, m, k, n, transA, transB, bias);
}
template <>
void gemmImpl(BFloat16* c, const BFloat16* a, const BFloat16* b, int64_t m, int64_t k,
- int64_t n, bool transA, bool transB, DeviceIndex device) {
- gemmCpuImpl(c, a, b, m, k, n, transA, transB);
+ int64_t n, bool transA, bool transB, DeviceIndex device,
+ const BFloat16* bias) {
+ gemmCpuImpl(c, a, b, m, k, n, transA, transB, bias);
}
} // namespace tinytorch::op
diff --git a/src/Operation/OpLinalgCuda.cuh b/src/Operation/OpLinalgCuda.cuh
index 117e53e..41b14dc 100644
--- a/src/Operation/OpLinalgCuda.cuh
+++ b/src/Operation/OpLinalgCuda.cuh
@@ -137,8 +137,23 @@ Tensor col2imOpCudaImpl(const Tensor& self, const IntArrayView shape, Dim2D kern
template
Tensor dotOpCudaImpl(const Tensor& self, const Tensor& other);
+template
+__global__ void kBroadcastBias(T* c, const T* bias, int64_t m, int64_t n) {
+ const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < m * n) {
+ c[idx] = bias[idx % n];
+ }
+}
+
inline void gemmCudaF32Impl(float* c, const float* a, const float* b, int64_t m, int64_t k, int64_t n, bool transA,
- bool transB, DeviceIndex device) {
+ bool transB, DeviceIndex device, const float* bias = nullptr) {
+ float beta = 0.f;
+ if (bias) {
+ auto params = cuda::getKernelLaunchParams(device, m * n);
+ CUDA_LAUNCH_KERNEL(kBroadcastBias, params, c, bias, m, n);
+ beta = 1.f;
+ }
+
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
@@ -147,14 +162,20 @@ inline void gemmCudaF32Impl(float* c, const float* a, const float* b, int64_t m,
int ldc = static_cast(n);
constexpr float alpha = 1.f;
- constexpr float beta = 0.f;
auto handle = cuda::getCublasHandle(device);
CUBLAS_CHECK(cublasSgemm(handle, opB, opA, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
}
inline void gemmCudaF16Impl(__half* c, const __half* a, const __half* b, int64_t m, int64_t k, int64_t n, bool transA,
- bool transB, DeviceIndex device) {
+ bool transB, DeviceIndex device, const __half* bias = nullptr) {
+ float beta = 0.f;
+ if (bias) {
+ auto params = cuda::getKernelLaunchParams(device, m * n);
+ CUDA_LAUNCH_KERNEL(kBroadcastBias<__half>, params, c, bias, m, n);
+ beta = 1.f;
+ }
+
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
@@ -163,7 +184,6 @@ inline void gemmCudaF16Impl(__half* c, const __half* a, const __half* b, int64_t
int ldc = static_cast(n);
constexpr float alpha = 1.f;
- constexpr float beta = 0.f;
auto handle = cuda::getCublasHandle(device);
CUBLAS_CHECK(cublasGemmEx(handle, opB, opA, n, m, k, &alpha, b, CUDA_R_16F, ldb, a, CUDA_R_16F, lda, &beta, c,
@@ -171,7 +191,15 @@ inline void gemmCudaF16Impl(__half* c, const __half* a, const __half* b, int64_t
}
inline void gemmCudaBF16Impl(__nv_bfloat16* c, const __nv_bfloat16* a, const __nv_bfloat16* b, int64_t m, int64_t k,
- int64_t n, bool transA, bool transB, DeviceIndex device) {
+ int64_t n, bool transA, bool transB, DeviceIndex device,
+ const __nv_bfloat16* bias = nullptr) {
+ float beta = 0.f;
+ if (bias) {
+ auto params = cuda::getKernelLaunchParams(device, m * n);
+ CUDA_LAUNCH_KERNEL(kBroadcastBias<__nv_bfloat16>, params, c, bias, m, n);
+ beta = 1.f;
+ }
+
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
@@ -180,7 +208,6 @@ inline void gemmCudaBF16Impl(__nv_bfloat16* c, const __nv_bfloat16* a, const __n
int ldc = static_cast(n);
constexpr float alpha = 1.f;
- constexpr float beta = 0.f;
auto handle = cuda::getCublasHandle(device);
CUBLAS_CHECK(cublasGemmEx(handle, opB, opA, n, m, k, &alpha, b, CUDA_R_16BF, ldb, a, CUDA_R_16BF, lda, &beta, c,
@@ -189,22 +216,24 @@ inline void gemmCudaBF16Impl(__nv_bfloat16* c, const __nv_bfloat16* a, const __n
template <>
void gemmImpl(float* c, const float* a, const float* b, int64_t m, int64_t k, int64_t n,
- bool transA, bool transB, DeviceIndex device) {
- gemmCudaF32Impl(c, a, b, m, k, n, transA, transB, device);
+ bool transA, bool transB, DeviceIndex device, const float* bias) {
+ gemmCudaF32Impl(c, a, b, m, k, n, transA, transB, device, bias);
}
template <>
void gemmImpl(Half* c, const Half* a, const Half* b, int64_t m, int64_t k, int64_t n,
- bool transA, bool transB, DeviceIndex device) {
+ bool transA, bool transB, DeviceIndex device, const Half* bias) {
gemmCudaF16Impl(reinterpret_cast<__half*>(c), reinterpret_cast(a), reinterpret_cast(b),
- m, k, n, transA, transB, device);
+ m, k, n, transA, transB, device, reinterpret_cast(bias));
}
template <>
void gemmImpl(BFloat16* c, const BFloat16* a, const BFloat16* b, int64_t m, int64_t k,
- int64_t n, bool transA, bool transB, DeviceIndex device) {
+ int64_t n, bool transA, bool transB, DeviceIndex device,
+ const BFloat16* bias) {
gemmCudaBF16Impl(reinterpret_cast<__nv_bfloat16*>(c), reinterpret_cast(a),
- reinterpret_cast(b), m, k, n, transA, transB, device);
+ reinterpret_cast(b), m, k, n, transA, transB, device,
+ reinterpret_cast(bias));
}
inline void gemmStridedBatchedCudaF32Impl(float* c, const float* a, const float* b, int64_t m, int64_t k, int64_t n,
diff --git a/src/Tensor/Storage.cpp b/src/Tensor/Storage.cpp
index 53f2d4f..2bffe03 100644
--- a/src/Tensor/Storage.cpp
+++ b/src/Tensor/Storage.cpp
@@ -31,8 +31,8 @@ std::shared_ptr Storage::clone() const {
return newStorage;
}
-void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice,
- int64_t nbytes) {
+void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, int64_t nbytes,
+ const void* stream) {
if (nbytes == 0) {
return;
}
@@ -47,25 +47,31 @@ void Storage::copyOnDevice(void* dst, const Device& dstDevice, const void* src,
// CUDA -> CUDA
if (dstDevice.isCuda() && srcDevice.isCuda()) {
cuda::CudaDeviceGuard guard(dstDevice.index);
- auto& stream = cuda::getCurrentCUDAStream(dstDevice.index);
- CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToDevice, stream.stream()));
+ const auto& s =
+ stream ? *static_cast(stream) : cuda::getCurrentCUDAStream(dstDevice.index);
+ CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToDevice, s.stream()));
return;
}
// CPU -> CUDA
if (dstDevice.isCuda() && srcDevice.isCpu()) {
cuda::CudaDeviceGuard guard(dstDevice.index);
- auto& stream = cuda::getCurrentCUDAStream(dstDevice.index);
- CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyHostToDevice, stream.stream()));
+ const auto& s =
+ stream ? *static_cast(stream) : cuda::getCurrentCUDAStream(dstDevice.index);
+ CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyHostToDevice, s.stream()));
return;
}
// CUDA -> CPU
if (dstDevice.isCpu() && srcDevice.isCuda()) {
cuda::CudaDeviceGuard guard(srcDevice.index);
- auto& stream = cuda::getCurrentCUDAStream(srcDevice.index);
- CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToHost, stream.stream()));
- stream.synchronize();
+ const auto& s =
+ stream ? *static_cast(stream) : cuda::getCurrentCUDAStream(srcDevice.index);
+ CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToHost, s.stream()));
+ // synchronization when use default stream
+ if (!stream) {
+ s.synchronize();
+ }
return;
}
#endif
diff --git a/src/Tensor/Storage.h b/src/Tensor/Storage.h
index 94b85c9..9b71c73 100644
--- a/src/Tensor/Storage.h
+++ b/src/Tensor/Storage.h
@@ -35,8 +35,8 @@ class Storage {
int64_t size() const { return nbytes_; }
Device device() const { return device_; }
- static void copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice,
- int64_t nbytes);
+ static void copyOnDevice(void* dst, const Device& dstDevice, const void* src, const Device& srcDevice, int64_t nbytes,
+ const void* stream = nullptr);
static void copyOnDevice(void* dst, const void* src, int64_t nbytes, const Device& device);
private:
diff --git a/test/test_operation.cpp b/test/test_operation.cpp
index cc47f10..6cde665 100644
--- a/test/test_operation.cpp
+++ b/test/test_operation.cpp
@@ -1223,28 +1223,28 @@ TEST(TEST_Operation, basic_im2col_col2im) {
TEST(TEST_Operation, math_matmul_01) {
Array2d d1 = {{1, 2}, {3, 4}};
Array2d d2 = {{2, 3}, {4, 5}};
- auto y = op::matmul(Tensor(d1), Tensor(d2), false, false);
+ auto y = op::matmul(Tensor(d1), Tensor(d2), false, false, Tensor{});
EXPECT_THAT(y.shape(), ElementsAre(2, 2));
EXPECT_THAT(y.toList(), ElementsAre(10, 13, 22, 29));
Array2d d3 = {{1, 2, 3}, {4, 5, 6}};
Array2d d4 = {{2, 3}, {4, 5}, {6, 7}};
- y = op::matmul(Tensor(d3), Tensor(d4), false, false);
+ y = op::matmul(Tensor(d3), Tensor(d4), false, false, Tensor{});
EXPECT_THAT(y.shape(), ElementsAre(2, 2));
EXPECT_THAT(y.toList(), ElementsAre(28, 34, 64, 79));
Array2d d5 = {{1, 0}, {0, 1}};
Array1d d6 = {1, 2};
- y = op::matmul(Tensor(d5), Tensor(d6), false, false);
+ y = op::matmul(Tensor(d5), Tensor(d6), false, false, Tensor{});
EXPECT_THAT(y.shape(), ElementsAre(2));
EXPECT_THAT(y.toList(), ElementsAre(1, 2));
- y = op::matmul(Tensor(d6), Tensor(d5), false, false);
+ y = op::matmul(Tensor(d6), Tensor(d5), false, false, Tensor{});
EXPECT_THAT(y.shape(), ElementsAre(2));
EXPECT_THAT(y.toList(), ElementsAre(1, 2));
Array1d d7 = {2};
- y = op::matmul(Tensor(d7), Tensor(d7), false, false);
+ y = op::matmul(Tensor(d7), Tensor(d7), false, false, Tensor{});
EXPECT_TRUE(y.dim() == 0);
EXPECT_THAT(y.toList(), ElementsAre(4));
@@ -1255,8 +1255,8 @@ TEST(TEST_Operation, math_matmul_01) {
b.reshape_({1, 2, 4, 2});
auto c = Tensor::arange(0, 1 * 2 * 4);
c.reshape_({1, 4, 2});
- auto d = op::matmul(a, b, false, false);
- auto e = op::matmul(a, c, false, false);
+ auto d = op::matmul(a, b, false, false, Tensor{});
+ auto e = op::matmul(a, c, false, false, Tensor{});
EXPECT_THAT(d.shape(), ElementsAre(1, 2, 2, 2));
EXPECT_THAT(d.toList(), ElementsAre(28, 34, 76, 98, 428, 466, 604, 658));
@@ -1268,13 +1268,13 @@ TEST(TEST_Operation, math_matmul_01) {
TEST(TEST_Operation, math_matmul_02) {
Array2d d1 = {{1, 2}, {3, 4}};
Array2d d2 = {{2, 3}, {4, 5}};
- auto y = op::matmul(Tensor(d1), Tensor(d2), false, true);
+ auto y = op::matmul(Tensor(d1), Tensor(d2), false, true, Tensor{});
EXPECT_THAT(y.shape(), ElementsAre(2, 2));
EXPECT_THAT(y.toList(), ElementsAre(8, 14, 18, 32));
Array2d d3 = {{1, 2, 3}, {4, 5, 6}};
Array2d d4 = {{2, 4, 6}, {3, 5, 7}};
- y = op::matmul(Tensor(d3), Tensor(d4), false, true);
+ y = op::matmul(Tensor(d3), Tensor(d4), false, true, Tensor{});
EXPECT_THAT(y.shape(), ElementsAre(2, 2));
EXPECT_THAT(y.toList(), ElementsAre(28, 34, 64, 79));
}
@@ -1285,7 +1285,7 @@ TEST(TEST_Operation, math_matmul_03) {
a1.reshape_({2, 3, 4});
auto b1 = Tensor::arange(0, 2 * 4 * 2);
b1.reshape_({2, 4, 2});
- auto c1 = op::matmul(a1, b1, false, false);
+ auto c1 = op::matmul(a1, b1, false, false, Tensor{});
EXPECT_THAT(c1.shape(), ElementsAre(2, 3, 2));
EXPECT_THAT(c1.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 604, 658, 780, 850, 956, 1042));
@@ -1294,7 +1294,7 @@ TEST(TEST_Operation, math_matmul_03) {
a3.reshape_({2, 3, 4});
auto b3 = Tensor::arange(0, 4 * 2);
b3.reshape_({4, 2});
- auto c3 = op::matmul(a3, b3, false, false);
+ auto c3 = op::matmul(a3, b3, false, false, Tensor{});
EXPECT_THAT(c3.shape(), ElementsAre(2, 3, 2));
EXPECT_THAT(c3.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 172, 226, 220, 290, 268, 354));
@@ -1303,7 +1303,7 @@ TEST(TEST_Operation, math_matmul_03) {
a4.reshape_({3, 4});
auto b4 = Tensor::arange(0, 2 * 4 * 2);
b4.reshape_({2, 4, 2});
- auto c4 = op::matmul(a4, b4, false, false);
+ auto c4 = op::matmul(a4, b4, false, false, Tensor{});
EXPECT_THAT(c4.shape(), ElementsAre(2, 3, 2));
EXPECT_THAT(c4.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 76, 82, 252, 274, 428, 466));
@@ -1312,7 +1312,7 @@ TEST(TEST_Operation, math_matmul_03) {
a5.reshape_({1, 3, 4});
auto b5 = Tensor::arange(0, 2 * 4 * 2);
b5.reshape_({2, 4, 2});
- auto c5 = op::matmul(a5, b5, false, false);
+ auto c5 = op::matmul(a5, b5, false, false, Tensor{});
EXPECT_THAT(c5.shape(), ElementsAre(2, 3, 2));
EXPECT_THAT(c5.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 76, 82, 252, 274, 428, 466));
@@ -1321,7 +1321,7 @@ TEST(TEST_Operation, math_matmul_03) {
a6.reshape_({2, 3, 4});
auto b6 = Tensor::arange(0, 1 * 4 * 2);
b6.reshape_({1, 4, 2});
- auto c6 = op::matmul(a6, b6, false, false);
+ auto c6 = op::matmul(a6, b6, false, false, Tensor{});
EXPECT_THAT(c6.shape(), ElementsAre(2, 3, 2));
EXPECT_THAT(c6.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 172, 226, 220, 290, 268, 354));
@@ -1330,7 +1330,7 @@ TEST(TEST_Operation, math_matmul_03) {
a8.reshape_({2, 3, 4});
auto b8 = Tensor::arange(0, 2 * 2 * 4);
b8.reshape_({2, 2, 4});
- auto c8 = op::matmul(a8, b8, false, true);
+ auto c8 = op::matmul(a8, b8, false, true, Tensor{});
EXPECT_THAT(c8.shape(), ElementsAre(2, 3, 2));
EXPECT_THAT(c8.toList(), ElementsAre(14, 38, 38, 126, 62, 214, 518, 734, 670, 950, 822, 1166));
@@ -1339,7 +1339,7 @@ TEST(TEST_Operation, math_matmul_03) {
a9.reshape_({2, 4, 3});
auto b9 = Tensor::arange(0, 2 * 4 * 2);
b9.reshape_({2, 4, 2});
- auto c9 = op::matmul(a9, b9, true, false);
+ auto c9 = op::matmul(a9, b9, true, false, Tensor{});
EXPECT_THAT(c9.shape(), ElementsAre(2, 3, 2));
EXPECT_THAT(c9.toList(), ElementsAre(84, 102, 96, 118, 108, 134, 756, 822, 800, 870, 844, 918));
@@ -1348,7 +1348,7 @@ TEST(TEST_Operation, math_matmul_03) {
a10.reshape_({1, 3, 4});
auto b10 = Tensor::arange(0, 1 * 4 * 2);
b10.reshape_({1, 4, 2});
- auto c10 = op::matmul(a10, b10, false, false);
+ auto c10 = op::matmul(a10, b10, false, false, Tensor{});
EXPECT_THAT(c10.shape(), ElementsAre(1, 3, 2));
EXPECT_THAT(c10.toList(), ElementsAre(28, 34, 76, 98, 124, 162));
@@ -1357,7 +1357,7 @@ TEST(TEST_Operation, math_matmul_03) {
a12.reshape_({2, 1, 3, 4});
auto b12 = Tensor::arange(0, 2 * 1 * 4 * 2);
b12.reshape_({2, 1, 4, 2});
- auto c12 = op::matmul(a12, b12, false, false);
+ auto c12 = op::matmul(a12, b12, false, false, Tensor{});
EXPECT_THAT(c12.shape(), ElementsAre(2, 1, 3, 2));
EXPECT_THAT(c12.toList(), ElementsAre(28, 34, 76, 98, 124, 162, 604, 658, 780, 850, 956, 1042));
}
@@ -1368,7 +1368,7 @@ TEST(TEST_Operation, math_matmul_04) {
a1.reshape_({1, 2, 3});
auto b1 = Tensor::arange(0, 6 * 3 * 2);
b1.reshape_({6, 3, 2});
- auto c1 = op::matmul(a1, b1, false, false);
+ auto c1 = op::matmul(a1, b1, false, false, Tensor{});
EXPECT_THAT(c1.shape(), ElementsAre(6, 2, 2));
EXPECT_THAT(c1.toList(), ElementsAre(10, 13, 28, 40, 28, 31, 100, 112, 46, 49, 172, 184, 64, 67, 244, 256, 82,
85, 316, 328, 100, 103, 388, 400));
@@ -1378,7 +1378,7 @@ TEST(TEST_Operation, math_matmul_04) {
a2.reshape_({8, 2, 3});
auto b2 = Tensor::arange(0, 1 * 3 * 2);
b2.reshape_({1, 3, 2});
- auto c2 = op::matmul(a2, b2, false, false);
+ auto c2 = op::matmul(a2, b2, false, false, Tensor{});
EXPECT_THAT(c2.shape(), ElementsAre(8, 2, 2));
EXPECT_THAT(c2.toList(),
ElementsAre(10, 13, 28, 40, 46, 67, 64, 94, 82, 121, 100, 148, 118, 175, 136, 202, 154, 229, 172, 256,
@@ -1389,7 +1389,7 @@ TEST(TEST_Operation, math_matmul_04) {
a3.reshape_({2, 1, 2, 3});
auto b3 = Tensor::arange(0, 2 * 3 * 3 * 2);
b3.reshape_({2, 3, 3, 2});
- auto c3 = op::matmul(a3, b3, false, false);
+ auto c3 = op::matmul(a3, b3, false, false, Tensor{});
EXPECT_THAT(c3.shape(), ElementsAre(2, 3, 2, 2));
EXPECT_THAT(c3.toList(), ElementsAre(10, 13, 28, 40, 28, 31, 100, 112, 46, 49, 172, 184, 424, 445, 604, 634,
550, 571, 784, 814, 676, 697, 964, 994));
diff --git a/third_party/TinyFA b/third_party/TinyFA
index ffc2647..4e18516 160000
--- a/third_party/TinyFA
+++ b/third_party/TinyFA
@@ -1 +1 @@
-Subproject commit ffc264708d49e63f167b067b6d42339340469ca1
+Subproject commit 4e18516165acb029076b2ecc7d733c9ebf4d552a