From f00dee6d571fe0419f62ac76103d51112f214963 Mon Sep 17 00:00:00 2001 From: Adam Lam Date: Wed, 13 May 2026 14:43:50 -0400 Subject: [PATCH] initial formatting --- .../supplemental/pytorch-kernels/README.md | 465 +++++++++--------- .../run_compiled_multiply.py | 14 + .../Vector_Addition/run_compiled_addition.py | 11 + 3 files changed, 261 insertions(+), 229 deletions(-) create mode 100644 playbooks/supplemental/pytorch-kernels/assets/Matrix_Multiplication/run_compiled_multiply.py create mode 100644 playbooks/supplemental/pytorch-kernels/assets/Vector_Addition/run_compiled_addition.py diff --git a/playbooks/supplemental/pytorch-kernels/README.md b/playbooks/supplemental/pytorch-kernels/README.md index 7b7a3af4..ad716d4b 100644 --- a/playbooks/supplemental/pytorch-kernels/README.md +++ b/playbooks/supplemental/pytorch-kernels/README.md @@ -13,7 +13,7 @@ SPDX-License-Identifier: MIT ## Overview -Write a GPU kernel from scratch, compile it, and launch it on an AMD GPU, then watch utilization spike. This playbook shows how GPU computation actually works: you write the kernel code, and it executes in parallel across thousands of threads. +Write a GPU kernel from scratch, compile it, launch it on an AMD GPU, and watch utilization spike. This playbook shows how GPU computation actually works: write the kernel code, and execute it in parallel across thousands of threads. ## What You'll Learn @@ -28,7 +28,7 @@ Write a GPU kernel from scratch, compile it, and launch it on an AMD GPU, then w - How AMD's ROCm/HIP stack lets you write CUDA-style code that runs on AMD GPUs without modification - How to compile a kernel at runtime using `torch.cuda._compile_kernel` - How to build a native C++ kernel extension with `CUDAExtension` + pybind11, importable from Python -- How to measure kernel execution time and monitor live GPU utilization with `rocm-smi` +- How to measure kernel execution time and monitor live GPU utilization with `amd-smi` --- @@ -38,17 +38,17 @@ This playbook covers two approaches for kernel development: | Approach | Entry point | |---|---| -| **JIT Compilation** | `torch.cuda._compile_kernel`, write a kernel as a Python string, no build step | -| **C++ Extension** | `CUDAExtension` + pybind11, compile a `.cu` file into a native `.pyd` and import it | +| **JIT Compilation** | `torch.cuda._compile_kernel`, write a kernel as a Python string, with no build step | +| **C++ Extension** | `CUDAExtension` + pybind11: compile a `.cu` file into a native `.pyd` and import it | | Approach | Entry point | |---|---| -| **JIT Compilation** | `torch.cuda._compile_kernel`, write a kernel as a Python string, no build step | -| **C++ Extension** | `CUDAExtension` + pybind11, compile a `.cu` file into a native `.so` and import it | +| **JIT Compilation** | `torch.cuda._compile_kernel`, write a kernel as a Python string, with no build step | +| **C++ Extension** | `CUDAExtension` + pybind11: compile a `.cu` file into a native `.so` and import it | -Both approaches run on AMD GPUs. This is possible because PyTorch's ROCm build maps the entire CUDA API surface to HIP, `torch.cuda`, `CUDAExtension`, and CUDA kernel syntax all work on AMD hardware transparently. You write CUDA-style code; ROCm handles the translation. +Both approaches run on AMD GPUs. This is possible because PyTorch's ROCm build maps the entire CUDA API surface to HIP. This means `torch.cuda`, `CUDAExtension`, and CUDA kernel syntax all work on AMD hardware transparently. --- @@ -62,14 +62,6 @@ A GPU kernel is a function that runs in parallel across thousands of GPU threads

-### GPU Execution Model: Wavefronts - -GPU threads are scheduled in groups rather than completely independently; threads in a group execute the same instructions simultaneously. On AMD GPUs, these groups are called **wavefronts**. - -A wavefront is the smallest group of threads that the GPU scheduler executes simultaneously. All threads in a wavefront execute the same instruction at the same time. A wavefront on Radeon GPUs consists of 32 threads. - -This will become relevant later when discussing block size choices. - ### Thread Indexing Model When launching a kernel you specify two dimensions: @@ -95,90 +87,110 @@ These variables are combined to compute a globally unique thread index: int idx = blockIdx.x * blockDim.x + threadIdx.x; ``` -Total threads = `gridDim.x * blockDim.x`. Each thread processes one element independently, this is **data parallelism**. The same operation runs on many elements at once with no inter-thread dependency. +Total threads = `gridDim.x * blockDim.x`. Each thread processes one element independently. This is the foundation of **data parallelism**. The same operation runs on many elements at once, with no inter-thread dependency. --- -### AMD GPU Programming: HIP +### GPU Execution Model: Wavefronts -AMD GPUs use **HIP** (Heterogeneous-Compute Interface for Portability), part of the **ROCm** (Radeon Open Compute) platform. **ROCm** is the full AMD open-source GPU compute stack: drivers, compilers, libraries, and runtime. HIP sits on top of ROCm. +AMD GPUs execute threads in groups of **32** called **wavefronts**. All threads in a wavefront run the same instruction simultaneously. This affects optimal block size choices (256 threads = 8 wavefronts = good scheduling efficiency). -HIP is designed to be syntactically close to CUDA. Most CUDA code can be translated to HIP mechanically using the `hipify` tool (which is what generated the `.hip` files in this repo). +### AMD GPU Programming: HIP + ROCm ---- +**ROCm** is AMD's open-source GPU compute stack (drivers, compilers, libraries, runtime). **HIP** sits on top, designed to be syntactically identical to CUDA. PyTorch's ROCm build transparently maps `torch.cuda.*` to HIP, so the same code works on AMD GPUs. -### PyTorch + AMD/HIP +--- -PyTorch ships a ROCm build where the CUDA API surface (`torch.cuda.*`) is transparently backed by HIP. This means: +## Setup -- `torch.cuda.is_available()` works on AMD GPUs with ROCm -- `tensor.to("cuda")` allocates on the AMD GPU -- `torch.version.hip` exposes the HIP version +### Create a Virtual Environment -PyTorch also exposes `torch.cuda._compile_kernel()`, a high-level shortcut to JIT-compile a raw kernel string and get back a callable, without needing a separate build step. + + +On Windows, open a terminal in the directory of your choice and follow the commands to create a venv with ROCm+Pytorch already installed. + +```bash +python -m venv kernel-env --system-site-packages +kernel-env\Scripts\activate +``` + ---- +> **Tip**: Windows users may need to modify their PowerShell Execution Policy (e.g. +> setting it to RemoteSigned or Unrestricted) before running some Powershell commands. -## Setup - -### Prerequisites - Windows -- Install latest: [AMD Adrenalin Software](https://www.amd.com/en/products/software/adrenalin.html) + -### Create a Virtual Environment +On Linux, open a terminal in the directory of your choice and follow the commands to create a venv with ROCm+Pytorch already installed. + ```bash +sudo apt update sudo apt install -y python3-venv -python3 -m venv ~/rocm-env -source ~/rocm-env/bin/activate +python3 -m venv kernel-env --system-site-packages +source kernel-env/bin/activate ``` - + + + + +On Windows, open a terminal in the directory of your choice and follow the commands to create a venv. + ```bash -python -m venv rocm-env -rocm-env\Scripts\activate +python -m venv kernel-env +kernel-env\Scripts\activate ``` - - + ---- +> **Tip**: Windows users may need to modify their PowerShell Execution Policy (e.g. +> setting it to RemoteSigned or Unrestricted) before running some Powershell commands. + + + -### Installing Dependencies +On Linux, open a terminal in the directory of your choice and follow the commands to create a venv. + ```bash -source ~/rocm-env/bin/activate +sudo apt update +sudo apt install -y python3-venv +python3 -m venv kernel-env +source kernel-env/bin/activate +``` + + + + -pip install --upgrade pip setuptools wheel -pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ "rocm[libraries,devel]" -# sudo reboot -source ~/rocm-env/bin/activate -pip install --pre --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ torch==2.10.0 torchaudio torchvision -``` +### Installing Basic Dependencies + + - -```bash -rocm-env\Scripts\activate + + -pip install --upgrade pip setuptools wheel -pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ "rocm[libraries,devel]" -# Reboot +--- -# Open a Powershell terminal and activate Visual Studio environment +### Installing Additional Dependencies +```bash +pip install --upgrade pip setuptools wheel +``` + +Open a Powershell terminal and activate Visual Studio environment C++ dependencies. +```bash cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } } - -rocm-env\Scripts\activate - -pip install --pre --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ torch==2.10.0 torchaudio torchvision ``` #### Set Environment Variables +#### Linux ```bash rocm-sdk init # Initialize the devel libraries @@ -188,7 +200,18 @@ export PATH = "$ROCM_HOME/bin:$PATH" ``` + +#### Linux & HaloBox +```bash +# Set compiler and build settings +export CC=clang +export CXX=clang +export DISTUTILS_USE_SDK=1 +``` + + +#### Windows ```bash rocm-sdk init # Initialize the devel libraries @@ -206,29 +229,35 @@ $env:DISTUTILS_USE_SDK = "1" --- +## Download Required Files + +Create the following directory structure by making the **2 new folders** and downloading the corresponding files: + +| Directory | Files to Download | Description | +|-----------|-------------------|-------------| +| **Vector_Addition/** | [add_one_kernel.py](assets/Vector_Addition/add_one_kernel.py)
[add_one_kernel.cu](assets/Vector_Addition/add_one_kernel.cu)
[setup.py](assets/Vector_Addition/setup.py)
[run_compiled_addition.py](assets/Vector_Addition/run_compiled_addition.py)| JIT and C++ extension files for vector addition kernel | +| **Matrix_Multiplication/** | [matmul_kernel.py](assets/Matrix_Multiplication/matmul_kernel.py)
[matmul_kernel.cu](assets/Matrix_Multiplication/matmul_kernel.cu)
[setup.py](assets/Matrix_Multiplication/setup.py)
[run_compiled_multiply.py](assets/Matrix_Multiplication/run_compiled_multiply.py) | JIT and C++ extension files for matrix multiplication kernel | + + ## Walkthroughs ### Walkthrough 1: Vector Addition -#### Approach A: JIT Compilation: [`add_one_kernel.py`](assets/Vector_Addition/add_one_kernel.py) - -Kernel is written as a raw C++ string inside Python and compiled at runtime via PyTorch's built-in JIT. +#### Approach A: JIT Compilation -#### Why Block Size = 256? -The kernel uses **256 threads per block**. This value is commonly used because it aligns well with the **wavefront execution model of AMD GPUs**. +JIT (Just-In-Time) compilation means the kernel is written as a raw C++ string inside Python and compiled at runtime, without needing extra build steps. -AMD Radeon GPUs execute threads in groups of 32 threads, called a wavefront. -``` -256 threads per block = 8 wavefronts per block - = 8 × 32 threads +To use [add_one_kernel.py](assets/Vector_Addition/add_one_kernel.py), make sure it's downloaded and run: +```bash +cd Vector_Addition # if not already inside the directory +python add_one_kernel.py ``` -This allows the GPU scheduler to keep multiple wavefronts active within a single block, which improves scheduling efficiency and helps keep compute units busy. -**How it works:** +**Key Code Snippets** ```python import torch -# 1. Kernel source as a string +# Snippet 1: Kernel source as a string KERNEL_SOURCE = """ extern "C" __global__ void add_one(float* data, int n) { @@ -240,7 +269,8 @@ __global__ void add_one(float* data, int n) { } """ -# 2. Compile the kernel string, PyTorch calls hipcc under the hood on ROCm + +# Snippet 2: Compile the kernel string. PyTorch calls hipcc under the hood with ROCm add_one_kernel = torch.cuda._compile_kernel(KERNEL_SOURCE, "add_one") x = torch.ones(100_000_000, dtype=torch.float32, device="cuda") @@ -248,7 +278,8 @@ n = x.numel() block_size = 256 grid_size = (n + block_size - 1) // block_size -# 3. Launch: specify grid/block dimensions and pass tensor args directly + +# Snippet 3: Launch: specify the grid/block dimensions and pass tensor arguments directly for _ in range(200): add_one_kernel( grid=(grid_size, 1, 1), @@ -256,30 +287,36 @@ for _ in range(200): args=[x, n], ) -# 4. Test the output -print("First 5 elements:", x[:5].cpu()) #tensor([200001., 200001., 200001., 200001., 200001.]) + +# Snippet 4: Test the output +print("First 5 elements:", x[:5].cpu()) +#Expected output: tensor([200001., 200001., 200001., 200001., 200001.]) ``` -The script also spawns a background thread that polls `rocm-smi` every 100ms to log peak and average GPU utilization during the kernel run. +> **Tip**: The script also spawns a background thread that polls `amd-smi` every 100ms to log peak and average GPU utilization during the kernel run. -**What the workload actually does:** -``` -100,000,000 elements in the tensor - × 1,000 inner loop iterations per kernel launch → +1,000 per element per launch - × 200 outer loop launches → +200,000 per element total +> **Note**: **Why is Block Size 256?**
+> - The kernel uses **256 threads per block** because it aligns well with the **wavefront execution model of AMD GPUs**. +> - Recall that AMD hardware executes threads in groups of 32 threads, resulting in 8 wavefronts per block. (8 wavefronts x 32 threads = 1 block) -Starting value: 1.0 -Final value: 200,001.0 (per element) -``` -The inner `for (int i = 0; i < 1000; i++)` loop is artificial, its only purpose is to make each kernel launch run long enough for `rocm-smi` to capture meaningful utilization. Without it, 200 launches over 100M elements would complete near-instantly and the sampling thread would likely read very low GPU utilization. +**What the workload does:** -**Run:** -```bash -python "Vector_Addition/add_one_kernel.py" -``` +The kernel artificially adds extra work to demonstrate GPU utilization: + +- **100,000,000 elements** in the tensor +- **Inner loop runs 1,000 times** per element per kernel launch +- **200 kernel launches** total + +**Math:** +- Each element: gets incremented by 1 × 1,000 iterations × 200 launches = +200,000 +- Final result: 1.0 (starting value) + 200,000 (additions) = 200,001.0 +**Why the inner loop?** +- Without the `for (int i = 0; i < 1000; i++)` loop, 200 launches would finish instantly and the monitoring tools wouldn't capture meaningful GPU utilization. The artificial work makes each kernel run long enough for monitoring tools to measure performance. + + **Expected output:**[The performance numbers might vary] ``` First 5 elements: tensor([200001., 200001., 200001., 200001., 200001.]) @@ -287,8 +324,16 @@ Elapsed time: 2.753s Peak GPU Utilization: 93% Average GPU Utilization: 65.94% ``` + + -On Windows, `rocm-smi` is not supported. To track GPU utilization, you can use Task Manager, where you should see a brief spike to 100% utilization when you run the program. +>**Note**: On Windows, `amd-smi` is not supported. To track GPU utilization, you can use Task Manager, where you should see a brief spike to 100% utilization when you run the program. +**Expected output:** +``` +First 5 elements: tensor([200001., 200001., 200001., 200001., 200001.]) +Elapsed time: 2.753s +No GPU Usage captured. +``` **Nice work! You just ran your first GPU kernel.** @@ -296,27 +341,29 @@ On Windows, `rocm-smi` is not supported. To track GPU utilization, you can use T #### Approach B: C++ Extension -The full manual path: write the kernel and Python binding in a single `.cu` file, compile it as a native extension using PyTorch's build system, then import and call it from Python. - -**Files:** +The second approach is more manual: write the kernel and Python binding to a single `.cu` file, compile it natively using PyTorch's build system, and import it into Python. +Download the following files if you haven't already: | File | Role | |---|---| | [add_one_kernel.cu](assets/Vector_Addition/add_one_kernel.cu) | Kernel + launcher + pybind11 binding, everything in one file | | [setup.py](assets/Vector_Addition/setup.py) | Build script, uses `CUDAExtension` to compile the `.cu` into a `.pyd` | +| [run_compiled_addition.py](assets/Vector_Addition/run_compiled_addition.py) | Python script that runs the built artifacts | + | File | Role | |---|---| | [add_one_kernel.cu](assets/Vector_Addition/add_one_kernel.cu) | Kernel + launcher + pybind11 binding, everything in one file | | [setup.py](assets/Vector_Addition/setup.py) | Build script, uses `CUDAExtension` to compile the `.cu` into a `.so` | +| [run_compiled_addition.py](assets/Vector_Addition/run_compiled_addition.py) | Python script that runs the built artifacts | -**How it works:** - -**Step 1: The kernel, launcher, and binding** ([add_one_kernel.cu](assets/Vector_Addition/add_one_kernel.cu)): +#### **Step 1: The kernel, launcher, and binding** ([add_one_kernel.cu](assets/Vector_Addition/add_one_kernel.cu)): ```cpp +#include +#include // GPU kernel, one thread per element __global__ void add_one(float* data, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -339,21 +386,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { } ``` -#### Why `hipDeviceSynchronize()`? - -GPU kernel launches are asynchronous. When the CPU launches a kernel like this: -``` -add_one<<>>(data, n); -``` -the CPU immediately continues executing the next instruction without waiting for the GPU to finish. `hipDeviceSynchronize()` forces the CPU to block until the GPU kernel completes. - -**Step 2: Build** +>**Tip**: Why use `hipDeviceSynchronize()`?
+> - GPU kernel launches are asynchronous. When the CPU runs `add_one<<>>(data, n);` it would immediately execute the next instruction without waiting for the GPU. `hipDeviceSynchronize()` forces the CPU to wait until the GPU kernel completes. +#### **Step 2: Build** ```bash pip install --no-build-isolation -v . ``` +>**Note**: This command looks for `setup.py` in the current directory to build the .cu file we have created. + -`CUDAExtension` is a CUDA build helper from `torch.utils.cpp_extension`. On AMD with ROCm, PyTorch **remaps `CUDAExtension` to use `hipcc`** instead of `nvcc`, so the same `setup.py` that would build a CUDA extension on NVIDIA compiles to AMD GPU code without any changes. This is the key mechanism that makes CUDA extension code portable to AMD: PyTorch's ROCm build intercepts the build path and routes it through the HIP compiler. Produces these in the same directory: +`CUDAExtension` is a CUDA build helper from `torch.utils.cpp_extension`. With ROCm, PyTorch **remaps `CUDAExtension` to use `hipcc`** instead of `nvcc`. ROCm intercepts the build path and routes it through the HIP compiler, porting CUDA code to AMD. + +This produces the following files: - `build/`: directory with the `.pyd` files - `add_one_kernel.hip`: the HIP source generated by hipifying the `.cu` file; this is what `hipcc` actually compiled @@ -363,105 +408,65 @@ pip install --no-build-isolation -v . - `add_one_kernel.hip`: the HIP source generated by hipifying the `.cu` file; this is what `hipcc` actually compiled -**Step 3: Use from Python** -```python -import os, sys -import torch -os.chdir("Vector_Addition") -sys.path.insert(0, os.getcwd()) -import add_one_ext - -x = torch.ones(10, device="cuda") -add_one_ext.add_one(x) -print(x[:5].cpu()) +#### **Step 3: Use from Python** ([run_compiled_addition.py](assets/Vector_Addition/run_compiled_addition.py)): +Execute this script to see the kernel in action: +```bash +cd Vector_Addition # if not already in directory +python run_compiled_addition.py ``` **Expected output:** -```python ->>> x = torch.ones(10, device="cuda") ->>> x -tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0') ->>> add_one_ext.add_one(x) ->>> x -tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.], device='cuda:0') +``` +Before: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0') +After: tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.], device='cuda:0') ``` --- ### Walkthrough 2: Matrix Multiplication -Given matrices **A** (M×N) and **B** (N×K), compute **C = A * B** (M×K). Each element `C[i][j]` is the dot product of row `i` of A with column `j` of B, completely independent of every other output element, making this a natural fit for GPU parallelism. - -#### The Math +Matrix multiplication computes **C = A × B** where: +- **A** is M×N (rows × columns) +- **B** is N×K +- **C** is M×K (the result) Each output element is defined as: - $$C[row, col] = \sum_{n=0}^{N-1} A[row, n] \cdot B[n, col]$$ -Each output element is assigned to exactly one thread, and threads don't depend on each other's results, thread `(0,0)` and thread `(1,5)` run simultaneously with no coordination. However, within a single thread the dot product is **sequential**: the `n` loop iterates N times, accumulating one multiply-add per step. - -#### Row-Major Memory Layout - -GPU memory is **flat (1D)**. A 2D matrix stored in row-major order lays out each row contiguously, one after another. - -For a 2×3 matrix A: - -``` -A = [ a00 a01 a02 - a10 a11 a12 ] - -Stored in memory: - Index: 0 1 2 3 4 5 - Value: a00 a01 a02 a10 a11 a12 -``` - -To reach `A[row][col]`, skip `row` full rows (each `N` elements wide), then advance `col` steps: +Each element of C is calculated independently, making this perfect for GPU parallelism. -$$A[row, col] = A[row \times N + col]$$ +#### How It Maps to GPU Threads -The same principle applies to B (column width K): +Unlike vector addition (1D), matrix multiplication produces a **2D output**, so we use a **2D grid of threads**: -$$B[n, col] = B[n \times K + col]$$ - -Substituting into the matmul formula gives the exact inner loop in the kernel: +| | Vector Addition | Matrix Multiplication | +|---|---|---| +| **Output shape** | 1D array | 2D matrix (M×K) | +| **Thread mapping** | 1 thread → 1 element | 1 thread → 1 output element | +| **Launch pattern** | 1D grid: `(grid_x, 1, 1)` | 2D grid: `(grid_x, grid_y, 1)` | +| **Block size** | `(256, 1, 1)` | `(16, 16, 1)` = 256 threads | -$$C[row, col] = \sum_{n=0}^{N-1} A[row \times N + n] \cdot B[n \times K + col]$$ +Each thread computes one element of the output matrix C. Thread at position `(row, col)` computes `C[row][col]` by multiplying the corresponding row of A with the corresponding column of B. -#### 2D thread indexing +**Memory Layout**: GPU memory is flat (1D), but matrices are stored row-by-row. To access `A[row][col]`, the kernel uses `A[row * N + col]`. -Vector addition maps one thread to one element of a 1D array. Matrix multiplication maps one thread to one element of a 2D output matrix, so the natural launch shape is a **2D grid of 2D blocks**. -| | Vector Addition | Matrix Multiplication | -|---|---|---| -| Output shape | 1D vector, length N | 2D matrix, M×K | -| Thread grid | 1D: `(grid_x, 1, 1)` | 2D: `(grid_x, grid_y, 1)` | -| Thread block | 1D: `(256, 1, 1)` = 256 threads | 2D: `(16, 16, 1)` = 256 threads | -| Thread index | `idx = blockIdx.x * blockDim.x + threadIdx.x` | `row = blockIdx.y * blockDim.y + threadIdx.y` & `col = blockIdx.x * blockDim.x + threadIdx.x` | -| Work per thread | `data[idx] += 1` | `C[row][col] = Σ A[row][k] * B[k][col]` | +#### Approach A: JIT Compilation: -The block is still 256 threads total (16×16), matching the convention from Walkthrough 1, but arranged in a square to align naturally with the 2D output. +Like Walkthrough 1, the kernel is written as a raw C++ string inside Python and compiled at runtime via PyTorch's built-in JIT. -``` -Grid (2D) -└── Block [bx, by] ... - └── Thread [tx, ty] → computes C[by*16+ty][bx*16+tx] -``` -The grid covers the full output: -``` -grid_x = ceil(K / 16) # enough blocks to span all K columns -grid_y = ceil(M / 16) # enough blocks to span all M rows +To use [matmul_kernel.py](assets/Matrix_Multiplication/matmul_kernel.py), make sure it's downloaded and run: +```bash +cd Matrix_Multiplication # if not already inside the directory +python matmul_kernel.py ``` -#### Approach A: JIT Compilation: [`matmul_kernel.py`](assets/Matrix_Multiplication/matmul_kernel.py) - -Kernel is written as a raw C++ string inside Python and compiled at runtime via PyTorch's built-in JIT. Identical workflow to Walkthrough 1, only the kernel body and launch dimensions change. - -**How it works:** +**Key Code Snippets** ```python import torch -# 1. Kernel source, 2D indexing to map threads onto the M×K output matrix +# Snippet 1: Kernel source as a string KERNEL_SOURCE = """ extern "C" __global__ void matmul(float* A, float* B, float* C, int M, int N, int K) { @@ -477,6 +482,8 @@ __global__ void matmul(float* A, float* B, float* C, int M, int N, int K) { } } """ + +# Snippet 2: Creating the Matrix - 2D indexing to map threads onto the M×K output matrix # Inputs: A is M x N, B is N x K, C is M x K M, N, K = 1024, 512, 768 @@ -488,10 +495,12 @@ BLOCK = 16 grid_x = (K + BLOCK - 1) // BLOCK grid_y = (M + BLOCK - 1) // BLOCK -# 2. Compile the kernel string + +# Snippet 3: Compile the kernel string matmul_kernel = torch.cuda._compile_kernel(KERNEL_SOURCE, "matmul") -# 3. Launch with a 2D grid, grid_x covers columns (K), grid_y covers rows (M) + +# Snippet 4:. Launch with a 2D grid, grid_x covers columns (K), grid_y covers rows (M) BLOCK = 16 matmul_kernel( grid=(grid_x, grid_y, 1), @@ -504,49 +513,56 @@ max_err = (C - C_ref).abs().max().item() print(f"Max error vs torch.mm: {max_err:.6f}") ``` -The row-major memory layout of the tensors maps directly to how the kernel indexes the flat pointers: -- `A[row * N + n]`: row `row`, column `n` -- `B[n * K + col]`: row `n`, column `col` - -The script spawns the same background monitoring thread from Walkthrough 1 (`rocm-smi` polled every 100ms) and verifies the result against `torch.mm`. Floating-point arithmetic on GPUs may produce small numerical differences compared to CPU implementations due to parallel reduction order. This is why we verify the result using a tolerance (`max error`) instead of exact equality. +The script verifies the result against `torch.mm` with a small tolerance. Floating-point arithmetic on GPUs may produce small numerical differences compared to CPU implementations due to parallel reduction order. -**Run:** -```bash -python "Matrix_Multiplication/matmul_kernel.py" + +**Expected output:**[The performance numbers might vary] +``` +Elapsed time: 2.753s +Max error vs torch.mm: 0.000160 +Peak GPU Utilization: 93% +Average GPU Utilization: 65.94% ``` + -**Expected output:**[The performance numbers might vary] + +>**Note**: On Windows, `amd-smi` is not supported. To track GPU utilization, you can use Task Manager, where you should see a brief spike to 100% utilization when you run the program. +**Expected output:** ``` -Elapsed time: 0.255s +Elapsed time: 2.753s Max error vs torch.mm: 0.000160 -Peak GPU Utilization: 100% -Average GPU Utilization: 55.00% +No GPU Usage captured. ``` + --- #### Approach B: C++ Extension -The full manual path: write the kernel and Python binding in a `.cu` file, compile it as a native extension, then import and call it from Python. Mirrors the structure of `add_one_kernel.cu` exactly, only the kernel signature and launcher logic differ. +The second approach is more manual: write the kernel and Python binding to a single `.cu` file, compile it natively using PyTorch's build system, and import it into Python. -**Files:** +Download the following files if you haven't already: +#### Windows | File | Role | |---|---| | [matmul_kernel.cu](assets/Matrix_Multiplication/matmul_kernel.cu) | Kernel + launcher + pybind11 binding | | [setup.py](assets/Matrix_Multiplication/setup.py) | Build script, uses `CUDAExtension` to compile the `.cu` into a `.pyd` | +| [run_compiled_multiply.py](assets/Matrix_Multiplication/run_compiled_multiply.py) | Python script that runs the built artifacts | +#### Linux | File | Role | |---|---| | [matmul_kernel.cu](assets/Matrix_Multiplication/matmul_kernel.cu) | Kernel + launcher + pybind11 binding | | [setup.py](assets/Matrix_Multiplication/setup.py) | Build script, uses `CUDAExtension` to compile the `.cu` into a `.so` | +| [run_compiled_multiply.py](assets/Matrix_Multiplication/run_compiled_multiply.py) | Python script that runs the built artifacts | -**How it works:** - -**Step 1: The kernel, launcher, and binding** ([matmul_kernel.cu](assets/Matrix_Multiplication/matmul_kernel.cu)): +#### **Step 1: The kernel, launcher, and binding** ([matmul_kernel.cu](assets/Matrix_Multiplication/matmul_kernel.cu)): ```cpp +#include +#include #define BLOCK 16 // GPU kernel, one thread per output element of C @@ -589,49 +605,38 @@ Compared to `add_one_launcher` in Walkthrough 1, the launcher here: - Allocates and returns the output tensor C, rather than mutating in-place - Uses `dim3` for both grid and block to express the 2D launch shape -**Step 2: Build** - +#### **Step 2: Build** ```bash pip install --no-build-isolation -v . ``` +>**Note**: This command looks for `setup.py` in the current directory to build the .cu file we have created. + -Produces these in the same directory: +This produces the following files: +#### Windows - `build/`: directory with the `.pyd` files - `matmul_kernel.hip`: the HIP source generated by hipifying the `.cu` file; this is what `hipcc` actually compiled +#### Linux - `build/`: directory with the `.so` files - `matmul_kernel.hip`: the HIP source generated by hipifying the `.cu` file; this is what `hipcc` actually compiled -The same `CUDAExtension` → `hipcc` remapping as walkthrough 1 applies here unchanged. - -**Step 3: Use from Python** -```python -import os, sys -import torch -os.chdir("Matrix_Multiplication") -sys.path.insert(0, os.getcwd()) -import matmul_ext - -A = torch.tensor([[1., 2.], - [3., 4.]], device="cuda") - -B = torch.tensor([[5., 6.], - [7., 8.]], device="cuda") - -C = matmul_ext.matmul(A, B) +#### **Step 3: Use from Python** ([run_compiled_multiply.py](assets/Matrix_Multiplication/run_compiled_multiply.py)): +Execute this script to see the kernel in action: +```bash +cd Matrix_Multiplication # if not already in directory +python run_compiled_multiply.py ``` **Expected output:** -```python ->>> C -tensor([[19., 22.], +``` +Result: tensor([[19., 22.], [43., 50.]], device='cuda:0') ->>> (C - torch.mm(A, B)).abs().max() -tensor(0., device='cuda:0') ``` + **Awesome! You just implemented matrix multiplication on the GPU.** This is a major milestone because matrix multiplication is the backbone of modern machine learning operations like: - Neural network layers - Attention mechanisms @@ -642,14 +647,16 @@ tensor(0., device='cuda:0') ## Next Steps -Once you understand the naive matmul kernel, you can explore more advanced GPU strategies. You may take one step further and implement these improvements: -- Instead of reading every value of A and B from global memory repeatedly, blocks of threads load tiles into shared memory and reuse them across many multiply-add operations. - -- Each product `A[row][k] * B[k][col]` is independent. Instead of one thread computing the full dot product, multiple threads can compute partial sums and then reduce them together. +You've learned to write, compile, and launch GPU kernels using both JIT compilation and C++ extensions for basic parallel operations. -- Rather than computing a single element per thread, a thread can compute a small tile (e.g., 4×4) of the output. This increases data reuse from shared memory and improves arithmetic intensity. +**Performance optimizations:** +- **Shared memory tiling** - Cache data blocks to reduce global memory access +- **Memory coalescing** - Optimize memory access patterns for bandwidth -You may also implement real-world GPU workloads: -- **2D Convolution (Image Filtering)**: A small filter (kernel) slides across an image, computing each output pixel from a weighted sum of neighboring pixels. This introduces stencil computations and shared memory tiling, where threads reuse overlapping image regions to reduce global memory access. +**Real-world algorithms:** +- **2D Convolution** - Image filtering and neural network layers +- **Reductions** - Parallel sums, max finding, softmax operations -- **Softmax Function**: Softmax converts a vector of numbers into probabilities that sum to 1, commonly used in neural network outputs. Implementing it efficiently on GPU introduces parallel reductions and numerical stability techniques while processing large vectors. +**Production considerations:** +- **Error handling** - Bounds checking and device management +- **PyTorch integration** - Custom operators with autograd support \ No newline at end of file diff --git a/playbooks/supplemental/pytorch-kernels/assets/Matrix_Multiplication/run_compiled_multiply.py b/playbooks/supplemental/pytorch-kernels/assets/Matrix_Multiplication/run_compiled_multiply.py new file mode 100644 index 00000000..38b6e192 --- /dev/null +++ b/playbooks/supplemental/pytorch-kernels/assets/Matrix_Multiplication/run_compiled_multiply.py @@ -0,0 +1,14 @@ +import os, sys +import torch +os.chdir("Matrix_Multiplication") +sys.path.insert(0, os.getcwd()) +import matmul_ext + +A = torch.tensor([[1., 2.], + [3., 4.]], device="cuda") + +B = torch.tensor([[5., 6.], + [7., 8.]], device="cuda") + +C = matmul_ext.matmul(A, B) +print("Result:", C.cpu()) \ No newline at end of file diff --git a/playbooks/supplemental/pytorch-kernels/assets/Vector_Addition/run_compiled_addition.py b/playbooks/supplemental/pytorch-kernels/assets/Vector_Addition/run_compiled_addition.py new file mode 100644 index 00000000..06f97997 --- /dev/null +++ b/playbooks/supplemental/pytorch-kernels/assets/Vector_Addition/run_compiled_addition.py @@ -0,0 +1,11 @@ +import os, sys +import torch +os.chdir("Vector_Addition") +sys.path.insert(0, os.getcwd()) +import add_one_ext + +x = torch.ones(10, device="cuda") +print("Before:", x.cpu()) + +add_one_ext.add_one(x) +print("After:", x[:5].cpu()) \ No newline at end of file