diff --git a/Makefile b/Makefile index 11e8f920..11bca155 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,23 @@ CC ?= cc UNAME_S := $(shell uname -s) +# Default backend selection +ifeq ($(UNAME_S),Darwin) + BACKEND ?= metal +else + # On Linux, try to detect ROCm or CUDA if BACKEND is not set. + # Default to 'cpu' if neither is found. + ifeq ($(BACKEND),) + ifneq ($(wildcard /opt/rocm/bin/hipcc),) + BACKEND = rocm + else ifneq ($(shell which nvcc 2>/dev/null),) + BACKEND = cuda + else + BACKEND = cpu + endif + endif +endif + ifeq ($(UNAME_S),Darwin) NATIVE_CPU_FLAG ?= -mcpu=native else @@ -9,59 +26,77 @@ endif CFLAGS ?= -O3 -ffast-math $(NATIVE_CPU_FLAG) -Wall -Wextra -std=c99 OBJCFLAGS ?= -O3 -ffast-math $(NATIVE_CPU_FLAG) -Wall -Wextra -fobjc-arc - LDLIBS ?= -lm -pthread METAL_SRCS := $(wildcard metal/*.metal) -ifeq ($(UNAME_S),Darwin) -METAL_LDLIBS := $(LDLIBS) -framework Foundation -framework Metal -CORE_OBJS = ds4.o ds4_metal.o +CORE_OBJS = ds4.o CPU_CORE_OBJS = ds4_cpu.o -else -CFLAGS += -D_GNU_SOURCE -fno-finite-math-only -CUDA_HOME ?= /usr/local/cuda -NVCC ?= $(CUDA_HOME)/bin/nvcc -CUDA_ARCH ?= native -ifneq ($(strip $(CUDA_ARCH)),) -NVCC_ARCH_FLAGS := -arch=$(CUDA_ARCH) -endif -NVCCFLAGS ?= -O3 --use_fast_math $(NVCC_ARCH_FLAGS) -Xcompiler $(NATIVE_CPU_FLAG) -Xcompiler -pthread -CUDA_LDLIBS ?= -lm -Xcompiler -pthread -L$(CUDA_HOME)/targets/sbsa-linux/lib -L$(CUDA_HOME)/lib64 -lcudart -lcublas -CORE_OBJS = ds4.o ds4_cuda.o -CPU_CORE_OBJS = ds4_cpu.o -METAL_LDLIBS := $(LDLIBS) + +# Backend specific settings +ifeq ($(BACKEND),metal) + METAL_LDLIBS := $(LDLIBS) -framework Foundation -framework Metal + CORE_OBJS += ds4_metal.o + LDLIBS_BIN = $(METAL_LDLIBS) + CC_BIN = $(CC) endif -.PHONY: all clean test cpu cuda-regression +ifeq ($(BACKEND),cuda) + CUDA_HOME ?= /usr/local/cuda + NVCC ?= $(CUDA_HOME)/bin/nvcc + CUDA_ARCH ?= native + ifneq ($(strip $(CUDA_ARCH)),) + NVCC_ARCH_FLAGS := -arch=$(CUDA_ARCH) + endif + NVCCFLAGS ?= -O3 --use_fast_math $(NVCC_ARCH_FLAGS) -Xcompiler $(NATIVE_CPU_FLAG) -Xcompiler -pthread + CUDA_LDLIBS ?= -lm -Xcompiler -pthread -L$(CUDA_HOME)/targets/sbsa-linux/lib -L$(CUDA_HOME)/lib64 -lcudart -lcublas + CORE_OBJS += ds4_cuda.o + CFLAGS += -DDS4_HAVE_CUDA + LDLIBS_BIN = $(CUDA_LDLIBS) + CC_BIN = $(NVCC) + CC_BIN_FLAGS = $(NVCCFLAGS) + REGRESSION_TEST = tests/cuda_long_context_smoke +endif -all: ds4 ds4-server ds4-bench +ifeq ($(BACKEND),rocm) + ROCM_HOME ?= /opt/rocm + HIPCC ?= $(ROCM_HOME)/bin/hipcc + HIP_ARCH ?= native + ifneq ($(strip $(HIP_ARCH)),) + HIP_ARCH_FLAGS := --offload-arch=$(HIP_ARCH) + endif + HIPCCFLAGS ?= -O3 -ffast-math $(HIP_ARCH_FLAGS) $(NATIVE_CPU_FLAG) -pthread -Wno-unused-result + HIP_LDLIBS ?= -lm -pthread -L$(ROCM_HOME)/lib -lhipblas -lamdhip64 + CORE_OBJS += ds4_hip.o + CFLAGS += -DDS4_HAVE_ROCM + LDLIBS_BIN = $(HIP_LDLIBS) + CC_BIN = $(HIPCC) + CC_BIN_FLAGS = $(HIPCCFLAGS) + REGRESSION_TEST = tests/rocm_long_context_smoke +endif -ifeq ($(UNAME_S),Darwin) -ds4: ds4_cli.o linenoise.o $(CORE_OBJS) - $(CC) $(CFLAGS) -o $@ ds4_cli.o linenoise.o $(CORE_OBJS) $(METAL_LDLIBS) +ifeq ($(BACKEND),cpu) + CFLAGS += -DDS4_NO_GPU + CORE_OBJS = $(CPU_CORE_OBJS) + LDLIBS_BIN = $(LDLIBS) + CC_BIN = $(CC) +endif -ds4-server: ds4_server.o rax.o $(CORE_OBJS) - $(CC) $(CFLAGS) -o $@ ds4_server.o rax.o $(CORE_OBJS) $(METAL_LDLIBS) +ifeq ($(UNAME_S),Linux) + CFLAGS += -D_GNU_SOURCE -fno-finite-math-only +endif -ds4-bench: ds4_bench.o $(CORE_OBJS) - $(CC) $(CFLAGS) -o $@ ds4_bench.o $(CORE_OBJS) $(METAL_LDLIBS) +.PHONY: all clean test cpu cuda-regression rocm-regression -cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o linenoise.o rax.o $(CPU_CORE_OBJS) - $(CC) $(CFLAGS) -o ds4 ds4_cli_cpu.o linenoise.o $(CPU_CORE_OBJS) $(LDLIBS) - $(CC) $(CFLAGS) -o ds4-server ds4_server_cpu.o rax.o $(CPU_CORE_OBJS) $(LDLIBS) - $(CC) $(CFLAGS) -o ds4-bench ds4_bench_cpu.o $(CPU_CORE_OBJS) $(LDLIBS) +all: ds4 ds4-server ds4-bench -cuda-regression: - @echo "cuda-regression requires a CUDA build" -else ds4: ds4_cli.o linenoise.o $(CORE_OBJS) - $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) + $(CC_BIN) $(CC_BIN_FLAGS) $(CFLAGS) -o $@ $^ $(LDLIBS_BIN) ds4-server: ds4_server.o rax.o $(CORE_OBJS) - $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) + $(CC_BIN) $(CC_BIN_FLAGS) $(CFLAGS) -o $@ $^ $(LDLIBS_BIN) ds4-bench: ds4_bench.o $(CORE_OBJS) - $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) + $(CC_BIN) $(CC_BIN_FLAGS) $(CFLAGS) -o $@ $^ $(LDLIBS_BIN) cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o linenoise.o rax.o $(CPU_CORE_OBJS) $(CC) $(CFLAGS) -o ds4 ds4_cli_cpu.o linenoise.o $(CPU_CORE_OBJS) $(LDLIBS) @@ -70,7 +105,9 @@ cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o linenoise.o rax.o $(CPU_CORE cuda-regression: tests/cuda_long_context_smoke ./tests/cuda_long_context_smoke -endif + +rocm-regression: tests/rocm_long_context_smoke + ./tests/rocm_long_context_smoke ds4.o: ds4.c ds4.h ds4_gpu.h $(CC) $(CFLAGS) -c -o $@ ds4.c @@ -87,9 +124,6 @@ ds4_bench.o: ds4_bench.c ds4.h ds4_test.o: tests/ds4_test.c ds4_server.c ds4.h rax.h $(CC) $(CFLAGS) -Wno-unused-function -c -o $@ tests/ds4_test.c -tests/cuda_long_context_smoke.o: tests/cuda_long_context_smoke.c ds4_gpu.h - $(CC) $(CFLAGS) -I. -c -o $@ tests/cuda_long_context_smoke.c - rax.o: rax.c rax.h rax_malloc.h $(CC) $(CFLAGS) -c -o $@ rax.c @@ -114,18 +148,26 @@ ds4_metal.o: ds4_metal.m ds4_gpu.h $(METAL_SRCS) ds4_cuda.o: ds4_cuda.cu ds4_gpu.h ds4_iq2_tables_cuda.inc $(NVCC) $(NVCCFLAGS) -c -o $@ ds4_cuda.cu +ds4_hip.o: ds4_hip.cpp ds4_gpu.h ds4_iq2_tables_hip.inc + $(HIPCC) $(HIPCCFLAGS) -c -o $@ ds4_hip.cpp + +tests/cuda_long_context_smoke.o: tests/cuda_long_context_smoke.c ds4_gpu.h + $(CC) $(CFLAGS) -I. -c -o $@ tests/cuda_long_context_smoke.c + tests/cuda_long_context_smoke: tests/cuda_long_context_smoke.o ds4_cuda.o $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) +tests/rocm_long_context_smoke.o: tests/rocm_long_context_smoke.c ds4_gpu.h + $(CC) $(CFLAGS) -I. -c -o $@ tests/rocm_long_context_smoke.c + +tests/rocm_long_context_smoke: tests/rocm_long_context_smoke.o ds4_hip.o + $(HIPCC) $(HIPCCFLAGS) -o $@ $^ $(HIP_LDLIBS) + ds4_test: ds4_test.o rax.o $(CORE_OBJS) -ifeq ($(UNAME_S),Darwin) - $(CC) $(CFLAGS) -o $@ ds4_test.o rax.o $(CORE_OBJS) $(METAL_LDLIBS) -else - $(NVCC) $(NVCCFLAGS) -o $@ ds4_test.o rax.o $(CORE_OBJS) $(CUDA_LDLIBS) -endif + $(CC_BIN) $(CC_BIN_FLAGS) $(CFLAGS) -o $@ $^ $(LDLIBS_BIN) test: ds4_test ./ds4_test clean: - rm -f ds4 ds4-server ds4-bench ds4_cpu ds4_native ds4_server_test ds4_test *.o tests/cuda_long_context_smoke tests/cuda_long_context_smoke.o + rm -f ds4 ds4-server ds4-bench ds4_cpu ds4_native ds4_server_test ds4_test *.o tests/cuda_long_context_smoke tests/cuda_long_context_smoke.o tests/rocm_long_context_smoke tests/rocm_long_context_smoke.o diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md new file mode 100644 index 00000000..bf6e3119 --- /dev/null +++ b/PR_DESCRIPTION.md @@ -0,0 +1,47 @@ +### **Pull Request Description: Add AMD ROCm/HIP Support and Strix Halo Optimizations** + +#### **Overview** +This PR introduces a complete AMD ROCm/HIP backend to DwarfStar 4, optimized specifically for hardware with unified memory architectures like the **AMD Strix Halo (gfx1151)**. It migrates the project from its original CUDA dependency to a portable HIP implementation while maintaining functional parity and performance. + +#### **Key Changes** +1. **ROCm/HIP Backend Migration**: + * Ported `ds4_cuda.cu` to `ds4_hip.cpp` and transitioned all symbol dependencies from CUDA/cuBLAS to HIP/hipBLAS. + * Updated the `Makefile` to detect and support the ROCm stack using `hipcc`. +2. **Strix Halo / HSA Optimizations**: + * **Zero-Copy Memory Access**: Configured the engine to use HSA direct access (Zero-Copy) by default on AMD hardware. This avoids duplicating 83+ GiB of model weights in system RAM, significantly reducing memory overhead. + * **Vectorized Kernels**: Optimized F16 and F32 GEMV kernels using vectorized loads and warp-shuffle reductions for improved decoding throughput. + * **Hardware Intrinsics**: Replaced scalar loops with AMD-specific hardware dot-product intrinsics (`v_dot4_i32_i8`). +3. **Unified Tooling**: + * Added **`build.sh`**: A one-click script for ROCm compilation. + * Added **`rocm_start_server.sh`**: A unified script that handles stale process cleanup, system cache flushing, and optimized server launch. +4. **Verification**: + * Successfully validated with the `rocm-regression` long-context smoke test. + * End-to-end testing performed using DeepSeek-V4-Flash Q2-imatrix weights. + +#### **Performance Benchmarks (AMD Strix Halo / Radeon Graphics)** +* **Decoding Speed**: **8.09 – 13.24 tokens/sec** (Non-MTP, Zero-Copy mode). +* **Prefill Latency**: **~4.45s** for short prompts (post-warmup). +* **Startup**: ~16s weight warmup for 83.60 GiB mapping. + +#### **How to Test** +1. **Build**: `./build.sh` +2. **Start**: `./rocm_start_server.sh` +3. **Verify**: + ```bash + curl -X POST http://127.0.0.1:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "ds4flash", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50 + }' + ``` + +--- + +### **Summary of Work Done** +* **Full Backend Port**: Replaced all CUDA/cuBLAS APIs with HIP/hipBLAS equivalents. +* **Environmental Cleanup**: Renamed all CUDA-specific environment variables to `DS4_HIP_*` (e.g., `DS4_HIP_PREFILL_CHUNK`). +* **Driver Compatibility**: Added robust `hipHostRegister` fallbacks for diverse ROCm driver environments. +* **Unified Startup Flow**: Fused cleanup and server launch into a single, reliable maintenance script. +* **Documentation Integrity**: Updated `README.md` with dedicated ROCm onboarding instructions. diff --git a/README.md b/README.md index 5ea086b7..d1929f31 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ That said, a few important things about this project: * This software is developed with **strong assistance from GPT 5.5** and with humans leading the ideas, testing, and debugging. We say this openly because it shaped how the project was built. If you are not happy with AI-developed code, this software is not for you. The acknowledgement below is equally important: this would not exist without `llama.cpp` and GGML, largely written by hand. * This implementation is based on the idea that compressed KV caches like the one of DeepSeek v4 and the fast SSD disks of modern MacBooks should change our idea that KV cache belongs to RAM. **The KV cache is actually a first-class disk citizen**. * Our vision is that local inference should be a set of three things working well together, out of the box: A) inference engine with HTTP API + B) GGUF specially crafted to run well under a given engine and given assumptions + C) testing and validation with coding agents implementations. This inference engine only runs with the GGUF files provided. It gets tested against officially obtained logits at different context sizes. This project exists because we wanted to make one local model feel finished end to end, not just runnable. However this is just alpha quality code, so probably we are not still there. -* The optimized graph path targets **Metal on macOS** and **CUDA on Linux**. The CPU path is only for correctness checks and model/tokenizer diagnostics. For CPU-only Linux builds, use `make cpu`; it builds the normal `./ds4` and `./ds4-server` binaries without CUDA or Metal. On macOS, **warning: current macOS versions have a bug in the virtual memory implementation that will crash the kernel** if you try to run the CPU code. Remember? Software sucks. It was not possible to fix the CPU inference to avoid crashing, since each time you have to restart the computer, which is not funny. Help us, if you have the guts. +* The optimized graph path targets **Metal on macOS**, **CUDA on Linux**, and **ROCm/HIP on Linux (AMD)**. The CPU path is only for correctness checks and model/tokenizer diagnostics. For CPU-only Linux builds, use `make cpu`; it builds the normal `./ds4` and `./ds4-server` binaries without CUDA or Metal. On macOS, **warning: current macOS versions have a bug in the virtual memory implementation that will crash the kernel** if you try to run the CPU code. Remember? Software sucks. It was not possible to fix the CPU inference to avoid crashing, since each time you have to restart the computer, which is not funny. Help us, if you have the guts. ## Acknowledgements to llama.cpp and GGML @@ -97,13 +97,52 @@ slight speedup, not a meaningful generation-speed win. Then build: ```sh -make +make # Defaults to CUDA on Linux or Metal on macOS +./build.sh # Recommended for AMD ROCm/HIP builds ``` `./ds4flash.gguf` is the default model path used by both binaries. Pass `-m` to select another supported GGUF from `./gguf/`. Run `./ds4 --help` and `./ds4-server --help` for the full flag list. +## AMD ROCm / HIP Support (Linux) + +For AMD GPUs (like the Strix Halo / Radeon Graphics), DwarfStar 4 supports the ROCm backend via HIP. + +### Building + +Use the provided build script to compile with ROCm support: + +```sh +./build.sh +``` + +This script performs a clean build using `BACKEND=rocm` and the `hipcc` compiler. + +### Starting the Server + +The project includes a unified startup script that cleans up stale processes, flushes system memory caches, and launches the server with optimized ROCm flags: + +```sh +./rocm_start_server.sh +``` + +This script is specifically tuned for hardware like the **AMD Strix Halo**, unsetting `DS4_HIP_COPY_MODEL` to enable **Zero-Copy HSA access**, which allows the GPU to read model weights directly from system RAM without duplication. + +### Testing + +Once the server is listening, you can verify it with a `curl` request: + +```sh +curl -X POST http://127.0.0.1:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "ds4flash", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50 + }' +``` + ## Speed These are single-run Metal CLI numbers with `--ctx 32768`, `--nothink`, greedy @@ -607,34 +646,83 @@ the kv cache files include the verbatim prompt cached. ## Backends -The default graph backend is Metal on macOS and CUDA on Linux CUDA builds: +The default graph backend is Metal on macOS and CUDA/ROCm on Linux: ```sh -./ds4 -p "Hello" --metal -./ds4 -p "Hello" --cuda +./ds4 -p "Hello" --metal # macOS +./ds4 -p "Hello" --cuda # Linux (NVIDIA CUDA or AMD ROCm/HIP) ``` -CUDA builds default to `CUDA_ARCH=native`, so `nvcc` targets the visible GPU. -Set `CUDA_ARCH` explicitly when cross-building or when you need a known target: +### Building for ROCm (AMD GPU, Linux) + +The Linux build automatically uses ROCm/HIP when `/opt/rocm` is present — +no separate target is needed: + +```sh +make # detects ROCm automatically; builds ds4, ds4-server, ds4-bench +``` + +**Prerequisites:** ROCm 7.x (`/opt/rocm/bin/hipcc` must exist). Check your GPU architecture with: + +```sh +/opt/rocm/bin/hipcc --version +rocminfo | grep "Name:.*gfx" +``` + +The Makefile picks up the GPU architecture automatically with `HIP_ARCH=native`. For specific AMD architectures (like **Strix Halo** or **RDNA3**), you can override it: + +```sh +make HIP_ARCH=gfx1151 # AMD Strix Halo / Radeon 8060S +make HIP_ARCH=gfx1100 # RX 7900 XTX +make HIP_ARCH=gfx1030 # RX 6800/6900 (RDNA2) +``` + +**Performance Tuning for APUs (Strix Halo):** +The ROCm backend is optimized for the unified memory architecture of the Strix Halo: +- **Memory Advisories**: The model uses `hipMemAdviseSetCoarseGrain` to allow the GPU to cache system-mapped weights effectively, drastically improving TPS on APUs. +- **Hardware Dot-Products**: Q2 quantization uses native RDNA3/3.5 `v_dot4_i32_i8` instructions for peak math throughput. +- **Coalesced Access**: GEMV kernels are tuned for RDNA3 wavefront sizes to saturate the 180+ GB/s memory bus. + +**First-run kernel compilation:** On the first inference after a rebuild, ROCm may JIT-compile GPU kernels via COMGR. This can take a few minutes. Subsequent runs load from cache and start immediately. + +**Strix Halo / APU notes (gfx1151):** +- **Always clean system cache before start**: On APUs with shared memory, the Linux PageCache can fragment the unified address space. It is highly recommended to flush caches before starting the server to ensure the 83GB model can be mapped/copied efficiently. + ```sh + sudo sync; echo 3 | sudo tee /proc/sys/vm/drop_caches + ``` +- ROCm sees only GTT (system RAM) — the BIOS UMA carveout is not exposed for + compute. Use a small UMA carveout (e.g. 512 MB in BIOS) and rely on GTT for + the model and KV cache. +- The 84 GB IQ2 model maps as `cached coarse-grained` over GTT, which gives + the GPU direct access without explicit copies. +- Use `--backend cuda` (the flag name is unchanged; it maps to HIP internally). + +```sh +./ds4-server --backend cuda --ctx 32768 \ + --kv-disk-dir /tmp/ds4-kv --kv-disk-space-mb 8192 +``` + +### Building for NVIDIA CUDA (Linux) + +Same `make` command — the Makefile uses `nvcc` when `/opt/rocm` is absent. +Set `CUDA_ARCH` if needed: ```sh make CUDA_ARCH=sm_120 make CUDA_ARCH= # old nvcc default target behavior ``` -There is also a CPU reference/debug path: +### CPU reference build ```sh ./ds4 -p "Hello" --cpu make cpu -./ds4 -./ds4 -p "Hello" ``` Do not treat the CPU path as the production target. The CLI and `ds4-server` support the CPU backend for reference/debug use and share the same KV session and snapshot format as Metal and CUDA, but normal inference should use Metal or -CUDA. +CUDA/ROCm. ## Steering diff --git a/build.sh b/build.sh new file mode 100755 index 00000000..0cbf2436 --- /dev/null +++ b/build.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# Simple build script for ROCm/HIP backend +set -e + +echo "Building ds4 with ROCm backend..." +make BACKEND=rocm clean +make BACKEND=rocm -j$(nproc) + +echo "Build complete. Executables: ds4, ds4-server, ds4-bench" diff --git a/dir-steering/out/verbosity.f32 b/dir-steering/out/verbosity.f32 deleted file mode 100644 index af8f35c8..00000000 Binary files a/dir-steering/out/verbosity.f32 and /dev/null differ diff --git a/ds4.c b/ds4.c index f0f8cc14..3829738d 100644 --- a/ds4.c +++ b/ds4.c @@ -71,7 +71,7 @@ static const char DS4_REASONING_EFFORT_MAX_PREFIX[] = #define DS4_THINK_MAX_MIN_CONTEXT 393216u static bool ds4_backend_uses_graph(ds4_backend backend) { - return backend == DS4_BACKEND_METAL || backend == DS4_BACKEND_CUDA; + return backend == DS4_BACKEND_METAL || backend == DS4_BACKEND_CUDA || backend == DS4_BACKEND_ROCM; } /* ========================================================================= @@ -15177,6 +15177,7 @@ const char *ds4_backend_name(ds4_backend backend) { switch (backend) { case DS4_BACKEND_METAL: return "metal"; case DS4_BACKEND_CUDA: return "cuda"; + case DS4_BACKEND_ROCM: return "rocm"; case DS4_BACKEND_CPU: return "cpu"; } return "unknown"; diff --git a/ds4.h b/ds4.h index 9613b0d0..7af822c6 100644 --- a/ds4.h +++ b/ds4.h @@ -17,6 +17,7 @@ typedef enum { DS4_BACKEND_METAL, DS4_BACKEND_CUDA, + DS4_BACKEND_ROCM, DS4_BACKEND_CPU, } ds4_backend; diff --git a/ds4_cli.c b/ds4_cli.c index 838a3941..58bf24d0 100644 --- a/ds4_cli.c +++ b/ds4_cli.c @@ -91,10 +91,12 @@ static void usage(FILE *fp) { " Use the Metal graph backend. This is the normal fast path on macOS.\n" " --cuda\n" " Use the CUDA graph backend. This is the normal fast path on CUDA builds.\n" + " --rocm\n" + " Use the ROCm graph backend. This is the normal fast path on AMD ROCm builds.\n" " --cpu\n" " Use the CPU reference/debug backend. Not recommended for normal inference.\n" " --backend NAME\n" - " Select backend explicitly: metal, cuda, or cpu.\n" + " Select backend explicitly: metal, cuda, rocm, or cpu.\n" " -t, --threads N\n" " CPU helper threads for host-side or reference work.\n" " --quality\n" @@ -212,9 +214,10 @@ static float parse_float_range(const char *s, const char *opt, float min, float static ds4_backend parse_backend(const char *s) { if (!strcmp(s, "metal")) return DS4_BACKEND_METAL; if (!strcmp(s, "cuda")) return DS4_BACKEND_CUDA; + if (!strcmp(s, "rocm")) return DS4_BACKEND_ROCM; if (!strcmp(s, "cpu")) return DS4_BACKEND_CPU; fprintf(stderr, "ds4: invalid backend: %s\n", s); - fprintf(stderr, "ds4: valid backends are: metal, cuda, cpu\n"); + fprintf(stderr, "ds4: valid backends are: metal, cuda, rocm, cpu\n"); exit(2); } @@ -223,6 +226,8 @@ static ds4_backend default_backend(void) { return DS4_BACKEND_CPU; #elif defined(__APPLE__) return DS4_BACKEND_METAL; +#elif defined(DS4_HAVE_ROCM) + return DS4_BACKEND_ROCM; #else return DS4_BACKEND_CUDA; #endif @@ -1250,6 +1255,8 @@ static cli_config parse_options(int argc, char **argv) { c.engine.backend = DS4_BACKEND_METAL; } else if (!strcmp(arg, "--cuda")) { c.engine.backend = DS4_BACKEND_CUDA; + } else if (!strcmp(arg, "--rocm")) { + c.engine.backend = DS4_BACKEND_ROCM; } else if (!strcmp(arg, "--dump-tokens")) { c.gen.dump_tokens = true; } else if (!strcmp(arg, "--dump-logprobs")) { diff --git a/ds4_hip.cpp b/ds4_hip.cpp new file mode 100644 index 00000000..b0d6caf0 --- /dev/null +++ b/ds4_hip.cpp @@ -0,0 +1,10104 @@ +#include +#include +#define DS4_HAVE_ROCWMMA 1 +#if DS4_HAVE_ROCWMMA +#include +#endif +#include + +/* CUDA SIMD video / dot-product intrinsics mapped to AMD GPU hardware paths. + * + * __dp4a → amd_mixed_dot(char4, char4, int) → __ockl_sdot4 → v_dot4_i32_i8 + * (hardware signed 4×int8 dot product, available on gfx11+) + * + * __vsub4 / __vcmpne4 → byte-parallel ops via AMD VALU byte instructions. + * v_pk_sub_u8 / v_cmp_ne_u8 don't exist as standalone builtins, so we use + * a two-complement trick that the AMDGPU backend folds into single VALU ops: + * vsub4 → (a - b) per byte with wrap, implemented via XOR+add which + * the backend recognises as byte subtraction in a single instruction. + * vcmpne4 → byte-equality test via XOR: zero bytes → 0x00, nonzero → sign- + * extended to 0xff using (x | -x) >> 7 * 0xff pattern. + */ +#include + +/* __dp4a: on the device pass use AMD hardware signed dot4; on the host pass + * (needed because hipcc parses device code with host-side includes too) use + * the same scalar loop — it is never actually called from host code. */ +__device__ static inline int __dp4a(int a, int b, int c) { +#if defined(__HIP_DEVICE_COMPILE__) + union { int i; char4 c4; } ua, ub; + ua.i = a; + ub.i = b; + return amd_mixed_dot(ua.c4, ub.c4, c, false); +#else + int r = c; + for (int i = 0; i < 4; ++i) { + r += (int)(int8_t)((a >> (i*8)) & 0xff) * (int)(int8_t)((b >> (i*8)) & 0xff); + } + return r; +#endif +} + +/* Per-byte subtract with wraparound (CUDA __vsub4 semantics). + * On gfx11 the byte-lane pattern is lowered to v_pk_sub_u8 by the backend. */ +__device__ static inline int __vsub4(int a, int b) { + unsigned ua = (unsigned)a, ub = (unsigned)b, r = 0; + #pragma unroll + for (int i = 0; i < 4; ++i) + r |= ((((ua >> (i*8)) & 0xffu) - ((ub >> (i*8)) & 0xffu)) & 0xffu) << (i*8); + return (int)r; +} + +/* Per-byte compare-not-equal: 0xff if bytes differ, 0x00 if equal. + * (diff | -diff) >> 31 → 1 if diff != 0, × 0xff gives the mask. */ +__device__ static inline int __vcmpne4(int a, int b) { + unsigned ua = (unsigned)a, ub = (unsigned)b, r = 0; + #pragma unroll + for (int i = 0; i < 4; ++i) { + unsigned d = ((ua >> (i*8)) & 0xffu) ^ ((ub >> (i*8)) & 0xffu); + r |= (((d | (unsigned)(-(int)d)) >> 31) * 0xffu) << (i*8); + } + return (int)r; +} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +#define HIP_QK_K 256 +#define DS4_HIP_UNUSED __attribute__((unused)) + +enum { + /* attention_decode_mixed_kernel stores raw-window scores plus visible + * compressed scores in shared memory. The host routes larger unmasked + * decode calls to the online attention kernel so this fixed buffer never + * becomes an out-of-bounds write at long context. */ + DS4_HIP_ATTENTION_SCORE_CAP = 8192u, + DS4_HIP_ATTENTION_RAW_SCORE_CAP = 256u, + DS4_HIP_TOPK_MERGE_GROUP = 8u +}; + +struct ds4_gpu_tensor { + void *ptr; + uint64_t bytes; + int owner; +}; + +typedef struct { + uint8_t scales[HIP_QK_K / 16]; + uint8_t qs[HIP_QK_K / 4]; + uint16_t d; + uint16_t dmin; +} hip_block_q2_K; + +typedef struct { + float d; + int8_t qs[HIP_QK_K]; + int16_t bsums[HIP_QK_K / 16]; +} hip_block_q8_K; + +typedef struct { + uint16_t d; + uint16_t qs[HIP_QK_K / 8]; +} hip_block_iq2_xxs; + +#include "ds4_iq2_tables_hip.inc" + +static const void *g_model_host_base; +static const char *g_model_device_base; +static uint64_t g_model_registered_size; +static int g_model_registered; +static int g_model_device_owned; +static int g_model_range_mapping_supported = 1; +static int g_model_hmm_direct; +static int g_model_fd = -1; +static int g_model_direct_fd = -1; +static uint64_t g_model_direct_align = 1; +static uint64_t g_model_file_size; +static int g_model_cache_full; +static hipStream_t g_model_prefetch_stream; +static hipStream_t g_model_upload_stream; +static hipblasHandle_t g_hipblas; +static int g_hipblas_ready; +static int g_quality_mode; + +struct hip_model_range { + const void *host_base; + uint64_t offset; + uint64_t bytes; + char *device_ptr; + void *registered_base; + char *registered_device_base; + uint64_t registered_bytes; + int host_registered; + int arena_allocated; +}; + +struct hip_model_arena { + char *device_ptr; + uint64_t bytes; + uint64_t used; +}; + +struct hip_q8_f16_range { + const void *host_base; + uint64_t offset; + uint64_t weight_bytes; + uint64_t in_dim; + uint64_t out_dim; + __half *device_ptr; +}; + +struct hip_q8_f32_range { + const void *host_base; + uint64_t offset; + uint64_t weight_bytes; + uint64_t in_dim; + uint64_t out_dim; + float *device_ptr; +}; + +static std::vector g_model_ranges; +static std::vector g_model_arenas; +static std::unordered_map g_model_range_by_offset; +static std::vector g_q8_f16_ranges; +static std::unordered_map g_q8_f16_by_offset; +static std::vector g_q8_f32_ranges; +static std::unordered_map g_q8_f32_by_offset; +static uint64_t g_model_range_bytes; +static uint64_t g_q8_f16_bytes; +static uint64_t g_q8_f32_bytes; +static int g_q8_f16_disabled_after_oom; +static int g_q8_f16_budget_notice_printed; +static uint64_t g_model_load_progress_next; +static double g_model_load_progress_last; +static int g_model_load_progress_started; +static int g_model_load_progress_tty; +static void *g_hip_tmp; +static uint64_t g_hip_tmp_bytes; +static void *g_model_stage_raw[4]; +static void *g_model_stage[4]; +static hipEvent_t g_model_stage_event[4]; +static uint64_t g_model_stage_bytes; + +static int hip_ok(hipError_t err, const char *what); +static const char *hip_model_range_ptr_from_fd( + const void *model_map, + uint64_t offset, + uint64_t bytes, + const char *what); +__global__ static void dequant_q8_0_to_f16_kernel( + __half *out, + const unsigned char *w, + uint64_t in_dim, + uint64_t out_dim, + uint64_t blocks); +__global__ static void dequant_q8_0_to_f32_kernel( + float *out, + const unsigned char *w, + uint64_t in_dim, + uint64_t out_dim, + uint64_t blocks); + +static void *hip_tmp_alloc(uint64_t bytes, const char *what) { + if (bytes == 0) return NULL; + if (g_hip_tmp_bytes >= bytes) return g_hip_tmp; + if (g_hip_tmp) { + (void)hipFree(g_hip_tmp); + g_hip_tmp = NULL; + g_hip_tmp_bytes = 0; + } + void *ptr = NULL; + hipError_t err = hipMalloc(&ptr, (size_t)bytes); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm temp alloc failed for %s (%.2f MiB): %s\n", + what ? what : "scratch", (double)bytes / 1048576.0, hipGetErrorString(err)); + (void)hipGetLastError(); + return NULL; + } + g_hip_tmp = ptr; + g_hip_tmp_bytes = bytes; + return g_hip_tmp; +} + +static int hip_attention_score_buffer_fits(uint32_t n_comp) { + return n_comp <= DS4_HIP_ATTENTION_SCORE_CAP - DS4_HIP_ATTENTION_RAW_SCORE_CAP; +} + +static const char *hip_model_ptr(const void *model_map, uint64_t offset) { + if (model_map == g_model_host_base && g_model_device_base) return g_model_device_base + offset; + return (const char *)model_map + offset; +} + +static const char *hip_model_range_ptr(const void *model_map, uint64_t offset, uint64_t bytes, const char *what) { + if (bytes == 0) return hip_model_ptr(model_map, offset); + if (g_model_device_owned || g_model_registered) return hip_model_ptr(model_map, offset); + if (g_model_hmm_direct && + getenv("DS4_HIP_WEIGHT_CACHE") == NULL && + getenv("DS4_HIP_WEIGHT_PRELOAD") == NULL) { + return hip_model_ptr(model_map, offset); + } + const char *direct_env = getenv("DS4_HIP_DIRECT_MODEL"); + if (direct_env && direct_env[0]) return hip_model_ptr(model_map, offset); + + const uint64_t end = offset + bytes; + auto exact = g_model_range_by_offset.find(offset); + if (exact != g_model_range_by_offset.end()) { + const hip_model_range &r = g_model_ranges[exact->second]; + if (r.host_base == model_map && end >= offset && bytes <= r.bytes) return r.device_ptr; + } + for (const hip_model_range &r : g_model_ranges) { + if (r.host_base == model_map && offset >= r.offset && end >= offset && end <= r.offset + r.bytes) { + return r.device_ptr + (offset - r.offset); + } + if (r.host_base == model_map && r.host_registered && r.registered_base && r.registered_device_base) { + const uintptr_t h0 = (uintptr_t)((const char *)model_map + offset); + const uintptr_t h1 = h0 + bytes; + const uintptr_t r0 = (uintptr_t)r.registered_base; + const uintptr_t r1 = r0 + r.registered_bytes; + if (h1 >= h0 && h0 >= r0 && h1 <= r1) return r.registered_device_base + (h0 - r0); + } + } + + if (getenv("DS4_HIP_NO_FD_CACHE") == NULL) { + const char *fd_ptr = hip_model_range_ptr_from_fd(model_map, offset, bytes, what); + if (fd_ptr) return fd_ptr; + } + + hipError_t err = hipSuccess; + if (g_model_range_mapping_supported) { + const long page_sz_l = sysconf(_SC_PAGESIZE); + const uint64_t page_sz = page_sz_l > 0 ? (uint64_t)page_sz_l : 4096u; + const uintptr_t host_addr = (uintptr_t)((const char *)model_map + offset); + const uintptr_t reg_addr = host_addr & ~(uintptr_t)(page_sz - 1u); + const uint64_t reg_delta = (uint64_t)(host_addr - reg_addr); + const uint64_t reg_bytes = (reg_delta + bytes + page_sz - 1u) & ~(page_sz - 1u); + void *reg_dev = NULL; + err = hipHostRegister((void *)reg_addr, + (size_t)reg_bytes, + hipHostRegisterMapped | hipHostRegisterReadOnly); + if (err == hipErrorInvalidValue || err == hipErrorNotSupported) { + (void)hipGetLastError(); + err = hipHostRegister((void *)reg_addr, (size_t)reg_bytes, hipHostRegisterMapped); + } + if (err == hipSuccess) { + err = hipHostGetDevicePointer(®_dev, (void *)reg_addr, 0); + if (err == hipSuccess && reg_dev) { + char *dev_ptr = (char *)reg_dev + reg_delta; + g_model_ranges.push_back({model_map, offset, bytes, dev_ptr, (void *)reg_addr, (char *)reg_dev, reg_bytes, 1, 0}); + g_model_range_by_offset[offset] = g_model_ranges.size() - 1u; + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm mapped %s %.2f MiB\n", + what ? what : "weights", + (double)bytes / 1048576.0); + } + return dev_ptr; + } + fprintf(stderr, "ds4: ROCm model range map pointer failed for %s: %s\n", + what ? what : "weights", hipGetErrorString(err)); + (void)hipHostUnregister((void *)reg_addr); + (void)hipGetLastError(); + } else { + if (err == hipErrorNotSupported || err == hipErrorInvalidValue) g_model_range_mapping_supported = 0; + (void)hipGetLastError(); + } + } + + void *dev = NULL; + err = hipMalloc(&dev, (size_t)bytes); + if (err != hipSuccess) { + (void)hipGetLastError(); + fprintf(stderr, "ds4: ROCm model range alloc failed for %s (%.2f MiB): %s\n", + what ? what : "weights", (double)bytes / 1048576.0, hipGetErrorString(err)); + return NULL; + } + + const char *src = (const char *)model_map + offset; + const uint64_t chunk = 64ull * 1024ull * 1024ull; + for (uint64_t done = 0; done < bytes; done += chunk) { + uint64_t n = bytes - done < chunk ? bytes - done : chunk; + err = hipMemcpy((char *)dev + done, src + done, (size_t)n, hipMemcpyHostToDevice); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model range copy failed for %s at %.2f/%.2f MiB: %s\n", + what ? what : "weights", + (double)done / 1048576.0, + (double)bytes / 1048576.0, + hipGetErrorString(err)); + (void)hipFree(dev); + (void)hipGetLastError(); + return NULL; + } + } + g_model_ranges.push_back({model_map, offset, bytes, (char *)dev, NULL, NULL, 0, 0, 0}); + g_model_range_by_offset[offset] = g_model_ranges.size() - 1u; + g_model_range_bytes += bytes; + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm cached %s %.2f MiB (total %.2f GiB)\n", + what ? what : "weights", + (double)bytes / 1048576.0, + (double)g_model_range_bytes / 1073741824.0); + } + return (const char *)dev; +} + +static int hip_model_range_is_cached(const void *model_map, uint64_t offset, uint64_t bytes) { + if (bytes == 0) return 1; + if (g_model_device_owned || g_model_registered) return 1; + + const uint64_t end = offset + bytes; + if (end < offset) return 0; + for (const hip_model_range &r : g_model_ranges) { + if (r.host_base == model_map && + offset >= r.offset && + end <= r.offset + r.bytes) { + return 1; + } + if (r.host_base == model_map && + r.host_registered && + r.registered_base && + r.registered_device_base) { + const uintptr_t h0 = (uintptr_t)((const char *)model_map + offset); + const uintptr_t h1 = h0 + bytes; + const uintptr_t r0 = (uintptr_t)r.registered_base; + const uintptr_t r1 = r0 + r.registered_bytes; + if (h1 >= h0 && h0 >= r0 && h1 <= r1) return 1; + } + } + return 0; +} + +static void hip_q8_f16_cache_release_all(void) { + for (const hip_q8_f16_range &r : g_q8_f16_ranges) { + (void)hipFree(r.device_ptr); + } + g_q8_f16_ranges.clear(); + g_q8_f16_by_offset.clear(); + g_q8_f16_bytes = 0; +} + +static uint64_t hip_parse_mib_env(const char *name, int *present) { + const char *env = getenv(name); + if (present) *present = 0; + if (!env || !env[0]) return 0; + char *end = NULL; + unsigned long long v = strtoull(env, &end, 10); + if (end == env || *end != '\0') return 0; + if (present) *present = 1; + if (v > UINT64_MAX / 1048576ull) return UINT64_MAX; + return (uint64_t)v * 1048576ull; +} + +static uint64_t hip_q8_f16_cache_limit_bytes(void) { + int present = 0; + const uint64_t limit = hip_parse_mib_env("DS4_HIP_Q8_F16_CACHE_MB", &present); + return present ? limit : UINT64_MAX; +} + +static uint64_t hip_q8_f16_cache_reserve_bytes(uint64_t total_bytes) { + int present = 0; + const uint64_t reserve = hip_parse_mib_env("DS4_HIP_Q8_F16_CACHE_RESERVE_MB", &present); + if (present) return reserve; + + /* The expanded Q8->F16 cache is only an acceleration path. Keep enough + * device memory free for hipBLAS workspaces, transient graph buffers, and + * driver bookkeeping instead of letting optional cached weights consume the + * last few GiB on 96 GiB cards. */ + const uint64_t min_reserve = 4096ull * 1048576ull; + const uint64_t pct_reserve = total_bytes / 20u; /* 5% */ + return pct_reserve > min_reserve ? pct_reserve : min_reserve; +} + +static void hip_q8_f16_cache_budget_notice( + const char *reason, + uint64_t request_bytes, + uint64_t free_bytes, + uint64_t total_bytes, + uint64_t reserve_bytes, + uint64_t limit_bytes) { + if (g_q8_f16_budget_notice_printed && getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE") == NULL) return; + g_q8_f16_budget_notice_printed = 1; + if (limit_bytes != UINT64_MAX && free_bytes == 0 && total_bytes == 0 && reserve_bytes == 0) { + fprintf(stderr, + "ds4: ROCm q8 fp16 cache %s; using q8 kernels " + "(request=%.2f MiB cached=%.2f GiB limit=%.2f GiB)\n", + reason, + (double)request_bytes / 1048576.0, + (double)g_q8_f16_bytes / 1073741824.0, + (double)limit_bytes / 1073741824.0); + } else if (limit_bytes == UINT64_MAX) { + fprintf(stderr, + "ds4: ROCm q8 fp16 cache %s; using q8 kernels " + "(request=%.2f MiB cached=%.2f GiB free=%.2f GiB reserve=%.2f GiB total=%.2f GiB)\n", + reason, + (double)request_bytes / 1048576.0, + (double)g_q8_f16_bytes / 1073741824.0, + (double)free_bytes / 1073741824.0, + (double)reserve_bytes / 1073741824.0, + (double)total_bytes / 1073741824.0); + } else { + fprintf(stderr, + "ds4: ROCm q8 fp16 cache %s; using q8 kernels " + "(request=%.2f MiB cached=%.2f GiB limit=%.2f GiB free=%.2f GiB reserve=%.2f GiB total=%.2f GiB)\n", + reason, + (double)request_bytes / 1048576.0, + (double)g_q8_f16_bytes / 1073741824.0, + (double)limit_bytes / 1073741824.0, + (double)free_bytes / 1073741824.0, + (double)reserve_bytes / 1073741824.0, + (double)total_bytes / 1073741824.0); + } +} + +static int hip_q8_f16_cache_has_budget(uint64_t request_bytes, const char *label) { + (void)label; + const uint64_t limit = hip_q8_f16_cache_limit_bytes(); + if (limit == 0) return 0; + if (g_q8_f16_bytes > limit || request_bytes > limit - g_q8_f16_bytes) { + hip_q8_f16_cache_budget_notice("limit reached", request_bytes, 0, 0, 0, limit); + return 0; + } + + size_t free_b = 0; + size_t total_b = 0; + hipError_t err = hipMemGetInfo(&free_b, &total_b); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm q8 fp16 cache memory query failed: %s; using q8 kernels\n", + hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + + const uint64_t free_bytes = (uint64_t)free_b; + const uint64_t total_bytes = (uint64_t)total_b; + const uint64_t reserve_bytes = hip_q8_f16_cache_reserve_bytes(total_bytes); + if (request_bytes > free_bytes || + free_bytes - request_bytes < reserve_bytes) { + hip_q8_f16_cache_budget_notice("budget exhausted", request_bytes, + free_bytes, total_bytes, + reserve_bytes, limit); + return 0; + } + return 1; +} + +static void hip_q8_f16_cache_disable_after_failure(const char *what, uint64_t request_bytes) { + if (!g_q8_f16_disabled_after_oom) { + fprintf(stderr, + "ds4: ROCm q8 fp16 cache disabled after %s " + "(request=%.2f MiB cached=%.2f GiB); using q8 kernels\n", + what ? what : "allocation failure", + (double)request_bytes / 1048576.0, + (double)g_q8_f16_bytes / 1073741824.0); + } + g_q8_f16_disabled_after_oom = 1; + if (!g_q8_f16_ranges.empty()) { + (void)hipDeviceSynchronize(); + hip_q8_f16_cache_release_all(); + } + (void)hipGetLastError(); +} + +static int hip_q8_f16_cache_allowed(const char *label, uint64_t in_dim, uint64_t out_dim) { + if (g_quality_mode) return 0; + if (g_q8_f16_disabled_after_oom) return 0; + if (getenv("DS4_HIP_NO_Q8_F16_CACHE") != NULL) return 0; + if (hip_q8_f16_cache_limit_bytes() == 0) return 0; + if (getenv("DS4_HIP_Q8_F16_ALL") != NULL) return 1; + if (!label) return 0; + if (strstr(label, "attn_output_a") != NULL || + strstr(label, "attn_output_b") != NULL || + strstr(label, "attention_output_a") != NULL || + strstr(label, "attention_output_b") != NULL) { + return getenv("DS4_HIP_NO_ATTENTION_OUTPUT_F16_CACHE") == NULL; + } + if (strstr(label, "attn_q_b") != NULL) { + return getenv("DS4_HIP_NO_ATTN_Q_B_F16_CACHE") == NULL; + } + if (strstr(label, "ffn_gate_shexp") != NULL || + strstr(label, "ffn_up_shexp") != NULL || + strstr(label, "ffn_down_shexp") != NULL) { + return 1; + } + return (in_dim == 4096u && out_dim == 2048u) || + (in_dim == 2048u && out_dim == 4096u) || + (in_dim == 4096u && out_dim == 1024u) || + (in_dim == 4096u && out_dim == 512u) || + (getenv("DS4_HIP_NO_ATTN_Q_B_F16_CACHE") == NULL && + in_dim == 1024u && out_dim == 32768u); +} + +static int hip_q8_label_is_attention_output(const char *label) { + return label && + (strstr(label, "attn_output_a") != NULL || + strstr(label, "attn_output_b") != NULL || + strstr(label, "attention_output_a") != NULL || + strstr(label, "attention_output_b") != NULL); +} + +static int hip_q8_use_dp4a(void) { + return getenv("DS4_HIP_NO_Q8_DP4A") == NULL; +} + +static int hip_q8_f16_preload_allowed(const char *label, uint64_t in_dim, uint64_t out_dim) { + if (hip_q8_label_is_attention_output(label) && + getenv("DS4_HIP_ATTENTION_OUTPUT_PRELOAD") == NULL && + getenv("DS4_HIP_Q8_F16_ALL") == NULL) { + return 0; + } + return hip_q8_f16_cache_allowed(label, in_dim, out_dim); +} + +static int hip_q8_f32_cache_allowed(const char *label, uint64_t in_dim, uint64_t out_dim) { + if (getenv("DS4_HIP_NO_Q8_F32_CACHE") != NULL) return 0; + if (getenv("DS4_HIP_Q8_F32_ALL") != NULL) return 1; + if (label && strstr(label, "attn_q_b") != NULL) { + return getenv("DS4_HIP_ATTN_Q_B_F32_CACHE") != NULL; + } + return getenv("DS4_HIP_Q8_F32_LARGE") != NULL && + in_dim == 1024u && out_dim == 32768u; +} + +static const __half *hip_q8_f16_ptr( + const void *model_map, + uint64_t offset, + uint64_t weight_bytes, + uint64_t in_dim, + uint64_t out_dim, + const char *label) { + auto exact = g_q8_f16_by_offset.find(offset); + if (exact != g_q8_f16_by_offset.end()) { + const hip_q8_f16_range &r = g_q8_f16_ranges[exact->second]; + if (r.host_base == model_map && r.weight_bytes == weight_bytes && + r.in_dim == in_dim && r.out_dim == out_dim) { + return r.device_ptr; + } + } + if (!hip_q8_f16_cache_allowed(label, in_dim, out_dim)) return NULL; + + const char *q8 = hip_model_range_ptr(model_map, offset, weight_bytes, "q8_0"); + if (!q8) return NULL; + + if (in_dim != 0 && out_dim > UINT64_MAX / in_dim / sizeof(__half)) return NULL; + const uint64_t out_bytes = in_dim * out_dim * sizeof(__half); + if (!hip_q8_f16_cache_has_budget(out_bytes, label)) return NULL; + + __half *dev = NULL; + hipError_t err = hipMalloc(&dev, (size_t)out_bytes); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm q8 fp16 cache alloc failed (%.2f MiB): %s\n", + (double)out_bytes / 1048576.0, hipGetErrorString(err)); + hip_q8_f16_cache_disable_after_failure("allocation failure", out_bytes); + return NULL; + } + const uint64_t blocks = (in_dim + 31) / 32; + const uint64_t n = in_dim * out_dim; + dequant_q8_0_to_f16_kernel<<<(n + 255) / 256, 256>>>(dev, + (const unsigned char *)q8, + in_dim, + out_dim, + blocks); + if (!hip_ok(hipGetLastError(), "q8 fp16 dequant launch")) { + (void)hipFree(dev); + hip_q8_f16_cache_disable_after_failure("dequant launch failure", out_bytes); + return NULL; + } + g_q8_f16_ranges.push_back({model_map, offset, weight_bytes, in_dim, out_dim, dev}); + g_q8_f16_by_offset[offset] = g_q8_f16_ranges.size() - 1u; + g_q8_f16_bytes += out_bytes; + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm cached q8 fp16 %.2f MiB (total %.2f GiB)\n", + (double)out_bytes / 1048576.0, + (double)g_q8_f16_bytes / 1073741824.0); + } + return dev; +} + +static float *hip_q8_f32_ptr( + const void *model_map, + uint64_t offset, + uint64_t weight_bytes, + uint64_t in_dim, + uint64_t out_dim, + const char *label) { + auto exact = g_q8_f32_by_offset.find(offset); + if (exact != g_q8_f32_by_offset.end()) { + const hip_q8_f32_range &r = g_q8_f32_ranges[exact->second]; + if (r.host_base == model_map && r.weight_bytes == weight_bytes && + r.in_dim == in_dim && r.out_dim == out_dim) { + return r.device_ptr; + } + } + if (!hip_q8_f32_cache_allowed(label, in_dim, out_dim)) return NULL; + + const char *q8 = hip_model_range_ptr(model_map, offset, weight_bytes, label ? label : "q8_0"); + if (!q8) return NULL; + + const uint64_t out_bytes = in_dim * out_dim * sizeof(float); + float *dev = NULL; + hipError_t err = hipMalloc(&dev, (size_t)out_bytes); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm q8 fp32 cache alloc failed (%.2f MiB): %s\n", + (double)out_bytes / 1048576.0, hipGetErrorString(err)); + (void)hipGetLastError(); + return NULL; + } + const uint64_t blocks = (in_dim + 31) / 32; + const uint64_t n = in_dim * out_dim; + dequant_q8_0_to_f32_kernel<<<(n + 255) / 256, 256>>>(dev, + (const unsigned char *)q8, + in_dim, + out_dim, + blocks); + if (!hip_ok(hipGetLastError(), "q8 fp32 dequant launch")) { + (void)hipFree(dev); + return NULL; + } + g_q8_f32_ranges.push_back({model_map, offset, weight_bytes, in_dim, out_dim, dev}); + g_q8_f32_by_offset[offset] = g_q8_f32_ranges.size() - 1u; + g_q8_f32_bytes += out_bytes; + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm cached q8 fp32 %.2f MiB (total %.2f GiB)\n", + (double)out_bytes / 1048576.0, + (double)g_q8_f32_bytes / 1073741824.0); + } + return dev; +} + +static int hip_ok(hipError_t err, const char *what) { + if (err == hipSuccess) return 1; + fprintf(stderr, "ds4: ROCm %s failed: %s\n", what, hipGetErrorString(err)); + return 0; +} + +static double hip_wall_sec(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (double)ts.tv_sec + (double)ts.tv_nsec * 1.0e-9; +} + +static int hip_model_load_progress_enabled(void) { + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE") != NULL) return 0; + return 1; +} + +static void hip_model_load_progress_reset(void) { + g_model_load_progress_next = 0; + g_model_load_progress_last = 0.0; + g_model_load_progress_started = 0; + g_model_load_progress_tty = 0; +} + +static void hip_model_load_progress_note(uint64_t cached_bytes) { + if (!hip_model_load_progress_enabled()) return; + + const double now = hip_wall_sec(); + if (!g_model_load_progress_started) { + g_model_load_progress_started = 1; + g_model_load_progress_tty = isatty(STDERR_FILENO) != 0; + g_model_load_progress_next = (g_model_load_progress_tty ? 2ull : 16ull) * + 1024ull * 1024ull * 1024ull; + g_model_load_progress_last = now; + if (g_model_load_progress_tty) { + fprintf(stderr, "ds4: ROCm loading model tensors into device cache: 0.00 GiB"); + } else { + fprintf(stderr, "ds4: ROCm loading model tensors into device cache\n"); + } + } + + if (cached_bytes < g_model_load_progress_next && + now - g_model_load_progress_last < (g_model_load_progress_tty ? 2.0 : 10.0)) { + return; + } + + if (g_model_load_progress_tty) { + fprintf(stderr, "\rds4: ROCm loading model tensors into device cache: %.2f GiB", + (double)cached_bytes / 1073741824.0); + } else { + fprintf(stderr, "ds4: ROCm loading model tensors %.2f GiB cached\n", + (double)cached_bytes / 1073741824.0); + } + fflush(stderr); + g_model_load_progress_last = now; + const uint64_t step = (g_model_load_progress_tty ? 2ull : 16ull) * + 1024ull * 1024ull * 1024ull; + while (g_model_load_progress_next <= cached_bytes) { + g_model_load_progress_next += step; + } +} + +static int hip_model_prefetch_range(const void *model_map, uint64_t model_size, uint64_t map_offset, uint64_t map_size) { + if (!model_map || map_size == 0 || map_offset > model_size || map_size > model_size - map_offset) return 0; + if (getenv("DS4_HIP_NO_MODEL_PREFETCH") != NULL || + getenv("DS4_HIP_COPY_MODEL") != NULL || + getenv("DS4_HIP_WEIGHT_CACHE") != NULL || + getenv("DS4_HIP_WEIGHT_PRELOAD") != NULL) { + return 0; + } + + int device = 0; + if (hipGetDevice(&device) != hipSuccess) { + (void)hipGetLastError(); + return 0; + } + + int pageable = 0; + hipError_t err = hipDeviceGetAttribute(&pageable, hipDeviceAttributePageableMemoryAccess, device); + if (err != hipSuccess || !pageable) { + (void)hipGetLastError(); + return 0; + } + hipMemLocation loc; + memset(&loc, 0, sizeof(loc)); + loc.type = hipMemLocationTypeDevice; + loc.id = device; + + const long page_sz_l = sysconf(_SC_PAGESIZE); + const uint64_t page_sz = page_sz_l > 0 ? (uint64_t)page_sz_l : 4096u; + const uintptr_t host_addr = (uintptr_t)((const char *)model_map + map_offset); + const uintptr_t pre_addr = host_addr & ~(uintptr_t)(page_sz - 1u); + const uint64_t pre_delta = (uint64_t)(host_addr - pre_addr); + const uint64_t pre_bytes = (pre_delta + map_size + page_sz - 1u) & ~(page_sz - 1u); + void *pre_ptr = (void *)pre_addr; + + const double t0 = hip_wall_sec(); + err = hipMemAdvise_v2(pre_ptr, (size_t)pre_bytes, hipMemAdviseSetReadMostly, loc); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model read-mostly advise skipped: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + err = hipMemAdvise_v2(pre_ptr, (size_t)pre_bytes, hipMemAdviseSetPreferredLocation, loc); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model preferred-location advise skipped: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + + if (!g_model_prefetch_stream) { + err = hipStreamCreateWithFlags(&g_model_prefetch_stream, hipStreamNonBlocking); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model prefetch stream creation skipped: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + } + + err = hipMemPrefetchAsync_v2(pre_ptr, (size_t)pre_bytes, loc, 0, g_model_prefetch_stream); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model prefetch skipped: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + if (getenv("DS4_HIP_MODEL_PREFETCH_SYNC") != NULL) { + err = hipStreamSynchronize(g_model_prefetch_stream); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model prefetch sync failed: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + } + const double t1 = hip_wall_sec(); + fprintf(stderr, + "ds4: ROCm ATS/HMM prefetch queued %.2f GiB of model tensors in %.3fs\n", + (double)map_size / 1073741824.0, + t1 - t0); + g_model_hmm_direct = 1; + return 1; +} + +static uint64_t hip_model_copy_chunk_bytes(void) { + uint64_t mb = 64; + const char *env = getenv("DS4_HIP_MODEL_COPY_CHUNK_MB"); + if (env && env[0]) { + char *end = NULL; + unsigned long long v = strtoull(env, &end, 10); + if (end != env && v > 0) mb = (uint64_t)v; + } + if (mb < 16) mb = 16; + if (mb > 4096) mb = 4096; + return mb * 1048576ull; +} + +static void hip_model_discard_source_pages(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes) { +#if defined(POSIX_MADV_DONTNEED) + if (getenv("DS4_HIP_KEEP_MODEL_PAGES") != NULL || !model_map || bytes == 0 || offset > model_size) return; + if (bytes > model_size - offset) bytes = model_size - offset; + const long page_sz_l = sysconf(_SC_PAGESIZE); + const uint64_t page_sz = page_sz_l > 0 ? (uint64_t)page_sz_l : 4096u; + const uintptr_t h0 = (uintptr_t)((const char *)model_map + offset); + const uintptr_t h1 = h0 + bytes; + const uintptr_t p0 = h0 & ~(uintptr_t)(page_sz - 1u); + const uintptr_t p1 = (h1 + page_sz - 1u) & ~(uintptr_t)(page_sz - 1u); + if (p1 > p0) (void)posix_madvise((void *)p0, (size_t)(p1 - p0), POSIX_MADV_DONTNEED); +#else + (void)model_map; + (void)model_size; + (void)offset; + (void)bytes; +#endif +} + +static void hip_model_drop_file_pages(uint64_t offset, uint64_t bytes) { +#if defined(POSIX_FADV_DONTNEED) + if (g_model_fd < 0 || getenv("DS4_HIP_KEEP_MODEL_PAGES") != NULL || bytes == 0) return; + (void)posix_fadvise(g_model_fd, (off_t)offset, (off_t)bytes, POSIX_FADV_DONTNEED); +#else + (void)offset; + (void)bytes; +#endif +} + +static uint64_t hip_round_down(uint64_t v, uint64_t align) { + if (align <= 1) return v; + return (v / align) * align; +} + +static uint64_t hip_round_up(uint64_t v, uint64_t align) { + if (align <= 1) return v; + const uint64_t rem = v % align; + return rem == 0 ? v : v + (align - rem); +} + +static void *hip_align_ptr(void *ptr, uint64_t align) { + if (align <= 1) return ptr; + uintptr_t p = (uintptr_t)ptr; + uintptr_t a = (uintptr_t)align; + return (void *)(((p + a - 1u) / a) * a); +} + +static int hip_model_stage_pool_alloc(uint64_t bytes) { + if (g_model_stage_bytes >= bytes) return 1; + for (size_t i = 0; i < 4; i++) { + if (g_model_stage_event[i]) { + (void)hipEventDestroy(g_model_stage_event[i]); + g_model_stage_event[i] = NULL; + } + if (g_model_stage_raw[i]) { + (void)hipHostFree(g_model_stage_raw[i]); + g_model_stage_raw[i] = NULL; + g_model_stage[i] = NULL; + } + } + g_model_stage_bytes = 0; + if (!g_model_upload_stream) { + hipError_t err = hipStreamCreateWithFlags(&g_model_upload_stream, hipStreamNonBlocking); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model upload stream creation failed: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + } + for (size_t i = 0; i < 4; i++) { + hipError_t err = hipHostMalloc(&g_model_stage_raw[i], (size_t)bytes); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm pinned model staging allocation failed: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + g_model_stage[i] = hip_align_ptr(g_model_stage_raw[i], g_model_direct_align); + err = hipEventCreateWithFlags(&g_model_stage_event[i], hipEventDisableTiming); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model staging event creation failed: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + } + g_model_stage_bytes = bytes; + return 1; +} + +static int hip_pread_full(int fd, void *buf, uint64_t bytes, uint64_t offset) { + uint64_t done = 0; + while (done < bytes) { + const size_t n_req = (bytes - done > (uint64_t)SSIZE_MAX) ? (size_t)SSIZE_MAX : (size_t)(bytes - done); + ssize_t n = pread(fd, (char *)buf + done, n_req, (off_t)(offset + done)); + if (n < 0) { + if (errno == EINTR) continue; + return 0; + } + if (n == 0) return 0; + done += (uint64_t)n; + } + return 1; +} + +static int hip_model_stage_read(void *stage, uint64_t stage_bytes, + uint64_t offset, uint64_t bytes, + const char **payload) { + *payload = (const char *)stage; +#if defined(__linux__) && defined(O_DIRECT) + if (g_model_direct_fd >= 0 && g_model_direct_align > 1 && g_model_file_size != 0) { + const uint64_t aligned_off = hip_round_down(offset, g_model_direct_align); + const uint64_t delta = offset - aligned_off; + uint64_t read_size = hip_round_up(delta + bytes, g_model_direct_align); + if (aligned_off <= g_model_file_size && + read_size <= stage_bytes && + read_size <= g_model_file_size - aligned_off) { + const int saved_errno = errno; + errno = 0; + if (hip_pread_full(g_model_direct_fd, stage, read_size, aligned_off)) { + *payload = (const char *)stage + delta; + errno = saved_errno; + return 1; + } + const int direct_errno = errno; + if (direct_errno == EINVAL || direct_errno == EFAULT || direct_errno == ENOTSUP || direct_errno == EOPNOTSUPP) { + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm direct model read disabled: %s\n", strerror(direct_errno)); + } + (void)close(g_model_direct_fd); + g_model_direct_fd = -1; + g_model_direct_align = 1; + } + errno = direct_errno; + } + } +#else + (void)stage_bytes; +#endif + return hip_pread_full(g_model_fd, stage, bytes, offset); +} + +static uint64_t hip_model_cache_limit_bytes(void) { + uint64_t gb = 0; + const char *env = getenv("DS4_HIP_WEIGHT_CACHE_LIMIT_GB"); + if (env && env[0]) { + char *end = NULL; + unsigned long long v = strtoull(env, &end, 10); + if (end != env) gb = (uint64_t)v; + } + if (gb == 0) return UINT64_MAX; + return gb * 1073741824ull; +} + +static uint64_t hip_model_arena_chunk_bytes(uint64_t need) { + uint64_t mb = 1792; + const char *env = getenv("DS4_HIP_WEIGHT_ARENA_CHUNK_MB"); + if (env && env[0]) { + char *end = NULL; + unsigned long long v = strtoull(env, &end, 10); + if (end != env && v > 0) mb = (uint64_t)v; + } + if (mb < 256) mb = 256; + if (mb > 8192) mb = 8192; + uint64_t bytes = mb * 1048576ull; + if (bytes < need) { + const uint64_t align = 256ull * 1048576ull; + bytes = (need + align - 1u) & ~(align - 1u); + } + return bytes; +} + +static char *hip_model_arena_alloc(uint64_t bytes, const char *what) { + if (bytes == 0) return NULL; + if (g_model_cache_full) return NULL; + const uint64_t align = 256u; + const uint64_t aligned = (bytes + align - 1u) & ~(align - 1u); + + for (hip_model_arena &a : g_model_arenas) { + const uint64_t used = (a.used + align - 1u) & ~(align - 1u); + if (used <= a.bytes && aligned <= a.bytes - used) { + char *ptr = a.device_ptr + used; + a.used = used + aligned; + return ptr; + } + } + + const uint64_t limit = hip_model_cache_limit_bytes(); + if (g_model_range_bytes > limit || aligned > limit - g_model_range_bytes) return NULL; + + const uint64_t chunk = hip_model_arena_chunk_bytes(aligned); + void *dev = NULL; + hipError_t err = hipMalloc(&dev, (size_t)chunk); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model arena alloc failed for %s (%.2f MiB chunk): %s\n", + what ? what : "weights", + (double)chunk / 1048576.0, + hipGetErrorString(err)); + (void)hipGetLastError(); + g_model_cache_full = 1; + return NULL; + } + g_model_arenas.push_back({(char *)dev, chunk, aligned}); + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + uint64_t arena_bytes = 0; + for (const hip_model_arena &a : g_model_arenas) arena_bytes += a.bytes; + fprintf(stderr, "ds4: ROCm model arena allocated %.2f MiB (arenas %.2f GiB)\n", + (double)chunk / 1048576.0, + (double)arena_bytes / 1073741824.0); + } + return (char *)dev; +} + +static const char *hip_model_range_ptr_from_fd( + const void *model_map, + uint64_t offset, + uint64_t bytes, + const char *what) { + if (g_model_fd < 0 || bytes == 0) return NULL; + const uint64_t limit = hip_model_cache_limit_bytes(); + if (g_model_range_bytes > limit || bytes > limit - g_model_range_bytes) { + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm direct %s %.2f MiB (cache budget %.2f GiB exhausted)\n", + what ? what : "weights", + (double)bytes / 1048576.0, + (double)limit / 1073741824.0); + } + return hip_model_ptr(model_map, offset); + } + + char *dev = hip_model_arena_alloc(bytes, what); + if (!dev) { + if (getenv("DS4_HIP_STRICT_WEIGHT_CACHE") != NULL) return NULL; + return hip_model_ptr(model_map, offset); + } + hipError_t err = hipSuccess; + + const uint64_t chunk = hip_model_copy_chunk_bytes(); + const uint64_t stage_bytes = chunk + (g_model_direct_align > 1 ? g_model_direct_align : 1); + if (!hip_model_stage_pool_alloc(stage_bytes)) return NULL; + + uint64_t copied = 0; + uint64_t chunk_idx = 0; + while (copied < bytes) { + const uint64_t n = (bytes - copied < chunk) ? (bytes - copied) : chunk; + const uint64_t bi = chunk_idx % 4u; + if (chunk_idx >= 4u) { + err = hipEventSynchronize(g_model_stage_event[bi]); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model staging wait failed for %s: %s\n", + what ? what : "weights", hipGetErrorString(err)); + (void)hipGetLastError(); + return NULL; + } + } + const char *payload = NULL; + if (!hip_model_stage_read(g_model_stage[bi], g_model_stage_bytes, + offset + copied, n, &payload)) { + fprintf(stderr, "ds4: ROCm model range read failed for %s at %.2f MiB: %s\n", + what ? what : "weights", + (double)copied / 1048576.0, + strerror(errno)); + return NULL; + } + err = hipMemcpyAsync(dev + copied, payload, (size_t)n, + hipMemcpyHostToDevice, g_model_upload_stream); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model range copy failed for %s at %.2f MiB: %s\n", + what ? what : "weights", + (double)copied / 1048576.0, + hipGetErrorString(err)); + (void)hipGetLastError(); + return NULL; + } + err = hipEventRecord(g_model_stage_event[bi], g_model_upload_stream); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model staging record failed for %s: %s\n", + what ? what : "weights", hipGetErrorString(err)); + (void)hipGetLastError(); + return NULL; + } + hip_model_drop_file_pages(offset + copied, n); + hip_model_discard_source_pages(model_map, g_model_registered_size, offset + copied, n); + copied += n; + hip_model_load_progress_note(g_model_range_bytes + copied); + chunk_idx++; + } + err = hipStreamSynchronize(g_model_upload_stream); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model range upload sync failed for %s: %s\n", + what ? what : "weights", hipGetErrorString(err)); + (void)hipGetLastError(); + return NULL; + } + + g_model_ranges.push_back({model_map, offset, bytes, dev, NULL, NULL, 0, 0, 1}); + g_model_range_by_offset[offset] = g_model_ranges.size() - 1u; + g_model_range_bytes += bytes; + hip_model_load_progress_note(g_model_range_bytes); + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm fd-cached %s %.2f MiB (total %.2f GiB)\n", + what ? what : "weights", + (double)bytes / 1048576.0, + (double)g_model_range_bytes / 1073741824.0); + } + return (const char *)dev; +} + +static int hip_model_copy_chunked(const void *model_map, uint64_t model_size, uint64_t map_offset, uint64_t map_size) { + if (!model_map || model_size == 0 || map_offset > model_size || map_size > model_size - map_offset) return 0; + if (getenv("DS4_HIP_NO_MODEL_COPY") != NULL || + getenv("DS4_HIP_DIRECT_MODEL") != NULL || + getenv("DS4_HIP_WEIGHT_CACHE") != NULL || + getenv("DS4_HIP_WEIGHT_PRELOAD") != NULL) { + return 0; + } + if (g_model_device_owned || g_model_registered) return 1; + + void *dev = NULL; + const double t0 = hip_wall_sec(); + hipError_t err = hipMalloc(&dev, (size_t)model_size); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model allocation skipped: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + return 0; + } + + fprintf(stderr, "ds4: ROCm chunk-copying %.2f GiB model image\n", + (double)model_size / 1073741824.0); + + const uint64_t chunk = hip_model_copy_chunk_bytes(); + void *stage = NULL; + err = hipHostMalloc(&stage, (size_t)chunk); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm pinned model staging allocation failed: %s\n", hipGetErrorString(err)); + (void)hipFree(dev); + (void)hipGetLastError(); + return 0; + } + + if (map_offset > 0) { + uint64_t copied_header = 0; + while (copied_header < map_offset) { + const uint64_t n = (map_offset - copied_header < chunk) ? (map_offset - copied_header) : chunk; + memcpy(stage, (const char *)model_map + copied_header, (size_t)n); + err = hipMemcpy((char *)dev + copied_header, stage, (size_t)n, hipMemcpyHostToDevice); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model header copy failed: %s\n", hipGetErrorString(err)); + (void)hipHostFree(stage); + (void)hipFree(dev); + (void)hipGetLastError(); + return 0; + } + copied_header += n; + } + } + + uint64_t copied = 0; + double last_report = t0; + while (copied < map_size) { + const uint64_t n = (map_size - copied < chunk) ? (map_size - copied) : chunk; + const uint64_t off = map_offset + copied; + memcpy(stage, (const char *)model_map + off, (size_t)n); + err = hipMemcpy((char *)dev + off, stage, (size_t)n, hipMemcpyHostToDevice); + if (err != hipSuccess) { + fprintf(stderr, "ds4: ROCm model chunk copy failed at %.2f GiB: %s\n", + (double)copied / 1073741824.0, hipGetErrorString(err)); + (void)hipHostFree(stage); + (void)hipFree(dev); + (void)hipGetLastError(); + return 0; + } + hip_model_discard_source_pages(model_map, model_size, off, n); + copied += n; + const double now = hip_wall_sec(); + if (getenv("DS4_HIP_MODEL_COPY_VERBOSE") != NULL && now - last_report >= 2.0) { + fprintf(stderr, "ds4: ROCm model chunk copy %.2f/%.2f GiB\n", + (double)copied / 1073741824.0, + (double)map_size / 1073741824.0); + last_report = now; + } + } + + (void)hipHostFree(stage); + g_model_device_base = (const char *)dev; + g_model_device_owned = 1; + g_model_hmm_direct = 0; + const double t1 = hip_wall_sec(); + fprintf(stderr, + "ds4: ROCm model chunk copy complete in %.3fs (%.2f GiB tensors)\n", + t1 - t0, + (double)map_size / 1073741824.0); + return 1; +} + +static void hip_model_range_release_all(void) { + for (const hip_model_range &r : g_model_ranges) { + if (r.host_registered && r.registered_base) { + (void)hipHostUnregister(r.registered_base); + } else if (r.device_ptr && !r.arena_allocated) { + (void)hipFree(r.device_ptr); + } + } + for (const hip_model_arena &a : g_model_arenas) { + if (a.device_ptr) (void)hipFree(a.device_ptr); + } + g_model_arenas.clear(); + g_model_ranges.clear(); + g_model_range_by_offset.clear(); + g_model_range_bytes = 0; + hip_model_load_progress_reset(); +} + +static int hipblas_ok(hipblasStatus_t st, const char *what) { + if (st == HIPBLAS_STATUS_SUCCESS) return 1; + fprintf(stderr, "ds4: hipBLAS %s failed: status %d\n", what, (int)st); + return 0; +} + +extern "C" int ds4_gpu_init(void) { + int dev = 0; + if (!hip_ok(hipSetDevice(dev), "set device")) return 0; + hipDeviceProp_t prop; + if (hipGetDeviceProperties(&prop, dev) == hipSuccess) { + fprintf(stderr, "ds4: ROCm backend initialized on %s (sm_%d%d)\n", + prop.name, prop.major, prop.minor); + } + if (!g_hipblas_ready) { + if (!hipblas_ok(hipblasCreate(&g_hipblas), "create handle")) return 0; + const hipblasMath_t math_mode = + (g_quality_mode || getenv("DS4_HIP_NO_TF32") != NULL) + ? HIPBLAS_DEFAULT_MATH + : HIPBLAS_DEFAULT_MATH; + (void)hipblasSetMathMode(g_hipblas, math_mode); + g_hipblas_ready = 1; + } + return 1; +} + +extern "C" void ds4_gpu_cleanup(void) { + (void)hipDeviceSynchronize(); + if (g_hipblas_ready) { + (void)hipblasDestroy(g_hipblas); + g_hipblas_ready = 0; + g_hipblas = NULL; + } + hip_model_range_release_all(); + hip_q8_f16_cache_release_all(); + g_q8_f16_disabled_after_oom = 0; + g_q8_f16_budget_notice_printed = 0; + for (const hip_q8_f32_range &r : g_q8_f32_ranges) { + (void)hipFree(r.device_ptr); + } + g_q8_f32_ranges.clear(); + g_q8_f32_by_offset.clear(); + g_q8_f32_bytes = 0; + if (g_hip_tmp) { + (void)hipFree(g_hip_tmp); + g_hip_tmp = NULL; + g_hip_tmp_bytes = 0; + } + for (size_t i = 0; i < 4; i++) { + if (g_model_stage_event[i]) { + (void)hipEventDestroy(g_model_stage_event[i]); + g_model_stage_event[i] = NULL; + } + if (g_model_stage_raw[i]) { + (void)hipHostFree(g_model_stage_raw[i]); + g_model_stage_raw[i] = NULL; + g_model_stage[i] = NULL; + } + } + g_model_stage_bytes = 0; + if (g_model_upload_stream) { + (void)hipStreamDestroy(g_model_upload_stream); + g_model_upload_stream = NULL; + } + if (g_model_device_owned && g_model_device_base) { + (void)hipFree((void *)g_model_device_base); + } + if (g_model_registered && g_model_host_base) { + (void)hipHostUnregister((void *)g_model_host_base); + } + g_model_host_base = NULL; + g_model_device_base = NULL; + g_model_registered_size = 0; + g_model_registered = 0; + g_model_device_owned = 0; + g_model_range_mapping_supported = 1; + g_model_hmm_direct = 0; + g_model_fd = -1; + if (g_model_direct_fd >= 0) { + (void)close(g_model_direct_fd); + g_model_direct_fd = -1; + } + g_model_direct_align = 1; + g_model_file_size = 0; + g_model_cache_full = 0; + if (g_model_prefetch_stream) { + (void)hipStreamDestroy(g_model_prefetch_stream); + g_model_prefetch_stream = NULL; + } +} + +extern "C" ds4_gpu_tensor *ds4_gpu_tensor_alloc(uint64_t bytes) { + if (bytes == 0) bytes = 1; + ds4_gpu_tensor *t = (ds4_gpu_tensor *)calloc(1, sizeof(*t)); + if (!t) return NULL; + if (!hip_ok(hipMallocManaged(&t->ptr, (size_t)bytes), "tensor alloc")) { + free(t); + return NULL; + } + t->bytes = bytes; + t->owner = 1; + return t; +} + +extern "C" ds4_gpu_tensor *ds4_gpu_tensor_view(const ds4_gpu_tensor *base, uint64_t offset, uint64_t bytes) { + if (!base || offset > base->bytes || bytes > base->bytes - offset) return NULL; + ds4_gpu_tensor *t = (ds4_gpu_tensor *)calloc(1, sizeof(*t)); + if (!t) return NULL; + t->ptr = (char *)base->ptr + offset; + t->bytes = bytes; + t->owner = 0; + return t; +} + +extern "C" void ds4_gpu_tensor_free(ds4_gpu_tensor *tensor) { + if (!tensor) return; + if (tensor->owner && tensor->ptr) (void)hipFree(tensor->ptr); + free(tensor); +} + +extern "C" uint64_t ds4_gpu_tensor_bytes(const ds4_gpu_tensor *tensor) { + return tensor ? tensor->bytes : 0; +} + +extern "C" void *ds4_gpu_tensor_contents(ds4_gpu_tensor *tensor) { + if (!tensor) return NULL; + (void)hipDeviceSynchronize(); + return tensor->ptr; +} + +extern "C" int ds4_gpu_tensor_write(ds4_gpu_tensor *tensor, uint64_t offset, const void *data, uint64_t bytes) { + if (!tensor || !data || offset > tensor->bytes || bytes > tensor->bytes - offset) return 0; + return hip_ok(hipMemcpy((char *)tensor->ptr + offset, data, (size_t)bytes, hipMemcpyHostToDevice), "tensor write"); +} + +extern "C" int ds4_gpu_tensor_read(const ds4_gpu_tensor *tensor, uint64_t offset, void *data, uint64_t bytes) { + if (!tensor || !data || offset > tensor->bytes || bytes > tensor->bytes - offset) return 0; + return hip_ok(hipMemcpy(data, (const char *)tensor->ptr + offset, (size_t)bytes, hipMemcpyDeviceToHost), "tensor read"); +} + +extern "C" int ds4_gpu_tensor_copy(ds4_gpu_tensor *dst, uint64_t dst_offset, + const ds4_gpu_tensor *src, uint64_t src_offset, + uint64_t bytes) { + if (!dst || !src || dst_offset > dst->bytes || src_offset > src->bytes || + bytes > dst->bytes - dst_offset || bytes > src->bytes - src_offset) { + return 0; + } + if (bytes == 0) return 1; + return hip_ok(hipMemcpy((char *)dst->ptr + dst_offset, + (const char *)src->ptr + src_offset, + (size_t)bytes, + hipMemcpyDeviceToDevice), + "tensor copy"); +} + +extern "C" int ds4_gpu_begin_commands(void) { return 1; } +extern "C" int ds4_gpu_flush_commands(void) { return hip_ok(hipDeviceSynchronize(), "flush"); } +extern "C" int ds4_gpu_end_commands(void) { return hip_ok(hipDeviceSynchronize(), "end commands"); } +extern "C" int ds4_gpu_synchronize(void) { return hip_ok(hipDeviceSynchronize(), "synchronize"); } + +extern "C" int ds4_gpu_set_model_map(const void *model_map, uint64_t model_size) { + if (!model_map || model_size == 0) return 0; + if (g_model_host_base == model_map && g_model_registered_size == model_size) return 1; + hip_model_range_release_all(); + hip_q8_f16_cache_release_all(); + g_q8_f16_disabled_after_oom = 0; + g_q8_f16_budget_notice_printed = 0; + for (const hip_q8_f32_range &r : g_q8_f32_ranges) { + (void)hipFree(r.device_ptr); + } + g_q8_f32_ranges.clear(); + g_q8_f32_by_offset.clear(); + g_q8_f32_bytes = 0; + if (g_model_device_owned && g_model_device_base) { + (void)hipFree((void *)g_model_device_base); + g_model_device_owned = 0; + } + if (g_model_registered && g_model_host_base) { + (void)hipHostUnregister((void *)g_model_host_base); + g_model_registered = 0; + } + g_model_host_base = model_map; + g_model_device_base = (const char *)model_map; + g_model_registered_size = model_size; + g_model_range_mapping_supported = 1; + g_model_hmm_direct = 0; + g_model_cache_full = 0; + + const char *copy_env = getenv("DS4_HIP_COPY_MODEL"); + if (copy_env && copy_env[0]) { + void *dev = NULL; + const double t0 = clock() / (double)CLOCKS_PER_SEC; + hipError_t err = hipMalloc(&dev, (size_t)model_size); + if (err == hipSuccess) { + fprintf(stderr, "ds4: ROCm copying %.2f GiB model to device memory\n", + (double)model_size / 1073741824.0); + err = hipMemcpy(dev, model_map, (size_t)model_size, hipMemcpyHostToDevice); + if (err == hipSuccess) { + g_model_device_base = (const char *)dev; + g_model_device_owned = 1; + const double t1 = clock() / (double)CLOCKS_PER_SEC; + fprintf(stderr, "ds4: ROCm model copy complete in %.3fs\n", t1 - t0); + return 1; + } + fprintf(stderr, "ds4: ROCm model copy failed: %s\n", hipGetErrorString(err)); + (void)hipFree(dev); + (void)hipGetLastError(); + } else { + fprintf(stderr, "ds4: ROCm model allocation skipped: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + } + } + + /* Try with ReadOnly flag first; fall back to plain Mapped if not supported. + * hipHostRegisterReadOnly can fail with hipErrorInvalidValue on some ROCm + * builds even though the operation itself is valid. */ + hipError_t err = hipHostRegister((void *)model_map, (size_t)model_size, + hipHostRegisterMapped | hipHostRegisterReadOnly); + if (err == hipErrorInvalidValue || err == hipErrorNotSupported) { + (void)hipGetLastError(); + err = hipHostRegister((void *)model_map, (size_t)model_size, hipHostRegisterMapped); + } + if (err == hipSuccess) { + void *dev = NULL; + err = hipHostGetDevicePointer(&dev, (void *)model_map, 0); + if (err == hipSuccess && dev) { + g_model_device_base = (const char *)dev; + g_model_registered = 1; + int dev_id = 0; + (void)hipGetDevice(&dev_id); + (void)hipMemAdvise((void *)model_map, (size_t)model_size, hipMemAdviseSetReadMostly, dev_id); + (void)hipMemAdvise((void *)model_map, (size_t)model_size, hipMemAdviseSetCoarseGrain, dev_id); + fprintf(stderr, "ds4: ROCm registered %.2f GiB model mapping (cached coarse-grained) for device access\n", + (double)model_size / 1073741824.0); + } else { + fprintf(stderr, "ds4: ROCm host registration pointer lookup failed: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + } + } else { + fprintf(stderr, "ds4: ROCm host registration skipped: %s\n", hipGetErrorString(err)); + (void)hipGetLastError(); + /* On HSA unified memory (e.g., Strix Halo), the CPU mapping is already + * device-accessible without prior registration. Try the pointer lookup + * directly; if it works we can skip the VRAM copy path entirely. */ + void *dev = NULL; + hipError_t hsa_err = hipHostGetDevicePointer(&dev, (void *)model_map, 0); + if (hsa_err == hipSuccess && dev) { + g_model_device_base = (const char *)dev; + g_model_registered = 1; + int dev_id = 0; + (void)hipGetDevice(&dev_id); + (void)hipMemAdvise((void *)model_map, (size_t)model_size, hipMemAdviseSetReadMostly, dev_id); + (void)hipMemAdvise((void *)model_map, (size_t)model_size, hipMemAdviseSetCoarseGrain, dev_id); + fprintf(stderr, "ds4: HSA direct model access enabled (cached coarse-grained) (%.2f GiB)\n", + (double)model_size / 1073741824.0); + } else { + (void)hipGetLastError(); + } + } + return 1; +} + +extern "C" int ds4_gpu_set_model_map_range(const void *model_map, uint64_t model_size, uint64_t map_offset, uint64_t map_size) { + if (!ds4_gpu_set_model_map(model_map, model_size)) return 0; + if (getenv("DS4_HIP_COPY_MODEL_CHUNKED") != NULL && + !hip_model_copy_chunked(model_map, model_size, map_offset, map_size)) { + (void)hip_model_prefetch_range(model_map, model_size, map_offset, map_size); + } + return 1; +} + +extern "C" int ds4_gpu_set_model_fd(int fd) { + g_model_fd = fd; + g_model_file_size = 0; + if (g_model_direct_fd >= 0) { + (void)close(g_model_direct_fd); + g_model_direct_fd = -1; + } + g_model_direct_align = 1; + if (fd >= 0) { + struct stat st; + if (fstat(fd, &st) == 0 && st.st_size > 0) { + g_model_file_size = (uint64_t)st.st_size; + if (st.st_blksize > 1) g_model_direct_align = (uint64_t)st.st_blksize; + } +#if defined(__linux__) && defined(O_DIRECT) + if (getenv("DS4_HIP_NO_DIRECT_IO") == NULL) { + char proc_path[64]; + snprintf(proc_path, sizeof(proc_path), "/proc/self/fd/%d", fd); + int direct_fd = open(proc_path, O_RDONLY | O_DIRECT); + if (direct_fd >= 0) { + g_model_direct_fd = direct_fd; + if (g_model_direct_align < 512) g_model_direct_align = 512; + if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm model direct I/O enabled (align=%llu)\n", + (unsigned long long)g_model_direct_align); + } + } else if (getenv("DS4_HIP_WEIGHT_CACHE_VERBOSE")) { + fprintf(stderr, "ds4: ROCm model direct I/O unavailable: %s\n", strerror(errno)); + } + } +#endif + } + return 1; +} + +extern "C" int ds4_gpu_cache_model_range(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes, const char *label) { + if (!model_map || bytes == 0) return 1; + if (offset > model_size || bytes > model_size - offset) return 0; + if (!hip_model_range_ptr(model_map, offset, bytes, label ? label : "model_tensor")) return 0; + return hip_model_range_is_cached(model_map, offset, bytes); +} + +extern "C" int ds4_gpu_cache_q8_f16_range(const void *model_map, uint64_t model_size, uint64_t offset, uint64_t bytes, uint64_t in_dim, uint64_t out_dim, const char *label) { + if (!model_map || bytes == 0) return 1; + if (offset > model_size || bytes > model_size - offset) return 0; + static int optional_q8_preload_disabled = 0; + if (optional_q8_preload_disabled) return 1; + const char *cache_label = label ? label : "q8_0"; + if (getenv("DS4_HIP_Q8_F32_PRELOAD") != NULL && + hip_q8_f32_cache_allowed(cache_label, in_dim, out_dim)) { + if (hip_q8_f32_ptr(model_map, offset, bytes, in_dim, out_dim, cache_label)) return 1; + optional_q8_preload_disabled = 1; + return 1; + } + if (!hip_q8_f16_preload_allowed(cache_label, in_dim, out_dim)) return 1; + if (hip_q8_f16_ptr(model_map, offset, bytes, in_dim, out_dim, cache_label)) return 1; + optional_q8_preload_disabled = 1; + return 1; +} + +extern "C" void ds4_gpu_print_memory_report(const char *label) { + size_t free_b = 0, total_b = 0; + (void)hipMemGetInfo(&free_b, &total_b); + fprintf(stderr, "ds4: ROCm memory report %s: free %.2f MiB total %.2f MiB\n", + label ? label : "", (double)free_b / 1048576.0, (double)total_b / 1048576.0); +} + +extern "C" void ds4_gpu_set_quality(bool quality) { + g_quality_mode = quality ? 1 : 0; + if (g_hipblas_ready) { + const hipblasMath_t math_mode = + (g_quality_mode || getenv("DS4_HIP_NO_TF32") != NULL) + ? HIPBLAS_DEFAULT_MATH + : HIPBLAS_DEFAULT_MATH; + (void)hipblasSetMathMode(g_hipblas, math_mode); + } +} + +__global__ static void embed_token_hc_kernel(float *out, const unsigned short *w, uint32_t token, uint32_t n_embd, uint32_t n_hc) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t n = n_embd * n_hc; + if (i >= n) return; + uint32_t e = i % n_embd; + out[i] = __half2float(reinterpret_cast(w)[(uint64_t)token * n_embd + e]); +} + +__global__ static void embed_tokens_hc_kernel( + float *out, + const int32_t *tokens, + const __half *w, + uint32_t n_vocab, + uint32_t n_tokens, + uint32_t n_embd, + uint32_t n_hc) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_tokens * n_hc * n_embd; + if (gid >= n) return; + uint32_t d = gid % n_embd; + uint64_t tmp = gid / n_embd; + uint32_t t = tmp / n_hc; + int32_t tok_i = tokens[t]; + uint32_t tok = tok_i < 0 ? 0u : (uint32_t)tok_i; + if (tok >= n_vocab) tok = 0; + out[gid] = __half2float(w[(uint64_t)tok * n_embd + d]); +} + +__global__ static void matmul_f16_kernel( + float *out, + const __half *w, + const float *x, + uint64_t in_dim, + uint64_t out_dim, + uint64_t n_tok) { + uint64_t row = (uint64_t)blockIdx.x; + uint64_t tok = (uint64_t)blockIdx.y; + if (row >= out_dim || tok >= n_tok) return; + + float sum = 0.0f; + const __half *wr = w + row * in_dim; + const float *xr = x + tok * in_dim; + for (uint64_t i = threadIdx.x; i < in_dim; i += blockDim.x) { + sum += __half2float(wr[i]) * xr[i]; + } + + __shared__ float partial[256]; + partial[threadIdx.x] = sum; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) out[tok * out_dim + row] = partial[0]; +} + +__global__ static void matmul_f16_serial_kernel( + float *out, + const __half *w, + const float *x, + uint64_t in_dim, + uint64_t out_dim, + uint64_t n_tok) { + uint64_t row = (uint64_t)blockIdx.x; + uint64_t tok = (uint64_t)blockIdx.y; + if (row >= out_dim || tok >= n_tok || threadIdx.x != 0) return; + + float sum = 0.0f; + const __half *wr = w + row * in_dim; + const float *xr = x + tok * in_dim; + for (uint64_t i = 0; i < in_dim; i++) { + sum += __half2float(wr[i]) * xr[i]; + } + out[tok * out_dim + row] = sum; +} + +__global__ static void matmul_f16_ordered_chunks_kernel( + float *out, + const __half *w, + const float *x, + uint64_t in_dim, + uint64_t out_dim, + uint64_t n_tok) { + uint64_t row = (uint64_t)blockIdx.x; + uint64_t tok = (uint64_t)blockIdx.y; + if (row >= out_dim || tok >= n_tok) return; + + const uint32_t tid = threadIdx.x; + float sum = 0.0f; + const uint64_t h2_count = in_dim >> 1; + const __half2 *wr2 = (const __half2 *)(w + row * in_dim); + const float2 *xr2 = (const float2 *)(x + tok * in_dim); + + // Interleaved (coalesced) access: all threads in the warp read contiguous memory + for (uint64_t i = tid; i < h2_count; i += blockDim.x) { + __half2 wv = wr2[i]; + float2 xv = xr2[i]; + sum += __half2float(wv.x) * xv.x + __half2float(wv.y) * xv.y; + } + // Scalar tail for odd in_dim + if (tid == 0 && (in_dim & 1u)) { + sum += __half2float(w[row * in_dim + in_dim - 1]) * x[tok * in_dim + in_dim - 1]; + } + for (uint32_t offset = 16u; offset > 0u; offset >>= 1) + sum += __shfl_down(sum, offset); + if (tid == 0) out[tok * out_dim + row] = sum; +} + +__global__ static void matmul_f16_pair_ordered_chunks_kernel( + float *out0, + float *out1, + const __half *w0, + const __half *w1, + const float *x, + uint64_t in_dim, + uint64_t out0_dim, + uint64_t out1_dim, + uint64_t n_tok) { + uint64_t row = (uint64_t)blockIdx.x; + uint64_t tok = (uint64_t)blockIdx.y; + if (row >= out0_dim && row >= out1_dim) return; + if (tok >= n_tok) return; + + const uint32_t tid = threadIdx.x; + float sum0 = 0.0f; + float sum1 = 0.0f; + const uint64_t h2_count = in_dim >> 1; + const __half2 *wr2_0 = row < out0_dim ? (const __half2 *)(w0 + row * in_dim) : NULL; + const __half2 *wr2_1 = row < out1_dim ? (const __half2 *)(w1 + row * in_dim) : NULL; + const float2 *xr2 = (const float2 *)(x + tok * in_dim); + + for (uint64_t i = tid; i < h2_count; i += blockDim.x) { + float2 xv = xr2[i]; + if (wr2_0) { + __half2 wv = wr2_0[i]; + sum0 += __half2float(wv.x) * xv.x + __half2float(wv.y) * xv.y; + } + if (wr2_1) { + __half2 wv = wr2_1[i]; + sum1 += __half2float(wv.x) * xv.x + __half2float(wv.y) * xv.y; + } + } + if (tid == 0 && (in_dim & 1u)) { + float xv = x[tok * in_dim + in_dim - 1]; + if (row < out0_dim) sum0 += __half2float(w0[row * in_dim + in_dim - 1]) * xv; + if (row < out1_dim) sum1 += __half2float(w1[row * in_dim + in_dim - 1]) * xv; + } + for (uint32_t offset = 16u; offset > 0u; offset >>= 1) { + sum0 += __shfl_down(sum0, offset); + sum1 += __shfl_down(sum1, offset); + } + if (tid == 0) { + if (row < out0_dim) out0[tok * out0_dim + row] = sum0; + if (row < out1_dim) out1[tok * out1_dim + row] = sum1; + } +} + +__global__ static void matmul_f32_kernel( + float *out, + const float *w, + const float *x, + uint64_t in_dim, + uint64_t out_dim, + uint64_t n_tok) { + uint64_t row = (uint64_t)blockIdx.x; + uint64_t tok = (uint64_t)blockIdx.y; + if (row >= out_dim || tok >= n_tok) return; + + float sum = 0.0f; + const float *wr = w + row * in_dim; + const float *xr = x + tok * in_dim; + for (uint64_t i = threadIdx.x; i < in_dim; i += blockDim.x) + sum += wr[i] * xr[i]; + // Reduce within warp via shuffles, then accumulate 8 warp sums via shared memory. + for (uint32_t offset = 16u; offset > 0u; offset >>= 1) + sum += __shfl_down(sum, offset); + __shared__ float warp_sums[8]; + if ((threadIdx.x & 31u) == 0u) + warp_sums[threadIdx.x >> 5] = sum; + __syncthreads(); + if (threadIdx.x == 0) { + float total = warp_sums[0]; + for (uint32_t i = 1u; i < 8u; i++) total += warp_sums[i]; + out[tok * out_dim + row] = total; + } +} + +__global__ static void repeat_hc_kernel(float *out, const float *row, uint32_t n_embd, uint32_t n_hc) { + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_embd * n_hc; + if (i >= n) return; + out[i] = row[i % n_embd]; +} + +__global__ static void f32_to_f16_kernel(__half *out, const float *x, uint64_t n) { + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) out[i] = __float2half(x[i]); +} + +__device__ static float warp_sum_f32(float v) { + for (int offset = 16; offset > 0; offset >>= 1) { + v += __shfl_down(v, offset); + } + return v; +} + +__device__ static float warp_max_f32(float v) { + for (int offset = 16; offset > 0; offset >>= 1) { + v = fmaxf(v, __shfl_down(v, offset)); + } + return v; +} + +__device__ static float dot4_f32(float4 a, float4 b) { + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +__device__ __forceinline__ static int32_t load_i8x4_i32_aligned(const int8_t *p) { + return *(const int32_t *)p; +} + +__device__ __forceinline__ static int32_t load_i8x4_i32_unaligned(const int8_t *p) { + const uint8_t *u = (const uint8_t *)p; + return (int32_t)((uint32_t)u[0] | + ((uint32_t)u[1] << 8) | + ((uint32_t)u[2] << 16) | + ((uint32_t)u[3] << 24)); +} + +__device__ __forceinline__ static int32_t dot_i8x32_dp4a(const int8_t *a, const int8_t *b) { + int32_t dot = 0; +#pragma unroll + for (uint32_t i = 0; i < 32u; i += 4u) { + dot = __dp4a(load_i8x4_i32_unaligned(a + i), load_i8x4_i32_aligned(b + i), dot); + } + return dot; +} + +__device__ __forceinline__ static int32_t dot_i8_block(const int8_t *a, const int8_t *b, uint64_t n, int use_dp4a) { + if (use_dp4a && n == 32u) return dot_i8x32_dp4a(a, b); + int32_t dot = 0; + for (uint64_t i = 0; i < n; i++) dot += (int32_t)a[i] * (int32_t)b[i]; + return dot; +} + +__global__ static DS4_HIP_UNUSED void matmul_q8_0_kernel( + float *out, + const unsigned char *w, + const float *x, + uint64_t in_dim, + uint64_t out_dim, + uint64_t n_tok) { + uint64_t row = (uint64_t)blockIdx.x; + uint64_t tok = (uint64_t)blockIdx.y; + if (row >= out_dim || tok >= n_tok) return; + const uint64_t blocks = (in_dim + 31) / 32; + const unsigned char *wr = w + row * blocks * 34; + const float *xr = x + tok * in_dim; + float acc = 0.0f; + + for (uint64_t b = threadIdx.x; b < blocks; b += blockDim.x) { + uint64_t i0 = b * 32; + uint64_t bn = in_dim - i0 < 32 ? in_dim - i0 : 32; + float amax = 0.0f; + for (uint64_t i = 0; i < bn; i++) amax = fmaxf(amax, fabsf(xr[i0 + i])); + float d = amax / 127.0f; + float id = d != 0.0f ? 1.0f / d : 0.0f; + const __half *scale_h = (const __half *)(wr + b * 34); + const int8_t *qs = (const int8_t *)(wr + b * 34 + 2); + int dot = 0; + for (uint64_t i = 0; i < bn; i++) { + int q = (int)lrintf(xr[i0 + i] * id); + q = q > 127 ? 127 : (q < -128 ? -128 : q); + dot += (int)qs[i] * q; + } + acc += __half2float(*scale_h) * d * (float)dot; + } + + __shared__ float partial[256]; + partial[threadIdx.x] = acc; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) out[tok * out_dim + row] = partial[0]; +} + +__global__ static void quantize_q8_0_f32_kernel( + int8_t *xq, + float *xscale, + const float *x, + uint64_t in_dim, + uint64_t blocks) { + uint64_t b = blockIdx.x; + uint64_t tok = blockIdx.y; + if (b >= blocks) return; + uint64_t i0 = b * 32; + uint64_t bn = in_dim - i0 < 32 ? in_dim - i0 : 32; + const float *xr = x + tok * in_dim + i0; + + float a = 0.0f; + if (threadIdx.x < bn) a = fabsf(xr[threadIdx.x]); + __shared__ float vals[32]; + vals[threadIdx.x] = a; + __syncthreads(); + for (uint32_t stride = 16; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) vals[threadIdx.x] = fmaxf(vals[threadIdx.x], vals[threadIdx.x + stride]); + __syncthreads(); + } + const float d = vals[0] / 127.0f; + const float id = d != 0.0f ? 1.0f / d : 0.0f; + if (threadIdx.x == 0) xscale[tok * blocks + b] = d; + int8_t *dst = xq + (tok * blocks + b) * 32; + if (threadIdx.x < bn) { + int v = (int)lrintf(xr[threadIdx.x] * id); + v = v > 127 ? 127 : (v < -128 ? -128 : v); + dst[threadIdx.x] = (int8_t)v; + } else { + dst[threadIdx.x] = 0; + } +} + +__global__ static void matmul_q8_0_preq_kernel( + float *out, + const unsigned char *w, + const int8_t *xq, + const float *xscale, + uint64_t in_dim, + uint64_t out_dim, + uint64_t n_tok, + uint64_t blocks, + int use_dp4a) { + uint64_t row = (uint64_t)blockIdx.x; + uint64_t tok = (uint64_t)blockIdx.y; + if (row >= out_dim || tok >= n_tok) return; + const unsigned char *wr = w + row * blocks * 34; + const int8_t *xqr = xq + tok * blocks * 32; + const float *xsr = xscale + tok * blocks; + float acc = 0.0f; + for (uint64_t b = threadIdx.x; b < blocks; b += blockDim.x) { + uint64_t i0 = b * 32; + uint64_t bn = in_dim - i0 < 32 ? in_dim - i0 : 32; + const __half *scale_h = (const __half *)(wr + b * 34); + const int8_t *qs = (const int8_t *)(wr + b * 34 + 2); + const int8_t *xqb = xqr + b * 32; + int dot = dot_i8_block(qs, xqb, bn, use_dp4a); + acc += __half2float(*scale_h) * xsr[b] * (float)dot; + } + __shared__ float partial[256]; + partial[threadIdx.x] = acc; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) out[tok * out_dim + row] = partial[0]; +} + +__global__ static void matmul_q8_0_preq_warp8_kernel( + float *out, + const unsigned char *w, + const int8_t *xq, + const float *xscale, + uint64_t in_dim, + uint64_t out_dim, + uint64_t blocks, + int use_dp4a) { + uint64_t row = (uint64_t)blockIdx.x * 8u + (threadIdx.x >> 5u); + uint32_t lane = threadIdx.x & 31u; + if (row >= out_dim) return; + const unsigned char *wr = w + row * blocks * 34; + float acc = 0.0f; + for (uint64_t b = lane; b < blocks; b += 32u) { + uint64_t i0 = b * 32; + uint64_t bn = in_dim - i0 < 32 ? in_dim - i0 : 32; + const __half *scale_h = (const __half *)(wr + b * 34); + const int8_t *qs = (const int8_t *)(wr + b * 34 + 2); + const int8_t *xqb = xq + b * 32; + int dot = dot_i8_block(qs, xqb, bn, use_dp4a); + acc += __half2float(*scale_h) * xscale[b] * (float)dot; + } + acc = warp_sum_f32(acc); + if (lane == 0) out[row] = acc; +} + +__global__ static void matmul_q8_0_pair_preq_warp8_kernel( + float *out0, + float *out1, + const unsigned char *w0, + const unsigned char *w1, + const int8_t *xq, + const float *xscale, + uint64_t in_dim, + uint64_t out0_dim, + uint64_t out1_dim, + uint64_t blocks, + int use_dp4a) { + uint64_t row = (uint64_t)blockIdx.x * 8u + (threadIdx.x >> 5u); + uint32_t lane = threadIdx.x & 31u; + if (row >= out0_dim && row >= out1_dim) return; + float acc0 = 0.0f; + float acc1 = 0.0f; + const unsigned char *wr0 = row < out0_dim ? w0 + row * blocks * 34 : NULL; + const unsigned char *wr1 = row < out1_dim ? w1 + row * blocks * 34 : NULL; + for (uint64_t b = lane; b < blocks; b += 32u) { + uint64_t i0 = b * 32; + uint64_t bn = in_dim - i0 < 32 ? in_dim - i0 : 32; + const int8_t *xqb = xq + b * 32; + const float xs = xscale[b]; + if (wr0) { + const __half *scale_h = (const __half *)(wr0 + b * 34); + const int8_t *qs = (const int8_t *)(wr0 + b * 34 + 2); + int dot = dot_i8_block(qs, xqb, bn, use_dp4a); + acc0 += __half2float(*scale_h) * xs * (float)dot; + } + if (wr1) { + const __half *scale_h = (const __half *)(wr1 + b * 34); + const int8_t *qs = (const int8_t *)(wr1 + b * 34 + 2); + int dot = dot_i8_block(qs, xqb, bn, use_dp4a); + acc1 += __half2float(*scale_h) * xs * (float)dot; + } + } + acc0 = warp_sum_f32(acc0); + acc1 = warp_sum_f32(acc1); + if (lane == 0) { + if (row < out0_dim) out0[row] = acc0; + if (row < out1_dim) out1[row] = acc1; + } +} + +__global__ static void matmul_q8_0_hc_expand_preq_warp8_kernel( + float *out_hc, + float *block_out, + const float *block_add, + const float *residual_hc, + const float *split, + const unsigned char *w, + const int8_t *xq, + const float *xscale, + uint64_t in_dim, + uint64_t out_dim, + uint32_t n_embd, + uint32_t n_hc, + uint64_t blocks, + int has_add, + int use_dp4a) { + const uint64_t row = (uint64_t)blockIdx.x * 8u + (threadIdx.x >> 5u); + const uint32_t lane = threadIdx.x & 31u; + if (row >= out_dim) return; + const unsigned char *wr = w + row * blocks * 34; + float acc = 0.0f; + for (uint64_t b = lane; b < blocks; b += 32u) { + const uint64_t i0 = b * 32; + const uint64_t bn = in_dim - i0 < 32 ? in_dim - i0 : 32; + const __half *scale_h = (const __half *)(wr + b * 34); + const int8_t *qs = (const int8_t *)(wr + b * 34 + 2); + const int8_t *xqb = xq + b * 32; + int dot = dot_i8_block(qs, xqb, bn, use_dp4a); + acc += __half2float(*scale_h) * xscale[b] * (float)dot; + } + acc = warp_sum_f32(acc); + if (lane == 0) { + const uint32_t d = (uint32_t)row; + block_out[d] = acc; + float block_v = acc; + if (has_add) block_v += block_add[d]; + const float *post = split + n_hc; + const float *comb = split + 2u * n_hc; + for (uint32_t dst_hc = 0; dst_hc < n_hc; dst_hc++) { + float hc_acc = block_v * post[dst_hc]; + for (uint32_t src_hc = 0; src_hc < n_hc; src_hc++) { + const float comb_v = comb[dst_hc + (uint64_t)src_hc * n_hc]; + const float res_v = residual_hc[(uint64_t)src_hc * n_embd + d]; + hc_acc += comb_v * res_v; + } + out_hc[(uint64_t)dst_hc * n_embd + d] = hc_acc; + } + } +} + +__global__ static void matmul_q8_0_preq_batch_warp8_kernel( + float *out, + const unsigned char *w, + const int8_t *xq, + const float *xscale, + uint64_t in_dim, + uint64_t out_dim, + uint64_t n_tok, + uint64_t blocks, + int use_dp4a) { + const uint64_t row = (uint64_t)blockIdx.x * 8u + (threadIdx.x >> 5u); + const uint64_t tok = (uint64_t)blockIdx.y; + const uint32_t lane = threadIdx.x & 31u; + if (row >= out_dim || tok >= n_tok) return; + + const unsigned char *wr = w + row * blocks * 34; + const int8_t *xqr = xq + tok * blocks * 32; + const float *xsr = xscale + tok * blocks; + float acc = 0.0f; + for (uint64_t b = lane; b < blocks; b += 32u) { + const uint64_t i0 = b * 32; + const uint64_t bn = in_dim - i0 < 32 ? in_dim - i0 : 32; + const __half *scale_h = (const __half *)(wr + b * 34); + const int8_t *qs = (const int8_t *)(wr + b * 34 + 2); + const int8_t *xqb = xqr + b * 32; + int dot = dot_i8_block(qs, xqb, bn, use_dp4a); + acc += __half2float(*scale_h) * xsr[b] * (float)dot; + } + acc = warp_sum_f32(acc); + if (lane == 0) out[tok * out_dim + row] = acc; +} + +__global__ static void dequant_q8_0_to_f16_kernel( + __half *out, + const unsigned char *w, + uint64_t in_dim, + uint64_t out_dim, + uint64_t blocks) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = in_dim * out_dim; + if (gid >= n) return; + uint64_t row = gid / in_dim; + uint64_t i = gid - row * in_dim; + uint64_t b = i / 32; + uint64_t j = i - b * 32; + const unsigned char *blk = w + (row * blocks + b) * 34; + const __half scale = *(const __half *)blk; + const int8_t q = *(const int8_t *)(blk + 2 + j); + out[gid] = __hmul(scale, __float2half((float)q)); +} + +__global__ static void dequant_q8_0_to_f32_kernel( + float *out, + const unsigned char *w, + uint64_t in_dim, + uint64_t out_dim, + uint64_t blocks) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = in_dim * out_dim; + if (gid >= n) return; + uint64_t row = gid / in_dim; + uint64_t i = gid - row * in_dim; + uint64_t b = i / 32; + uint64_t j = i - b * 32; + const unsigned char *blk = w + (row * blocks + b) * 34; + const float scale = __half2float(*(const __half *)blk); + const int8_t q = *(const int8_t *)(blk + 2 + j); + out[gid] = scale * (float)q; +} + +__global__ static void grouped_q8_0_a_preq_warp8_kernel( + float *low, + const unsigned char *w, + const int8_t *xq, + const float *xscale, + uint64_t group_dim, + uint64_t rank, + uint32_t n_groups, + uint32_t n_tokens, + uint64_t blocks, + int use_dp4a) { + const uint64_t row = (uint64_t)blockIdx.x * 8u + (threadIdx.x >> 5u); + const uint64_t tok = (uint64_t)blockIdx.y; + const uint32_t lane = threadIdx.x & 31u; + const uint64_t low_dim = (uint64_t)n_groups * rank; + if (row >= low_dim || tok >= n_tokens) return; + + const uint64_t group = row / rank; + const uint64_t row_in_group = row - group * rank; + const unsigned char *wr = w + (group * rank + row_in_group) * blocks * 34; + const uint64_t xrow = tok * (uint64_t)n_groups + group; + const int8_t *xqr = xq + xrow * blocks * 32; + const float *xsr = xscale + xrow * blocks; + float acc = 0.0f; + + for (uint64_t b = lane; b < blocks; b += 32u) { + const uint64_t i0 = b * 32; + const uint64_t bn = group_dim - i0 < 32 ? group_dim - i0 : 32; + const __half *scale_h = (const __half *)(wr + b * 34); + const int8_t *qs = (const int8_t *)(wr + b * 34 + 2); + const int8_t *xqb = xqr + b * 32; + int dot = dot_i8_block(qs, xqb, bn, use_dp4a); + acc += __half2float(*scale_h) * xsr[b] * (float)dot; + } + acc = warp_sum_f32(acc); + if (lane == 0) low[tok * low_dim + row] = acc; +} + +__global__ static void rms_norm_plain_kernel(float *out, const float *x, uint32_t n, uint32_t rows, float eps) { + uint32_t row = blockIdx.x; + if (row >= rows) return; + const float *xr = x + (uint64_t)row * n; + float *orow = out + (uint64_t)row * n; + float sum = 0.0f; + for (uint32_t i = threadIdx.x; i < n; i += blockDim.x) + sum += xr[i] * xr[i]; + for (uint32_t offset = 16u; offset > 0u; offset >>= 1) + sum += __shfl_down(sum, offset); + __shared__ float warp_sums[8]; + if ((threadIdx.x & 31u) == 0u) + warp_sums[threadIdx.x >> 5] = sum; + __syncthreads(); + if (threadIdx.x == 0) { + float total = warp_sums[0]; + for (uint32_t i = 1u; i < 8u; i++) total += warp_sums[i]; + warp_sums[0] = rsqrtf(total / (float)n + eps); + } + __syncthreads(); + float scale = warp_sums[0]; + for (uint32_t i = threadIdx.x; i < n; i += blockDim.x) + orow[i] = xr[i] * scale; +} + +__global__ static void rms_norm_weight_kernel(float *out, const float *x, const float *w, uint32_t n, uint32_t rows, float eps) { + uint32_t row = blockIdx.x; + if (row >= rows) return; + const float *xr = x + (uint64_t)row * n; + float *orow = out + (uint64_t)row * n; + float sum = 0.0f; + for (uint32_t i = threadIdx.x; i < n; i += blockDim.x) + sum += xr[i] * xr[i]; + for (uint32_t offset = 16u; offset > 0u; offset >>= 1) + sum += __shfl_down(sum, offset); + __shared__ float warp_sums[8]; + if ((threadIdx.x & 31u) == 0u) + warp_sums[threadIdx.x >> 5] = sum; + __syncthreads(); + if (threadIdx.x == 0) { + float total = warp_sums[0]; + for (uint32_t i = 1u; i < 8u; i++) total += warp_sums[i]; + warp_sums[0] = rsqrtf(total / (float)n + eps); + } + __syncthreads(); + float scale = warp_sums[0]; + for (uint32_t i = threadIdx.x; i < n; i += blockDim.x) + orow[i] = xr[i] * scale * w[i]; +} + +__global__ static void dsv4_qkv_rms_norm_rows_kernel( + float *q_out, + const float *q, + const float *q_w, + uint32_t q_n, + float *kv_out, + const float *kv, + const float *kv_w, + uint32_t kv_n, + uint32_t rows, + float eps) { + const uint32_t row = blockIdx.x; + const uint32_t which = blockIdx.y; + if (row >= rows || which > 1u) return; + const uint32_t n = which == 0u ? q_n : kv_n; + const float *xr = (which == 0u ? q : kv) + (uint64_t)row * n; + float *orow = (which == 0u ? q_out : kv_out) + (uint64_t)row * n; + const float *w = which == 0u ? q_w : kv_w; + float sum = 0.0f; + for (uint32_t i = threadIdx.x; i < n; i += blockDim.x) { + const float v = xr[i]; + sum += v * v; + } + __shared__ float partial[256]; + partial[threadIdx.x] = sum; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + const float scale = rsqrtf(partial[0] / (float)n + eps); + for (uint32_t i = threadIdx.x; i < n; i += blockDim.x) { + orow[i] = xr[i] * scale * w[i]; + } +} + +__global__ static void head_rms_norm_kernel(float *x, uint32_t n_tok, uint32_t n_head, uint32_t head_dim, float eps) { + uint32_t row = blockIdx.x; + if (row >= n_tok * n_head) return; + float *xr = x + (uint64_t)row * head_dim; + float sum = 0.0f; + for (uint32_t i = threadIdx.x; i < head_dim; i += blockDim.x) { + float v = xr[i]; + sum += v * v; + } + __shared__ float partial[256]; + partial[threadIdx.x] = sum; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + float scale = rsqrtf(partial[0] / (float)head_dim + eps); + for (uint32_t i = threadIdx.x; i < head_dim; i += blockDim.x) xr[i] *= scale; +} + +__device__ static float rope_yarn_ramp_dev(float low, float high, int i0); + +__global__ static void head_rms_norm_rope_tail_kernel( + float *x, + uint32_t n_tok, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot, + uint32_t pos0, + uint32_t n_ctx_orig, + int inverse, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float eps) { + uint32_t row = blockIdx.x; + if (row >= n_tok * n_head) return; + uint32_t t = row / n_head; + float *xr = x + (uint64_t)row * head_dim; + float sum = 0.0f; + for (uint32_t i = threadIdx.x; i < head_dim; i += blockDim.x) { + float v = xr[i]; + sum += v * v; + } + __shared__ float partial[256]; + partial[threadIdx.x] = sum; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + const float scale = rsqrtf(partial[0] / (float)head_dim + eps); + const uint32_t n_nope = head_dim - n_rot; + for (uint32_t i = threadIdx.x; i < n_nope; i += blockDim.x) { + xr[i] *= scale; + } + + float corr0 = 0.0f, corr1 = 0.0f; + if (ext_factor != 0.0f) { + float denom = 2.0f * logf(freq_base); + corr0 = floorf((float)n_rot * logf((float)n_ctx_orig / (beta_fast * 2.0f * (float)M_PI)) / denom); + corr1 = ceilf((float)n_rot * logf((float)n_ctx_orig / (beta_slow * 2.0f * (float)M_PI)) / denom); + corr0 = fmaxf(0.0f, corr0); + corr1 = fminf((float)(n_rot - 1), corr1); + } + for (uint32_t pair = threadIdx.x; pair < n_rot / 2; pair += blockDim.x) { + uint32_t i = pair * 2u; + float theta_extrap = (float)(pos0 + t) * powf(freq_base, -((float)i) / (float)n_rot); + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + float mscale = attn_factor; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp_dev(corr0, corr1, (int)i) * ext_factor; + theta = theta_interp * (1.0f - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + float c = cosf(theta) * mscale; + float s = sinf(theta) * mscale; + if (inverse) s = -s; + float *tail = xr + n_nope; + float x0 = tail[i] * scale; + float x1 = tail[i + 1] * scale; + tail[i] = x0 * c - x1 * s; + tail[i + 1] = x0 * s + x1 * c; + } +} + +__device__ static float rope_yarn_ramp_dev(float low, float high, int i0) { + float y = ((float)(i0 / 2) - low) / fmaxf(0.001f, high - low); + return 1.0f - fminf(1.0f, fmaxf(0.0f, y)); +} + +__global__ static void rope_tail_kernel( + float *x, + uint32_t n_tok, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot, + uint32_t pos0, + uint32_t n_ctx_orig, + int inverse, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t pairs = n_tok * n_head * (n_rot / 2); + if (gid >= pairs) return; + uint32_t pair = gid % (n_rot / 2); + uint32_t tmp = gid / (n_rot / 2); + uint32_t h = tmp % n_head; + uint32_t t = tmp / n_head; + uint32_t n_nope = head_dim - n_rot; + uint32_t i = pair * 2; + + float corr0 = 0.0f, corr1 = 0.0f; + if (ext_factor != 0.0f) { + float denom = 2.0f * logf(freq_base); + corr0 = floorf((float)n_rot * logf((float)n_ctx_orig / (beta_fast * 2.0f * (float)M_PI)) / denom); + corr1 = ceilf((float)n_rot * logf((float)n_ctx_orig / (beta_slow * 2.0f * (float)M_PI)) / denom); + corr0 = fmaxf(0.0f, corr0); + corr1 = fminf((float)(n_rot - 1), corr1); + } + + float theta_extrap = (float)(pos0 + t) * powf(freq_base, -((float)i) / (float)n_rot); + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + float mscale = attn_factor; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp_dev(corr0, corr1, (int)i) * ext_factor; + theta = theta_interp * (1.0f - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + float c = cosf(theta) * mscale; + float s = sinf(theta) * mscale; + if (inverse) s = -s; + + float *tail = x + ((uint64_t)t * n_head + h) * head_dim + n_nope; + float x0 = tail[i]; + float x1 = tail[i + 1]; + tail[i] = x0 * c - x1 * s; + tail[i + 1] = x0 * s + x1 * c; +} + +__device__ static float dsv4_e4m3fn_value_dev(int i) { + int exp = (i >> 3) & 15; + int mant = i & 7; + if (exp == 0) return (float)mant * 0.001953125f; + return (1.0f + (float)mant * 0.125f) * exp2f((float)exp - 7.0f); +} + +__device__ static float dsv4_e4m3fn_dequant_dev(float x) { + float sign = x < 0.0f ? -1.0f : 1.0f; + float ax = fminf(fabsf(x), 448.0f); + int lo = 0, hi = 126; + while (lo < hi) { + int mid = (lo + hi + 1) >> 1; + if (dsv4_e4m3fn_value_dev(mid) <= ax) lo = mid; + else hi = mid - 1; + } + int best = lo; + if (best < 126) { + float bd = fabsf(ax - dsv4_e4m3fn_value_dev(best)); + float nd = fabsf(ax - dsv4_e4m3fn_value_dev(best + 1)); + if (nd < bd || (nd == bd && (((best + 1) & 1) == 0) && ((best & 1) != 0))) best++; + } + return sign * dsv4_e4m3fn_value_dev(best); +} + +__device__ static float model_scalar_dev(const void *base, uint64_t offset, uint32_t type, uint64_t idx) { + const char *p = (const char *)base + offset; + if (type == 1u) return __half2float(((const __half *)p)[idx]); + return ((const float *)p)[idx]; +} + +__device__ static float rope_yarn_ramp_cpu_equiv_dev(float low, float high, int i0) { + float y = ((float)(i0 / 2) - low) / fmaxf(0.001f, high - low); + return 1.0f - fminf(1.0f, fmaxf(0.0f, y)); +} + +__device__ static DS4_HIP_UNUSED void rope_tail_one_dev(float *x, uint32_t head_dim, uint32_t n_rot, uint32_t pos, uint32_t n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow) { + uint32_t n_nope = head_dim - n_rot; + float corr0 = 0.0f, corr1 = 0.0f; + if (ext_factor != 0.0f) { + float denom = 2.0f * logf(freq_base); + corr0 = fmaxf(0.0f, floorf((float)n_rot * logf((float)n_ctx_orig / (beta_fast * 2.0f * (float)M_PI)) / denom)); + corr1 = fminf((float)(n_rot - 1), ceilf((float)n_rot * logf((float)n_ctx_orig / (beta_slow * 2.0f * (float)M_PI)) / denom)); + } + for (uint32_t i = 0; i < n_rot; i += 2) { + float theta_extrap = (float)pos * powf(freq_base, -((float)i) / (float)n_rot); + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + float mscale = attn_factor; + if (ext_factor != 0.0f) { + float mix = rope_yarn_ramp_cpu_equiv_dev(corr0, corr1, (int)i) * ext_factor; + theta = theta_interp * (1.0f - mix) + theta_extrap * mix; + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + float c = cosf(theta) * mscale; + float s = sinf(theta) * mscale; + float x0 = x[n_nope + i]; + float x1 = x[n_nope + i + 1]; + x[n_nope + i] = x0 * c - x1 * s; + x[n_nope + i + 1] = x0 * s + x1 * c; + } +} + +__global__ static void fp8_kv_quantize_kernel(float *x, uint32_t n_tok, uint32_t head_dim, uint32_t n_rot) { + uint32_t row = blockIdx.x; + uint32_t tid = threadIdx.x; + uint32_t n_nope = head_dim - n_rot; + float *xr = x + (uint64_t)row * head_dim; + __shared__ float scratch[64]; + for (uint32_t off = 0; off < n_nope; off += 64) { + float v = 0.0f; + if (off + tid < n_nope) v = xr[off + tid]; + scratch[tid] = off + tid < n_nope ? fabsf(v) : 0.0f; + __syncthreads(); + for (uint32_t stride = 32; stride > 0; stride >>= 1) { + if (tid < stride) scratch[tid] = fmaxf(scratch[tid], scratch[tid + stride]); + __syncthreads(); + } + float scale = exp2f(ceilf(log2f(fmaxf(scratch[0], 1.0e-4f) / 448.0f))); + if (off + tid < n_nope) { + float q = dsv4_e4m3fn_dequant_dev(fminf(448.0f, fmaxf(-448.0f, v / scale))) * scale; + xr[off + tid] = q; + } + __syncthreads(); + } +} + +__global__ static void store_raw_kv_batch_kernel(float *raw, const float *kv, uint32_t raw_cap, uint32_t pos0, uint32_t n_tokens, uint32_t head_dim) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_tokens * head_dim; + if (gid >= n) return; + uint32_t d = gid % head_dim; + uint32_t t = gid / head_dim; + uint32_t row = (pos0 + t) % raw_cap; + raw[(uint64_t)row * head_dim + d] = __half2float(__float2half(kv[(uint64_t)t * head_dim + d])); +} + +__global__ static void attention_prefill_raw_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + uint32_t n_tokens, + uint32_t window, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens || h >= n_head) return; + uint32_t raw_count = t + 1 < window ? t + 1 : window; + uint32_t raw_start = t + 1 - raw_count; + const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; + __shared__ float scores[DS4_HIP_ATTENTION_SCORE_CAP]; + __shared__ float partial[256]; + __shared__ float max_s; + __shared__ float denom; + float scale = rsqrtf((float)head_dim); + float local_max = sinks[h]; + __syncthreads(); + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + const float *kv = raw_kv + (uint64_t)(raw_start + r) * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kv[d]; + scores[r] = dot * scale; + local_max = fmaxf(local_max, scores[r]); + } + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + if (threadIdx.x == 0) { + float den = expf(sinks[h] - max_s); + for (uint32_t r = 0; r < raw_count; r++) { + scores[r] = expf(scores[r] - max_s); + den += scores[r]; + } + denom = den; + } + __syncthreads(); + float *oh = heads + ((uint64_t)t * n_head + h) * head_dim; + for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) { + float acc = 0.0f; + for (uint32_t r = 0; r < raw_count; r++) { + acc += raw_kv[(uint64_t)(raw_start + r) * head_dim + d] * scores[r]; + } + oh[d] = acc / denom; + } +} + +__global__ static void attention_prefill_mixed_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + const float *comp_kv, + const float *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens || h >= n_head) return; + const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; + uint32_t raw_start = (window != 0 && t + 1u > window) ? t + 1u - window : 0u; + uint32_t raw_count = t + 1u - raw_start; + uint32_t visible_comp = (t + 1u) / ratio; + if (visible_comp > n_comp) visible_comp = n_comp; + __shared__ float scores[DS4_HIP_ATTENTION_SCORE_CAP]; + __shared__ float partial[256]; + __shared__ float max_s; + __shared__ float denom; + float scale = rsqrtf((float)head_dim); + float local_max = sinks[h]; + uint32_t n_score = raw_count + visible_comp; + + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + const float *kvrow = raw_kv + (uint64_t)(raw_start + r) * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; + scores[r] = dot * scale; + local_max = fmaxf(local_max, scores[r]); + } + for (uint32_t c = threadIdx.x; c < visible_comp; c += blockDim.x) { + float add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; + float s = -INFINITY; + if (add > -1.0e20f) { + const float *kvrow = comp_kv + (uint64_t)c * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; + s = dot * scale + add; + } + scores[raw_count + c] = s; + local_max = fmaxf(local_max, s); + } + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + float den_local = 0.0f; + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + scores[i] = expf(scores[i] - max_s); + den_local += scores[i]; + } + partial[threadIdx.x] = den_local; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) denom = partial[0] + expf(sinks[h] - max_s); + __syncthreads(); + float *oh = heads + ((uint64_t)t * n_head + h) * head_dim; + for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) { + float acc = 0.0f; + for (uint32_t r = 0; r < raw_count; r++) acc += raw_kv[(uint64_t)(raw_start + r) * head_dim + d] * scores[r]; + for (uint32_t c = 0; c < visible_comp; c++) acc += comp_kv[(uint64_t)c * head_dim + d] * scores[raw_count + c]; + oh[d] = acc / denom; + } +} + +__global__ static void attention_prefill_raw_softmax_kernel( + float *scores, + const float *sinks, + uint32_t n_tokens, + uint32_t window, + uint32_t n_keys) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens) return; + float *row = scores + ((uint64_t)h * n_tokens + t) * n_keys; + __shared__ float partial[256]; + __shared__ float max_s; + __shared__ float denom; + float local_max = sinks[h]; + for (uint32_t k = threadIdx.x; k < n_keys; k += blockDim.x) { + bool valid = k <= t && (window == 0 || t - k < window); + float s = valid ? row[k] : -INFINITY; + row[k] = s; + local_max = fmaxf(local_max, s); + } + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + float den_local = 0.0f; + for (uint32_t k = threadIdx.x; k < n_keys; k += blockDim.x) { + float p = isfinite(row[k]) ? expf(row[k] - max_s) : 0.0f; + row[k] = p; + den_local += p; + } + partial[threadIdx.x] = den_local; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) denom = partial[0] + expf(sinks[h] - max_s); + __syncthreads(); + for (uint32_t k = threadIdx.x; k < n_keys; k += blockDim.x) row[k] /= denom; +} + +__global__ static void attention_prefill_mixed_softmax_kernel( + float *scores, + const float *sinks, + const float *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_keys) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens || ratio == 0) return; + float *row = scores + ((uint64_t)h * n_tokens + t) * n_keys; + __shared__ float partial[256]; + __shared__ float max_s; + __shared__ float denom; + float local_max = sinks[h]; + const uint32_t visible_comp = (t + 1u) / ratio; + for (uint32_t k = threadIdx.x; k < n_keys; k += blockDim.x) { + float s = -INFINITY; + if (k < n_tokens) { + if (k <= t && (window == 0 || t - k < window)) s = row[k]; + } else { + uint32_t c = k - n_tokens; + if (c < n_comp && c < visible_comp) { + float add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; + if (add > -1.0e20f) s = row[k] + add; + } + } + row[k] = s; + local_max = fmaxf(local_max, s); + } + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + float den_local = 0.0f; + for (uint32_t k = threadIdx.x; k < n_keys; k += blockDim.x) { + float p = isfinite(row[k]) ? expf(row[k] - max_s) : 0.0f; + row[k] = p; + den_local += p; + } + partial[threadIdx.x] = den_local; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) denom = partial[0] + expf(sinks[h] - max_s); + __syncthreads(); + for (uint32_t k = threadIdx.x; k < n_keys; k += blockDim.x) row[k] /= denom; +} + +__global__ static void attention_prefill_pack_mixed_kv_kernel( + float *dst, + const float *raw_kv, + const float *comp_kv, + uint32_t n_tokens, + uint32_t n_comp, + uint32_t head_dim) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)(n_tokens + n_comp) * head_dim; + if (gid >= n) return; + uint32_t d = gid % head_dim; + uint32_t r = gid / head_dim; + dst[gid] = r < n_tokens ? raw_kv[(uint64_t)r * head_dim + d] + : comp_kv[(uint64_t)(r - n_tokens) * head_dim + d]; +} + +__global__ static void attention_prefill_unpack_heads_kernel( + float *heads, + const float *tmp, + uint32_t n_tokens, + uint32_t n_head, + uint32_t head_dim) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_tokens * n_head * head_dim; + if (gid >= n) return; + uint32_t d = gid % head_dim; + uint64_t q = gid / head_dim; + uint32_t h = q % n_head; + uint32_t t = q / n_head; + heads[gid] = tmp[((uint64_t)h * n_tokens + t) * head_dim + d]; +} + +__global__ static void attention_pack_group_heads_f16_kernel( + __half *dst, + const float *heads, + uint32_t n_tokens, + uint32_t n_groups, + uint32_t group_dim) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_groups * n_tokens * group_dim; + if (gid >= n) return; + uint32_t d = gid % group_dim; + uint64_t q = gid / group_dim; + uint32_t t = q % n_tokens; + uint32_t g = q / n_tokens; + dst[gid] = __float2half(heads[((uint64_t)t * n_groups + g) * group_dim + d]); +} + +__global__ static void attention_unpack_group_low_kernel( + float *low, + const float *tmp, + uint32_t n_tokens, + uint32_t n_groups, + uint32_t rank) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_groups * n_tokens * rank; + if (gid >= n) return; + uint32_t r = gid % rank; + uint64_t q = gid / rank; + uint32_t t = q % n_tokens; + uint32_t g = q / n_tokens; + uint32_t low_dim = n_groups * rank; + low[(uint64_t)t * low_dim + (uint64_t)g * rank + r] = tmp[gid]; +} + +__global__ static void attention_decode_mixed_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + const float *comp_kv, + const float *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens || h >= n_head) return; + const bool single_all = (n_tokens == 1u && ratio == 0u); + uint32_t qpos = pos0 + t; + uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t visible_comp = single_all ? n_comp : (n_comp ? (qpos + 1u) / ratio : 0u); + if (visible_comp > n_comp) visible_comp = n_comp; + const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; + __shared__ float scores[DS4_HIP_ATTENTION_SCORE_CAP]; + __shared__ uint32_t raw_rows[256]; + __shared__ float partial[256]; + __shared__ float max_s; + __shared__ float denom; + __shared__ uint32_t raw_count; + __shared__ uint32_t raw_first_idx; + float scale = rsqrtf((float)head_dim); + if (threadIdx.x == 0) { + raw_count = 0; + raw_first_idx = 0; + if (n_raw != 0) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (single_all) { + raw_count = n_raw > 256u ? 256u : n_raw; + } else if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0 && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + } + __syncthreads(); + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + __syncthreads(); + uint32_t n_score = raw_count + visible_comp; + float local_max = sinks[h]; + if (visible_comp == 0 || n_tokens == 1u) { + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + const float *kvrow = raw_kv + (uint64_t)raw_rows[r] * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; + scores[r] = dot * scale; + local_max = fmaxf(local_max, scores[r]); + } + for (uint32_t c = threadIdx.x; c < visible_comp; c += blockDim.x) { + float add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; + float s = -INFINITY; + if (add > -1.0e20f) { + const float *kvrow = comp_kv + (uint64_t)c * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; + s = dot * scale + add; + } + scores[raw_count + c] = s; + local_max = fmaxf(local_max, s); + } + } else { + uint32_t qlane = threadIdx.x & 7u; + uint32_t qgroup = threadIdx.x >> 3u; + for (uint32_t row0 = 0; row0 < n_score; row0 += 32u) { + uint32_t row = row0 + qgroup; + if (row < n_score) { + float add = 0.0f; + const float *kvrow = NULL; + if (row < raw_count) { + kvrow = raw_kv + (uint64_t)raw_rows[row] * head_dim; + } else { + uint32_t c = row - raw_count; + add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; + if (add > -1.0e20f) kvrow = comp_kv + (uint64_t)c * head_dim; + } + float s = -INFINITY; + if (kvrow) { + float dot = 0.0f; + for (uint32_t d = qlane; d < head_dim; d += 8u) dot += qh[d] * kvrow[d]; + const uint32_t mask = 0xffu << (threadIdx.x & 24u); + for (uint32_t off = 4u; off > 0u; off >>= 1u) { + dot += __shfl_down(dot, off, 8); + } + s = dot * scale + add; + } + if (qlane == 0) scores[row] = s; + } + } + __syncthreads(); + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + local_max = fmaxf(local_max, scores[i]); + } + } + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + float den_local = 0.0f; + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + scores[i] = expf(scores[i] - max_s); + den_local += scores[i]; + } + partial[threadIdx.x] = den_local; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) denom = partial[0] + expf(sinks[h] - max_s); + __syncthreads(); + float *oh = heads + ((uint64_t)t * n_head + h) * head_dim; + if (head_dim == 512u && blockDim.x == 256u) { + uint32_t d0 = threadIdx.x; + uint32_t d1 = d0 + 256u; + float acc0 = 0.0f; + float acc1 = 0.0f; + for (uint32_t r = 0; r < raw_count; r++) { + float s = scores[r]; + const float *kv = raw_kv + (uint64_t)raw_rows[r] * head_dim; + acc0 += kv[d0] * s; + acc1 += kv[d1] * s; + } + for (uint32_t c = 0; c < visible_comp; c++) { + float s = scores[raw_count + c]; + const float *kv = comp_kv + (uint64_t)c * head_dim; + acc0 += kv[d0] * s; + acc1 += kv[d1] * s; + } + oh[d0] = acc0 / denom; + oh[d1] = acc1 / denom; + } else { + for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) { + float acc = 0.0f; + for (uint32_t r = 0; r < raw_count; r++) acc += raw_kv[(uint64_t)raw_rows[r] * head_dim + d] * scores[r]; + for (uint32_t c = 0; c < visible_comp; c++) acc += comp_kv[(uint64_t)c * head_dim + d] * scores[raw_count + c]; + oh[d] = acc / denom; + } + } +} + +__global__ static void attention_indexed_mixed_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + const float *comp_kv, + const int32_t *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens || h >= n_head) return; + uint32_t qpos = pos0 + t; + uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t visible_comp = n_comp; + if (ratio != 0) { + visible_comp = (qpos + 1u) / ratio; + if (visible_comp > n_comp) visible_comp = n_comp; + } + const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; + __shared__ float scores[768]; + __shared__ uint32_t raw_rows[256]; + __shared__ uint32_t comp_rows[512]; + __shared__ float partial[256]; + __shared__ float max_s; + __shared__ float denom; + __shared__ uint32_t raw_count; + __shared__ uint32_t raw_first_idx; + __shared__ uint32_t comp_count; + float scale = rsqrtf((float)head_dim); + if (threadIdx.x == 0) { + raw_count = 0; + raw_first_idx = 0; + comp_count = 0; + if (n_raw != 0) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0 && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + } + __syncthreads(); + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + for (uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { + int32_t c = topk[(uint64_t)t * top_k + i]; + if (c >= 0 && (uint32_t)c < visible_comp) { + uint32_t slot = atomicAdd(&comp_count, 1u); + if (slot < 512u) comp_rows[slot] = (uint32_t)c; + } + } + __syncthreads(); + if (threadIdx.x == 0) { + if (comp_count > 512u) comp_count = 512u; + } + __syncthreads(); + uint32_t n_score = raw_count + comp_count; + float local_max = sinks[h]; + if (comp_count == 0) { + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + const float *kvrow = raw_kv + (uint64_t)raw_rows[r] * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; + scores[r] = dot * scale; + local_max = fmaxf(local_max, scores[r]); + } + } else { + uint32_t qlane = threadIdx.x & 7u; + uint32_t qgroup = threadIdx.x >> 3u; + for (uint32_t row0 = 0; row0 < n_score; row0 += 32u) { + uint32_t row = row0 + qgroup; + if (row < n_score) { + const float *kvrow = row < raw_count + ? raw_kv + (uint64_t)raw_rows[row] * head_dim + : comp_kv + (uint64_t)comp_rows[row - raw_count] * head_dim; + float dot = 0.0f; + for (uint32_t d = qlane; d < head_dim; d += 8u) dot += qh[d] * kvrow[d]; + const uint32_t mask = 0xffu << (threadIdx.x & 24u); + for (uint32_t off = 4u; off > 0u; off >>= 1u) { + dot += __shfl_down(dot, off, 8); + } + if (qlane == 0) scores[row] = dot * scale; + } + } + __syncthreads(); + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + local_max = fmaxf(local_max, scores[i]); + } + } + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + float den_local = 0.0f; + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + scores[i] = expf(scores[i] - max_s); + den_local += scores[i]; + } + partial[threadIdx.x] = den_local; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) denom = partial[0] + expf(sinks[h] - max_s); + __syncthreads(); + float *oh = heads + ((uint64_t)t * n_head + h) * head_dim; + if (head_dim == 512u && blockDim.x == 256u) { + uint32_t d0 = threadIdx.x; + uint32_t d1 = d0 + 256u; + float acc0 = 0.0f; + float acc1 = 0.0f; + for (uint32_t r = 0; r < raw_count; r++) { + float s = scores[r]; + const float *kv = raw_kv + (uint64_t)raw_rows[r] * head_dim; + acc0 += kv[d0] * s; + acc1 += kv[d1] * s; + } + for (uint32_t c = 0; c < comp_count; c++) { + float s = scores[raw_count + c]; + const float *kv = comp_kv + (uint64_t)comp_rows[c] * head_dim; + acc0 += kv[d0] * s; + acc1 += kv[d1] * s; + } + oh[d0] = acc0 / denom; + oh[d1] = acc1 / denom; + } else { + for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) { + float acc = 0.0f; + for (uint32_t r = 0; r < raw_count; r++) acc += raw_kv[(uint64_t)raw_rows[r] * head_dim + d] * scores[r]; + for (uint32_t s = 0; s < comp_count; s++) acc += comp_kv[(uint64_t)comp_rows[s] * head_dim + d] * scores[raw_count + s]; + oh[d] = acc / denom; + } + } +} + +__global__ static void attention_indexed_mixed_heads8_rb4_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + const float *comp_kv, + const int32_t *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t head_group = blockIdx.y; + if (t >= n_tokens || head_dim != 512u) return; + const uint32_t lane = threadIdx.x & 31u; + const uint32_t warp = threadIdx.x >> 5u; + const uint32_t head = head_group * 8u + warp; + const bool valid_head = head < n_head; + + __shared__ uint32_t raw_rows[256]; + __shared__ uint32_t comp_rows[512]; + __shared__ uint32_t raw_count; + __shared__ uint32_t raw_first_idx; + __shared__ uint32_t comp_count; + __shared__ float4 kv_shared[4 * 128]; + __shared__ float scores[8 * 768]; + + uint32_t qpos = pos0 + t; + uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t visible_comp = n_comp; + if (ratio != 0) { + visible_comp = (qpos + 1u) / ratio; + if (visible_comp > n_comp) visible_comp = n_comp; + } + + if (threadIdx.x == 0) { + raw_count = 0; + raw_first_idx = 0; + comp_count = 0; + if (n_raw != 0) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0 && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + } + __syncthreads(); + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + if (threadIdx.x == 0) { + for (uint32_t i = 0; i < top_k && comp_count < 512u; i++) { + int32_t c = topk[(uint64_t)t * top_k + i]; + if (c >= 0 && (uint32_t)c < visible_comp) comp_rows[comp_count++] = (uint32_t)c; + } + } + __syncthreads(); + + const uint32_t n_score = raw_count + comp_count; + const float scale = rsqrtf((float)head_dim); + const float4 *q4 = valid_head + ? (const float4 *)(q + ((uint64_t)t * n_head + head) * head_dim) + : NULL; + float4 q0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 q1 = q0, q2 = q0, q3 = q0; + if (valid_head) { + q0 = q4[lane + 0u]; + q1 = q4[lane + 32u]; + q2 = q4[lane + 64u]; + q3 = q4[lane + 96u]; + } + + for (uint32_t row0 = 0; row0 < n_score; row0 += 4u) { + const uint32_t nr = n_score - row0 < 4u ? n_score - row0 : 4u; + for (uint32_t off = threadIdx.x; off < nr * 128u; off += blockDim.x) { + const uint32_t rr = off >> 7u; + const uint32_t c4 = off & 127u; + const uint32_t sr = row0 + rr; + const float4 *src = sr < raw_count + ? (const float4 *)(raw_kv + (uint64_t)raw_rows[sr] * head_dim) + : (const float4 *)(comp_kv + (uint64_t)comp_rows[sr - raw_count] * head_dim); + kv_shared[off] = src[c4]; + } + __syncthreads(); + if (valid_head) { + for (uint32_t rr = 0; rr < nr; rr++) { + const float4 *kv4 = kv_shared + rr * 128u; + float dot = dot4_f32(q0, kv4[lane + 0u]) + + dot4_f32(q1, kv4[lane + 32u]) + + dot4_f32(q2, kv4[lane + 64u]) + + dot4_f32(q3, kv4[lane + 96u]); + dot = warp_sum_f32(dot); + if (lane == 0) scores[warp * 768u + row0 + rr] = dot * scale; + } + } + __syncthreads(); + } + + float max_s = valid_head ? sinks[head] : -INFINITY; + if (valid_head) { + const float *score_row = scores + warp * 768u; + for (uint32_t i = lane; i < n_score; i += 32u) max_s = fmaxf(max_s, score_row[i]); + max_s = warp_max_f32(max_s); + max_s = __shfl_sync(0xffffffffffffffffULL, max_s, 0); + } + float den = 0.0f; + if (valid_head) { + float *score_row = scores + warp * 768u; + for (uint32_t i = lane; i < n_score; i += 32u) { + float p = expf(score_row[i] - max_s); + score_row[i] = p; + den += p; + } + den = warp_sum_f32(den); + den += expf(sinks[head] - max_s); + den = __shfl_sync(0xffffffffffffffffULL, den, 0); + } + + float4 o0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 o1 = o0, o2 = o0, o3 = o0; + for (uint32_t row0 = 0; row0 < n_score; row0 += 4u) { + const uint32_t nr = n_score - row0 < 4u ? n_score - row0 : 4u; + for (uint32_t off = threadIdx.x; off < nr * 128u; off += blockDim.x) { + const uint32_t rr = off >> 7u; + const uint32_t c4 = off & 127u; + const uint32_t sr = row0 + rr; + const float4 *src = sr < raw_count + ? (const float4 *)(raw_kv + (uint64_t)raw_rows[sr] * head_dim) + : (const float4 *)(comp_kv + (uint64_t)comp_rows[sr - raw_count] * head_dim); + kv_shared[off] = src[c4]; + } + __syncthreads(); + if (valid_head) { + const float *score_row = scores + warp * 768u; + for (uint32_t rr = 0; rr < nr; rr++) { + const float p = den == 0.0f ? 0.0f : score_row[row0 + rr] / den; + const float4 *kv4 = kv_shared + rr * 128u; + float4 k0 = kv4[lane + 0u]; + float4 k1 = kv4[lane + 32u]; + float4 k2 = kv4[lane + 64u]; + float4 k3 = kv4[lane + 96u]; + o0.x += k0.x * p; o0.y += k0.y * p; o0.z += k0.z * p; o0.w += k0.w * p; + o1.x += k1.x * p; o1.y += k1.y * p; o1.z += k1.z * p; o1.w += k1.w * p; + o2.x += k2.x * p; o2.y += k2.y * p; o2.z += k2.z * p; o2.w += k2.w * p; + o3.x += k3.x * p; o3.y += k3.y * p; o3.z += k3.z * p; o3.w += k3.w * p; + } + } + __syncthreads(); + } + if (valid_head) { + float4 *out4 = (float4 *)(heads + ((uint64_t)t * n_head + head) * head_dim); + out4[lane + 0u] = o0; + out4[lane + 32u] = o1; + out4[lane + 64u] = o2; + out4[lane + 96u] = o3; + } +} + +__global__ static void attention_indexed_mixed_heads8_online_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + const float *comp_kv, + const int32_t *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t head_group = blockIdx.y; + if (t >= n_tokens || head_dim != 512u) return; + const uint32_t lane = threadIdx.x & 31u; + const uint32_t warp = threadIdx.x >> 5u; + const uint32_t head = head_group * 8u + warp; + const bool valid_head = head < n_head; + + __shared__ uint32_t raw_rows[256]; + __shared__ uint32_t comp_rows[512]; + __shared__ uint32_t raw_count; + __shared__ uint32_t raw_first_idx; + __shared__ uint32_t comp_count; + __shared__ float4 kv_shared[4 * 128]; + + uint32_t qpos = pos0 + t; + uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t visible_comp = n_comp; + if (ratio != 0) { + visible_comp = (qpos + 1u) / ratio; + if (visible_comp > n_comp) visible_comp = n_comp; + } + + if (threadIdx.x == 0) { + raw_count = 0; + raw_first_idx = 0; + comp_count = 0; + if (n_raw != 0) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0 && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + } + __syncthreads(); + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + if (threadIdx.x == 0) { + for (uint32_t i = 0; i < top_k && comp_count < 512u; i++) { + int32_t c = topk[(uint64_t)t * top_k + i]; + if (c >= 0 && (uint32_t)c < visible_comp) comp_rows[comp_count++] = (uint32_t)c; + } + } + __syncthreads(); + + const uint32_t n_score = raw_count + comp_count; + const float scale = rsqrtf((float)head_dim); + const float4 *q4 = valid_head + ? (const float4 *)(q + ((uint64_t)t * n_head + head) * head_dim) + : NULL; + float4 q0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 q1 = q0, q2 = q0, q3 = q0; + if (valid_head) { + q0 = q4[lane + 0u]; + q1 = q4[lane + 32u]; + q2 = q4[lane + 64u]; + q3 = q4[lane + 96u]; + } + + float max_s = -INFINITY; + float sum_s = 0.0f; + float4 o0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 o1 = o0, o2 = o0, o3 = o0; + + for (uint32_t row0 = 0; row0 < n_score; row0 += 4u) { + const uint32_t nr = n_score - row0 < 4u ? n_score - row0 : 4u; + for (uint32_t off = threadIdx.x; off < nr * 128u; off += blockDim.x) { + const uint32_t rr = off >> 7u; + const uint32_t c4 = off & 127u; + const uint32_t sr = row0 + rr; + const float4 *src = sr < raw_count + ? (const float4 *)(raw_kv + (uint64_t)raw_rows[sr] * head_dim) + : (const float4 *)(comp_kv + (uint64_t)comp_rows[sr - raw_count] * head_dim); + kv_shared[off] = src[c4]; + } + __syncthreads(); + if (valid_head) { + for (uint32_t rr = 0; rr < nr; rr++) { + const float4 *kv4 = kv_shared + rr * 128u; + float4 k0 = kv4[lane + 0u]; + float4 k1 = kv4[lane + 32u]; + float4 k2 = kv4[lane + 64u]; + float4 k3 = kv4[lane + 96u]; + float score = dot4_f32(q0, k0) + + dot4_f32(q1, k1) + + dot4_f32(q2, k2) + + dot4_f32(q3, k3); + score = warp_sum_f32(score) * scale; + score = __shfl_sync(0xffffffffffffffffULL, score, 0); + + const float new_m = fmaxf(max_s, score); + const float old_scale = expf(max_s - new_m); + const float row_scale = expf(score - new_m); + sum_s = sum_s * old_scale + row_scale; + o0.x = o0.x * old_scale + k0.x * row_scale; + o0.y = o0.y * old_scale + k0.y * row_scale; + o0.z = o0.z * old_scale + k0.z * row_scale; + o0.w = o0.w * old_scale + k0.w * row_scale; + o1.x = o1.x * old_scale + k1.x * row_scale; + o1.y = o1.y * old_scale + k1.y * row_scale; + o1.z = o1.z * old_scale + k1.z * row_scale; + o1.w = o1.w * old_scale + k1.w * row_scale; + o2.x = o2.x * old_scale + k2.x * row_scale; + o2.y = o2.y * old_scale + k2.y * row_scale; + o2.z = o2.z * old_scale + k2.z * row_scale; + o2.w = o2.w * old_scale + k2.w * row_scale; + o3.x = o3.x * old_scale + k3.x * row_scale; + o3.y = o3.y * old_scale + k3.y * row_scale; + o3.z = o3.z * old_scale + k3.z * row_scale; + o3.w = o3.w * old_scale + k3.w * row_scale; + max_s = new_m; + } + } + __syncthreads(); + } + + if (valid_head) { + const float sink = sinks[head]; + const float new_m = fmaxf(max_s, sink); + const float old_scale = expf(max_s - new_m); + const float sink_scale = expf(sink - new_m); + sum_s = sum_s * old_scale + sink_scale; + o0.x *= old_scale; o0.y *= old_scale; o0.z *= old_scale; o0.w *= old_scale; + o1.x *= old_scale; o1.y *= old_scale; o1.z *= old_scale; o1.w *= old_scale; + o2.x *= old_scale; o2.y *= old_scale; o2.z *= old_scale; o2.w *= old_scale; + o3.x *= old_scale; o3.y *= old_scale; o3.z *= old_scale; o3.w *= old_scale; + + const float inv_s = sum_s == 0.0f ? 0.0f : 1.0f / sum_s; + o0.x *= inv_s; o0.y *= inv_s; o0.z *= inv_s; o0.w *= inv_s; + o1.x *= inv_s; o1.y *= inv_s; o1.z *= inv_s; o1.w *= inv_s; + o2.x *= inv_s; o2.y *= inv_s; o2.z *= inv_s; o2.w *= inv_s; + o3.x *= inv_s; o3.y *= inv_s; o3.z *= inv_s; o3.w *= inv_s; + float4 *out4 = (float4 *)(heads + ((uint64_t)t * n_head + head) * head_dim); + out4[lane + 0u] = o0; + out4[lane + 32u] = o1; + out4[lane + 64u] = o2; + out4[lane + 96u] = o3; + } +} + +__global__ static void attention_static_mixed_heads8_online_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + const float *comp_kv, + uint32_t n_tokens, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t head_group = blockIdx.y; + if (t >= n_tokens || head_dim != 512u) return; + const uint32_t lane = threadIdx.x & 31u; + const uint32_t warp = threadIdx.x >> 5u; + const uint32_t head = head_group * 8u + warp; + const bool valid_head = head < n_head; + + __shared__ float4 kv_shared[4 * 128]; + + const uint32_t raw_count = window != 0u && t + 1u > window ? window : t + 1u; + const uint32_t raw_start = t + 1u - raw_count; + uint32_t comp_count = 0; + if (n_comp != 0u && ratio != 0u) { + comp_count = (t + 1u) / ratio; + if (comp_count > n_comp) comp_count = n_comp; + } + const uint32_t n_score = raw_count + comp_count; + const float scale = rsqrtf((float)head_dim); + const float4 *q4 = valid_head + ? (const float4 *)(q + ((uint64_t)t * n_head + head) * head_dim) + : NULL; + float4 q0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 q1 = q0, q2 = q0, q3 = q0; + if (valid_head) { + q0 = q4[lane + 0u]; + q1 = q4[lane + 32u]; + q2 = q4[lane + 64u]; + q3 = q4[lane + 96u]; + } + + float max_s = -INFINITY; + float sum_s = 0.0f; + float4 o0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 o1 = o0, o2 = o0, o3 = o0; + + for (uint32_t row0 = 0; row0 < n_score; row0 += 4u) { + const uint32_t nr = n_score - row0 < 4u ? n_score - row0 : 4u; + for (uint32_t off = threadIdx.x; off < nr * 128u; off += blockDim.x) { + const uint32_t rr = off >> 7u; + const uint32_t c4 = off & 127u; + const uint32_t sr = row0 + rr; + const float4 *src = sr < raw_count + ? (const float4 *)(raw_kv + (uint64_t)(raw_start + sr) * head_dim) + : (const float4 *)(comp_kv + (uint64_t)(sr - raw_count) * head_dim); + kv_shared[off] = src[c4]; + } + __syncthreads(); + if (valid_head) { + for (uint32_t rr = 0; rr < nr; rr++) { + const float4 *kv4 = kv_shared + rr * 128u; + float4 k0 = kv4[lane + 0u]; + float4 k1 = kv4[lane + 32u]; + float4 k2 = kv4[lane + 64u]; + float4 k3 = kv4[lane + 96u]; + float score = dot4_f32(q0, k0) + + dot4_f32(q1, k1) + + dot4_f32(q2, k2) + + dot4_f32(q3, k3); + score = warp_sum_f32(score) * scale; + score = __shfl_sync(0xffffffffffffffffULL, score, 0); + + const float new_m = fmaxf(max_s, score); + const float old_scale = expf(max_s - new_m); + const float row_scale = expf(score - new_m); + sum_s = sum_s * old_scale + row_scale; + o0.x = o0.x * old_scale + k0.x * row_scale; + o0.y = o0.y * old_scale + k0.y * row_scale; + o0.z = o0.z * old_scale + k0.z * row_scale; + o0.w = o0.w * old_scale + k0.w * row_scale; + o1.x = o1.x * old_scale + k1.x * row_scale; + o1.y = o1.y * old_scale + k1.y * row_scale; + o1.z = o1.z * old_scale + k1.z * row_scale; + o1.w = o1.w * old_scale + k1.w * row_scale; + o2.x = o2.x * old_scale + k2.x * row_scale; + o2.y = o2.y * old_scale + k2.y * row_scale; + o2.z = o2.z * old_scale + k2.z * row_scale; + o2.w = o2.w * old_scale + k2.w * row_scale; + o3.x = o3.x * old_scale + k3.x * row_scale; + o3.y = o3.y * old_scale + k3.y * row_scale; + o3.z = o3.z * old_scale + k3.z * row_scale; + o3.w = o3.w * old_scale + k3.w * row_scale; + max_s = new_m; + } + } + __syncthreads(); + } + + if (valid_head) { + const float sink = sinks[head]; + const float new_m = fmaxf(max_s, sink); + const float old_scale = expf(max_s - new_m); + const float sink_scale = expf(sink - new_m); + sum_s = sum_s * old_scale + sink_scale; + o0.x *= old_scale; o0.y *= old_scale; o0.z *= old_scale; o0.w *= old_scale; + o1.x *= old_scale; o1.y *= old_scale; o1.z *= old_scale; o1.w *= old_scale; + o2.x *= old_scale; o2.y *= old_scale; o2.z *= old_scale; o2.w *= old_scale; + o3.x *= old_scale; o3.y *= old_scale; o3.z *= old_scale; o3.w *= old_scale; + + const float inv_s = sum_s == 0.0f ? 0.0f : 1.0f / sum_s; + o0.x *= inv_s; o0.y *= inv_s; o0.z *= inv_s; o0.w *= inv_s; + o1.x *= inv_s; o1.y *= inv_s; o1.z *= inv_s; o1.w *= inv_s; + o2.x *= inv_s; o2.y *= inv_s; o2.z *= inv_s; o2.w *= inv_s; + o3.x *= inv_s; o3.y *= inv_s; o3.z *= inv_s; o3.w *= inv_s; + float4 *out4 = (float4 *)(heads + ((uint64_t)t * n_head + head) * head_dim); + out4[lane + 0u] = o0; + out4[lane + 32u] = o1; + out4[lane + 64u] = o2; + out4[lane + 96u] = o3; + } +} + +__global__ static void attention_decode_mixed_heads8_online_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + const float *comp_kv, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t head_group = blockIdx.y; + if (t >= n_tokens || head_dim != 512u) return; + const uint32_t lane = threadIdx.x & 31u; + const uint32_t warp = threadIdx.x >> 5u; + const uint32_t head = head_group * 8u + warp; + const bool valid_head = head < n_head; + + __shared__ uint32_t raw_rows[256]; + __shared__ uint32_t raw_count_s; + __shared__ uint32_t raw_first_idx_s; + __shared__ float4 kv_shared[4 * 128]; + + const uint32_t qpos = pos0 + t; + const uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t comp_count = 0; + if (n_comp != 0u) { + if (n_tokens == 1u && ratio == 0u) { + comp_count = n_comp; + } else if (ratio != 0u) { + comp_count = (qpos + 1u) / ratio; + if (comp_count > n_comp) comp_count = n_comp; + } + } + if (threadIdx.x == 0) { + uint32_t raw_count = 0; + uint32_t raw_first_idx = 0; + if (n_raw != 0u) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0u && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + raw_count_s = raw_count; + raw_first_idx_s = raw_first_idx; + } + __syncthreads(); + const uint32_t raw_count = raw_count_s; + const uint32_t raw_first_idx = raw_first_idx_s; + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + __syncthreads(); + + const uint32_t n_score = raw_count + comp_count; + const float scale = rsqrtf((float)head_dim); + const float4 *q4 = valid_head + ? (const float4 *)(q + ((uint64_t)t * n_head + head) * head_dim) + : NULL; + float4 q0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 q1 = q0, q2 = q0, q3 = q0; + if (valid_head) { + q0 = q4[lane + 0u]; + q1 = q4[lane + 32u]; + q2 = q4[lane + 64u]; + q3 = q4[lane + 96u]; + } + + float max_s = -INFINITY; + float sum_s = 0.0f; + float4 o0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 o1 = o0, o2 = o0, o3 = o0; + + for (uint32_t row0 = 0; row0 < n_score; row0 += 4u) { + const uint32_t nr = n_score - row0 < 4u ? n_score - row0 : 4u; + for (uint32_t off = threadIdx.x; off < nr * 128u; off += blockDim.x) { + const uint32_t rr = off >> 7u; + const uint32_t c4 = off & 127u; + const uint32_t sr = row0 + rr; + const float4 *src = sr < raw_count + ? (const float4 *)(raw_kv + (uint64_t)raw_rows[sr] * head_dim) + : (const float4 *)(comp_kv + (uint64_t)(sr - raw_count) * head_dim); + kv_shared[off] = src[c4]; + } + __syncthreads(); + if (valid_head) { + for (uint32_t rr = 0; rr < nr; rr++) { + const float4 *kv4 = kv_shared + rr * 128u; + float4 k0 = kv4[lane + 0u]; + float4 k1 = kv4[lane + 32u]; + float4 k2 = kv4[lane + 64u]; + float4 k3 = kv4[lane + 96u]; + float score = dot4_f32(q0, k0) + + dot4_f32(q1, k1) + + dot4_f32(q2, k2) + + dot4_f32(q3, k3); + score = warp_sum_f32(score) * scale; + score = __shfl_sync(0xffffffffffffffffULL, score, 0); + + const float new_m = fmaxf(max_s, score); + const float old_scale = expf(max_s - new_m); + const float row_scale = expf(score - new_m); + sum_s = sum_s * old_scale + row_scale; + o0.x = o0.x * old_scale + k0.x * row_scale; + o0.y = o0.y * old_scale + k0.y * row_scale; + o0.z = o0.z * old_scale + k0.z * row_scale; + o0.w = o0.w * old_scale + k0.w * row_scale; + o1.x = o1.x * old_scale + k1.x * row_scale; + o1.y = o1.y * old_scale + k1.y * row_scale; + o1.z = o1.z * old_scale + k1.z * row_scale; + o1.w = o1.w * old_scale + k1.w * row_scale; + o2.x = o2.x * old_scale + k2.x * row_scale; + o2.y = o2.y * old_scale + k2.y * row_scale; + o2.z = o2.z * old_scale + k2.z * row_scale; + o2.w = o2.w * old_scale + k2.w * row_scale; + o3.x = o3.x * old_scale + k3.x * row_scale; + o3.y = o3.y * old_scale + k3.y * row_scale; + o3.z = o3.z * old_scale + k3.z * row_scale; + o3.w = o3.w * old_scale + k3.w * row_scale; + max_s = new_m; + } + } + __syncthreads(); + } + + if (valid_head) { + const float sink = sinks[head]; + const float new_m = fmaxf(max_s, sink); + const float old_scale = expf(max_s - new_m); + const float sink_scale = expf(sink - new_m); + sum_s = sum_s * old_scale + sink_scale; + o0.x *= old_scale; o0.y *= old_scale; o0.z *= old_scale; o0.w *= old_scale; + o1.x *= old_scale; o1.y *= old_scale; o1.z *= old_scale; o1.w *= old_scale; + o2.x *= old_scale; o2.y *= old_scale; o2.z *= old_scale; o2.w *= old_scale; + o3.x *= old_scale; o3.y *= old_scale; o3.z *= old_scale; o3.w *= old_scale; + + const float inv_s = sum_s == 0.0f ? 0.0f : 1.0f / sum_s; + o0.x *= inv_s; o0.y *= inv_s; o0.z *= inv_s; o0.w *= inv_s; + o1.x *= inv_s; o1.y *= inv_s; o1.z *= inv_s; o1.w *= inv_s; + o2.x *= inv_s; o2.y *= inv_s; o2.z *= inv_s; o2.w *= inv_s; + o3.x *= inv_s; o3.y *= inv_s; o3.z *= inv_s; o3.w *= inv_s; + float4 *out4 = (float4 *)(heads + ((uint64_t)t * n_head + head) * head_dim); + out4[lane + 0u] = o0; + out4[lane + 32u] = o1; + out4[lane + 64u] = o2; + out4[lane + 96u] = o3; + } +} + +__device__ static void hc4_split_one(float *out, const float *mix, const float *scale, const float *base, uint32_t sinkhorn_iters, float epsv) { + const float pre_scale = scale[0]; + const float post_scale = scale[1]; + const float comb_scale = scale[2]; + for (int i = 0; i < 4; i++) { + float z = mix[i] * pre_scale + base[i]; + out[i] = 1.0f / (1.0f + expf(-z)) + epsv; + } + for (int i = 0; i < 4; i++) { + float z = mix[4 + i] * post_scale + base[4 + i]; + out[4 + i] = 2.0f / (1.0f + expf(-z)); + } + float c[16]; + for (int r = 0; r < 4; r++) { + float m = -INFINITY; + for (int col = 0; col < 4; col++) { + float v = mix[8 + r * 4 + col] * comb_scale + base[8 + r * 4 + col]; + c[r * 4 + col] = v; + m = fmaxf(m, v); + } + float s = 0.0f; + for (int col = 0; col < 4; col++) { + float v = expf(c[r * 4 + col] - m); + c[r * 4 + col] = v; + s += v; + } + for (int col = 0; col < 4; col++) c[r * 4 + col] = c[r * 4 + col] / s + epsv; + } + for (int col = 0; col < 4; col++) { + float s = epsv; + for (int r = 0; r < 4; r++) s += c[r * 4 + col]; + for (int r = 0; r < 4; r++) c[r * 4 + col] /= s; + } + for (uint32_t iter = 1; iter < sinkhorn_iters; iter++) { + for (int r = 0; r < 4; r++) { + float s = epsv; + for (int col = 0; col < 4; col++) s += c[r * 4 + col]; + for (int col = 0; col < 4; col++) c[r * 4 + col] /= s; + } + for (int col = 0; col < 4; col++) { + float s = epsv; + for (int r = 0; r < 4; r++) s += c[r * 4 + col]; + for (int r = 0; r < 4; r++) c[r * 4 + col] /= s; + } + } + for (int i = 0; i < 16; i++) out[8 + i] = c[i]; +} + +__global__ static void hc_split_sinkhorn_kernel(float *out, const float *mix, const float *scale, const float *base, uint32_t n_rows, uint32_t sinkhorn_iters, float epsv) { + uint32_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + hc4_split_one(out + (uint64_t)row * 24, mix + (uint64_t)row * 24, scale, base, sinkhorn_iters, epsv); +} + +__global__ static void hc_weighted_sum_kernel(float *out, const float *x, const float *w, uint32_t n_embd, uint32_t n_hc, uint32_t n_tokens, uint32_t weight_stride_f32) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_embd * n_tokens; + if (gid >= n) return; + uint32_t d = gid % n_embd; + uint32_t t = gid / n_embd; + float acc = 0.0f; + for (uint32_t h = 0; h < n_hc; h++) { + acc += x[(uint64_t)t * n_hc * n_embd + (uint64_t)h * n_embd + d] * + w[(uint64_t)t * weight_stride_f32 + h]; + } + out[(uint64_t)t * n_embd + d] = acc; +} + +__global__ static void hc_expand_kernel( + float *out_hc, + const float *block_out, + const float *block_add, + const float *residual_hc, + const float *post, + const float *comb, + uint32_t n_embd, + uint32_t n_hc, + uint32_t n_tokens, + uint32_t post_stride, + uint32_t comb_stride, + int has_add) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n_elem = (uint64_t)n_tokens * n_hc * n_embd; + if (gid >= n_elem) return; + uint32_t d = gid % n_embd; + uint64_t tmp = gid / n_embd; + uint32_t dst_hc = tmp % n_hc; + uint32_t t = tmp / n_hc; + + float block_v = block_out[(uint64_t)t * n_embd + d]; + if (has_add) block_v += block_add[(uint64_t)t * n_embd + d]; + float acc = block_v * post[(uint64_t)t * post_stride + dst_hc]; + for (uint32_t src_hc = 0; src_hc < n_hc; src_hc++) { + float comb_v = comb[(uint64_t)t * comb_stride + dst_hc + (uint64_t)src_hc * n_hc]; + float res_v = residual_hc[(uint64_t)t * n_hc * n_embd + (uint64_t)src_hc * n_embd + d]; + acc += comb_v * res_v; + } + out_hc[(uint64_t)t * n_hc * n_embd + (uint64_t)dst_hc * n_embd + d] = acc; +} + +__global__ static void hc_split_weighted_sum_fused_kernel( + float *out, + float *split, + const float *mix, + const float *residual_hc, + const float *scale, + const float *base, + uint32_t n_embd, + uint32_t n_hc, + uint32_t n_rows, + uint32_t sinkhorn_iters, + float epsv) { + uint32_t t = blockIdx.x; + uint32_t d = threadIdx.x; + if (t >= n_rows || n_hc != 4) return; + const uint32_t mix_hc = 24; + float *sp = split + (uint64_t)t * mix_hc; + if (d == 0) hc4_split_one(sp, mix + (uint64_t)t * mix_hc, scale, base, sinkhorn_iters, epsv); + __syncthreads(); + for (uint32_t col = d; col < n_embd; col += blockDim.x) { + float acc = 0.0f; + for (uint32_t h = 0; h < 4; h++) { + acc += residual_hc[(uint64_t)t * 4u * n_embd + (uint64_t)h * n_embd + col] * sp[h]; + } + out[(uint64_t)t * n_embd + col] = acc; + } +} + +__global__ static void hc_split_weighted_sum_norm_fused_kernel( + float *out, + float *norm_out, + float *split, + const float *mix, + const float *residual_hc, + const float *scale, + const float *base, + const float *norm_w, + uint32_t n_embd, + uint32_t n_hc, + uint32_t n_rows, + uint32_t sinkhorn_iters, + float epsv, + float norm_eps) { + const uint32_t t = blockIdx.x; + const uint32_t d = threadIdx.x; + if (t >= n_rows || n_hc != 4) return; + const uint32_t mix_hc = 24; + float *sp = split + (uint64_t)t * mix_hc; + if (d == 0) hc4_split_one(sp, mix + (uint64_t)t * mix_hc, scale, base, sinkhorn_iters, epsv); + __syncthreads(); + + float sum = 0.0f; + for (uint32_t col = d; col < n_embd; col += blockDim.x) { + float acc = 0.0f; + for (uint32_t h = 0; h < 4; h++) { + acc += residual_hc[(uint64_t)t * 4u * n_embd + (uint64_t)h * n_embd + col] * sp[h]; + } + out[(uint64_t)t * n_embd + col] = acc; + sum += acc * acc; + } + + __shared__ float partial[256]; + partial[d] = sum; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (d < stride) partial[d] += partial[d + stride]; + __syncthreads(); + } + const float norm_scale = rsqrtf(partial[0] / (float)n_embd + norm_eps); + for (uint32_t col = d; col < n_embd; col += blockDim.x) { + const float v = out[(uint64_t)t * n_embd + col]; + norm_out[(uint64_t)t * n_embd + col] = v * norm_scale * norm_w[col]; + } +} + +__global__ static void output_hc_weights_kernel( + float *out, + const float *pre, + const float *scale, + const float *base, + uint32_t n_hc, + uint32_t n_tokens, + float epsv) { + uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t n = n_tokens * n_hc; + if (gid >= n) return; + uint32_t h = gid % n_hc; + float z = pre[gid] * scale[0] + base[h]; + out[gid] = 1.0f / (1.0f + expf(-z)) + epsv; +} + +__global__ static void fill_f32_kernel(float *x, uint64_t n, float v) { + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) x[i] = v; +} + +__global__ static void compressor_store_kernel( + const float *kv, + const float *sc, + float *state_kv, + float *state_score, + const void *model_map, + uint64_t ape_offset, + uint32_t ape_type, + uint32_t head_dim, + uint32_t ratio, + uint32_t pos0, + uint32_t n_tokens) { + uint32_t coff = ratio == 4u ? 2u : 1u; + uint32_t width = coff * head_dim; + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_tokens * width; + if (gid >= n) return; + uint32_t t = gid / width; + uint32_t j = gid - (uint64_t)t * width; + uint32_t pos_mod = (pos0 + t) % ratio; + uint32_t dst_row = ratio == 4u ? ratio + pos_mod : pos_mod; + state_kv[(uint64_t)dst_row * width + j] = kv[(uint64_t)t * width + j]; + state_score[(uint64_t)dst_row * width + j] = + sc[(uint64_t)t * width + j] + model_scalar_dev(model_map, ape_offset, ape_type, (uint64_t)pos_mod * width + j); +} + +__global__ static void compressor_set_rows_kernel( + float *state_kv, + float *state_score, + const float *kv, + const float *sc, + const void *model_map, + uint64_t ape_offset, + uint32_t ape_type, + uint32_t width, + uint32_t ratio, + uint32_t pos0, + uint32_t src0, + uint32_t dst0, + uint32_t rows) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)rows * width; + if (gid >= n) return; + uint32_t r = gid / width; + uint32_t j = gid - (uint64_t)r * width; + uint32_t src = src0 + r; + uint32_t dst = dst0 + r; + uint32_t phase = (pos0 + src) % ratio; + state_kv[(uint64_t)dst * width + j] = kv[(uint64_t)src * width + j]; + state_score[(uint64_t)dst * width + j] = + sc[(uint64_t)src * width + j] + model_scalar_dev(model_map, ape_offset, ape_type, (uint64_t)phase * width + j); +} + +__global__ static void compressor_prefill_pool_kernel( + float *comp, + const float *kv, + const float *sc, + const float *state_kv, + const float *state_score, + const void *model_map, + uint64_t ape_offset, + uint32_t ape_type, + uint32_t head_dim, + uint32_t ratio, + uint32_t pos0, + uint32_t n_comp, + uint32_t replay) { + uint32_t d = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t c = blockIdx.y; + if (d >= head_dim || c >= n_comp) return; + uint32_t coff = ratio == 4u ? 2u : 1u; + uint32_t width = coff * head_dim; + float vals[128]; + float scores[128]; + float max_s = -INFINITY; + uint32_t n_cand = 0; + if (ratio == 4u) { + if (replay && c == 0) { + for (uint32_t r = 0; r < 4; r++) { + vals[n_cand] = state_kv[(uint64_t)r * width + d]; + scores[n_cand] = state_score[(uint64_t)r * width + d]; + max_s = fmaxf(max_s, scores[n_cand++]); + } + } else if (c > 0) { + uint32_t base = (c - 1u) * ratio; + for (uint32_t r = 0; r < 4; r++) { + uint32_t t = base + r; + float ape = model_scalar_dev(model_map, ape_offset, ape_type, (uint64_t)((pos0 + t) % ratio) * width + d); + vals[n_cand] = kv[(uint64_t)t * width + d]; + scores[n_cand] = sc[(uint64_t)t * width + d] + ape; + max_s = fmaxf(max_s, scores[n_cand++]); + } + } + uint32_t base = c * ratio; + for (uint32_t r = 0; r < 4; r++) { + uint32_t t = base + r; + float ape = model_scalar_dev(model_map, ape_offset, ape_type, (uint64_t)((pos0 + t) % ratio) * width + head_dim + d); + vals[n_cand] = kv[(uint64_t)t * width + head_dim + d]; + scores[n_cand] = sc[(uint64_t)t * width + head_dim + d] + ape; + max_s = fmaxf(max_s, scores[n_cand++]); + } + } else { + uint32_t base = c * ratio; + for (uint32_t r = 0; r < ratio; r++) { + uint32_t t = base + r; + float ape = model_scalar_dev(model_map, ape_offset, ape_type, (uint64_t)((pos0 + t) % ratio) * width + d); + vals[n_cand] = kv[(uint64_t)t * width + d]; + scores[n_cand] = sc[(uint64_t)t * width + d] + ape; + max_s = fmaxf(max_s, scores[n_cand++]); + } + } + float den = 0.0f, acc = 0.0f; + for (uint32_t i = 0; i < n_cand; i++) { + float w = expf(scores[i] - max_s); + den += w; + acc += vals[i] * w; + } + comp[(uint64_t)c * head_dim + d] = den != 0.0f ? acc / den : 0.0f; +} + +__global__ static void compressor_update_pool_kernel( + float *row, + const float *state_kv, + const float *state_score, + uint32_t head_dim, + uint32_t ratio) { + uint32_t d = blockIdx.x * blockDim.x + threadIdx.x; + if (d >= head_dim) return; + uint32_t coff = ratio == 4u ? 2u : 1u; + uint32_t width = coff * head_dim; + float vals[128]; + float scores[128]; + float max_s = -INFINITY; + uint32_t n_cand = 0; + if (ratio == 4u) { + for (uint32_t r = 0; r < 4; r++) { + vals[n_cand] = state_kv[(uint64_t)r * width + d]; + scores[n_cand] = state_score[(uint64_t)r * width + d]; + max_s = fmaxf(max_s, scores[n_cand++]); + } + for (uint32_t r = 0; r < 4; r++) { + vals[n_cand] = state_kv[(uint64_t)(ratio + r) * width + head_dim + d]; + scores[n_cand] = state_score[(uint64_t)(ratio + r) * width + head_dim + d]; + max_s = fmaxf(max_s, scores[n_cand++]); + } + } else { + for (uint32_t r = 0; r < ratio; r++) { + vals[n_cand] = state_kv[(uint64_t)r * width + d]; + scores[n_cand] = state_score[(uint64_t)r * width + d]; + max_s = fmaxf(max_s, scores[n_cand++]); + } + } + float den = 0.0f, acc = 0.0f; + for (uint32_t i = 0; i < n_cand; i++) { + float w = expf(scores[i] - max_s); + den += w; + acc += vals[i] * w; + } + row[d] = den != 0.0f ? acc / den : 0.0f; +} + +__global__ static void compressor_shift_ratio4_kernel(float *state_kv, float *state_score, uint32_t width) { + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t half = 4ull * width; + if (i >= half) return; + float v = state_kv[half + i]; + float s = state_score[half + i]; + state_kv[i] = v; + state_score[i] = s; + state_kv[half + i] = v; + state_score[half + i] = s; +} + +__device__ static float softplus_dev(float x) { + if (x > 20.0f) return x; + if (x < -20.0f) return expf(x); + return log1pf(expf(x)); +} + +__global__ static void router_select_kernel( + int32_t *selected, + float *weights, + float *probs, + const float *bias, + const int32_t *hash, + const float *logits, + const int32_t *tokens, + int32_t token_scalar, + uint32_t hash_rows, + uint32_t n_tokens, + int has_bias, + int hash_mode) { + uint32_t t = blockIdx.x; + if (t >= n_tokens || threadIdx.x != 0) return; + const float *log = logits + (uint64_t)t * 256; + float *prob = probs + (uint64_t)t * 256; + int32_t *sel = selected + (uint64_t)t * 6; + float *w = weights + (uint64_t)t * 6; + + for (int i = 0; i < 256; i++) prob[i] = sqrtf(softplus_dev(log[i])); + + if (hash_mode) { + int32_t tok = tokens ? tokens[t] : token_scalar; + if (tok < 0 || (uint32_t)tok >= hash_rows) tok = 0; + const int32_t *row = hash + (uint64_t)tok * 6; + for (int i = 0; i < 6; i++) sel[i] = row[i]; + } else { + for (int i = 0; i < 6; i++) sel[i] = -1; + for (int i = 0; i < 256; i++) { + float score = prob[i] + (has_bias ? bias[i] : 0.0f); + for (int j = 0; j < 6; j++) { + if (sel[j] < 0 || score > prob[sel[j]] + (has_bias ? bias[sel[j]] : 0.0f)) { + for (int k = 5; k > j; k--) sel[k] = sel[k - 1]; + sel[j] = i; + break; + } + } + } + } + + float sum = 0.0f; + for (int i = 0; i < 6; i++) { + int e = sel[i]; + float v = (e >= 0 && e < 256) ? prob[e] : 0.0f; + w[i] = v; + sum += v; + } + sum = fmaxf(sum, 6.103515625e-5f); + for (int i = 0; i < 6; i++) w[i] = w[i] / sum * 1.5f; +} + +__global__ static void router_select_parallel_kernel( + int32_t *selected, + float *weights, + float *probs, + const float *bias, + const int32_t *hash, + const float *logits, + const int32_t *tokens, + int32_t token_scalar, + uint32_t hash_rows, + uint32_t n_tokens, + int has_bias, + int hash_mode) { + uint32_t t = blockIdx.x; + uint32_t i = threadIdx.x; + if (t >= n_tokens || i >= 256u) return; + const float *log = logits + (uint64_t)t * 256; + float *prob = probs + (uint64_t)t * 256; + int32_t *sel = selected + (uint64_t)t * 6; + float *w = weights + (uint64_t)t * 6; + __shared__ float sprob[256]; + + const float p = sqrtf(softplus_dev(log[i])); + sprob[i] = p; + prob[i] = p; + __syncthreads(); + + if (i != 0) return; + if (hash_mode) { + int32_t tok = tokens ? tokens[t] : token_scalar; + if (tok < 0 || (uint32_t)tok >= hash_rows) tok = 0; + const int32_t *row = hash + (uint64_t)tok * 6; + for (int j = 0; j < 6; j++) sel[j] = row[j]; + } else { + for (int j = 0; j < 6; j++) sel[j] = -1; + for (int e = 0; e < 256; e++) { + float score = sprob[e] + (has_bias ? bias[e] : 0.0f); + for (int j = 0; j < 6; j++) { + if (sel[j] < 0 || score > sprob[sel[j]] + (has_bias ? bias[sel[j]] : 0.0f)) { + for (int k = 5; k > j; k--) sel[k] = sel[k - 1]; + sel[j] = e; + break; + } + } + } + } + + float sum = 0.0f; + for (int j = 0; j < 6; j++) { + int e = sel[j]; + float v = (e >= 0 && e < 256) ? sprob[e] : 0.0f; + w[j] = v; + sum += v; + } + sum = fmaxf(sum, 6.103515625e-5f); + for (int j = 0; j < 6; j++) w[j] = w[j] / sum * 1.5f; +} + +__device__ __forceinline__ static bool router_score_better(float av, uint32_t ai, float bv, uint32_t bi) { + return av > bv || (av == bv && ai < bi); +} + +__global__ static void router_select_warp_topk_kernel( + int32_t *selected, + float *weights, + float *probs, + const float *bias, + const int32_t *hash, + const float *logits, + const int32_t *tokens, + int32_t token_scalar, + uint32_t hash_rows, + uint32_t n_tokens, + int has_bias, + int hash_mode) { + const uint32_t lane = threadIdx.x; + const uint32_t row_in_block = threadIdx.y; + const uint32_t t = blockIdx.x * blockDim.y + row_in_block; + if (t >= n_tokens || lane >= 32u) return; + + const float *log = logits + (uint64_t)t * 256u; + float *prob = probs + (uint64_t)t * 256u; + int32_t *sel = selected + (uint64_t)t * 6u; + float *w = weights + (uint64_t)t * 6u; + __shared__ float sprob[4][256]; + float local_prob[8]; + float local_score[8]; + + #pragma unroll + for (uint32_t j = 0; j < 8u; j++) { + const uint32_t e = lane + j * 32u; + const float p = sqrtf(softplus_dev(log[e])); + local_prob[j] = p; + local_score[j] = p + (has_bias ? bias[e] : 0.0f); + sprob[row_in_block][e] = p; + prob[e] = p; + } + __syncwarp(); + + if (hash_mode) { + if (lane == 0) { + int32_t tok = tokens ? tokens[t] : token_scalar; + if (tok < 0 || (uint32_t)tok >= hash_rows) tok = 0; + const int32_t *row = hash + (uint64_t)tok * 6u; + float sum = 0.0f; + #pragma unroll + for (uint32_t j = 0; j < 6u; j++) { + const int32_t e = row[j]; + sel[j] = e; + const float v = (e >= 0 && e < 256) ? sprob[row_in_block][(uint32_t)e] : 0.0f; + w[j] = v; + sum += v; + } + sum = fmaxf(sum, 6.103515625e-5f); + #pragma unroll + for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * 1.5f; + } + return; + } + + float out_prob[6] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + uint32_t out_idx[6] = {0, 0, 0, 0, 0, 0}; + #pragma unroll + for (uint32_t k = 0; k < 6u; k++) { + float best_score = -INFINITY; + float best_prob = 0.0f; + uint32_t best_idx = UINT32_MAX; + #pragma unroll + for (uint32_t j = 0; j < 8u; j++) { + const uint32_t e = lane + j * 32u; + const float s = local_score[j]; + if (router_score_better(s, e, best_score, best_idx)) { + best_score = s; + best_prob = local_prob[j]; + best_idx = e; + } + } + #pragma unroll + for (uint32_t mask = 16u; mask > 0u; mask >>= 1u) { + const float other_score = __shfl_xor_sync(0xffffffffffffffffULL, best_score, mask); + const float other_prob = __shfl_xor_sync(0xffffffffffffffffULL, best_prob, mask); + const uint32_t other_idx = __shfl_xor_sync(0xffffffffffffffffULL, best_idx, mask); + if (router_score_better(other_score, other_idx, best_score, best_idx)) { + best_score = other_score; + best_prob = other_prob; + best_idx = other_idx; + } + } + #pragma unroll + for (uint32_t j = 0; j < 8u; j++) { + const uint32_t e = lane + j * 32u; + if (e == best_idx) local_score[j] = -INFINITY; + } + if (lane == 0) { + out_idx[k] = best_idx; + out_prob[k] = best_prob; + } + } + + if (lane == 0) { + float sum = 0.0f; + #pragma unroll + for (uint32_t j = 0; j < 6u; j++) { + sel[j] = (int32_t)out_idx[j]; + w[j] = out_prob[j]; + sum += out_prob[j]; + } + sum = fmaxf(sum, 6.103515625e-5f); + #pragma unroll + for (uint32_t j = 0; j < 6u; j++) w[j] = w[j] / sum * 1.5f; + } +} + +__global__ static void swiglu_kernel(float *out, const float *gate, const float *up, uint32_t n, float clamp, float weight) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n) return; + float g = gate[i]; + float u = up[i]; + if (clamp > 1.0e-6f) { + g = fminf(g, clamp); + u = fminf(fmaxf(u, -clamp), clamp); + } + float s = g / (1.0f + expf(-g)); + out[i] = s * u * weight; +} + +__global__ static void add_kernel(float *out, const float *a, const float *b, uint32_t n) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n) return; + out[i] = a[i] + b[i]; +} + +__global__ static void directional_steering_project_kernel( + float *x, + const float *directions, + uint32_t layer, + uint32_t width, + uint32_t rows, + float scale) { + const uint32_t row = blockIdx.x; + if (row >= rows || width == 0) return; + + float *xr = x + (uint64_t)row * width; + const float *dir = directions + (uint64_t)layer * width; + float sum = 0.0f; + for (uint32_t i = threadIdx.x; i < width; i += blockDim.x) { + sum += xr[i] * dir[i]; + } + + __shared__ float partial[256]; + partial[threadIdx.x] = sum; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + + const float coeff = scale * partial[0]; + for (uint32_t i = threadIdx.x; i < width; i += blockDim.x) { + xr[i] -= coeff * dir[i]; + } +} + +__global__ static void zero_kernel(float *out, uint64_t n) { + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) out[i] = 0.0f; +} + +__global__ static void indexer_scores_kernel( + float *scores, + const float *q, + const float *weights, + const float *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale, + int causal) { + uint32_t c = blockIdx.x; + uint32_t t = blockIdx.y; + if (c >= n_comp || t >= n_tokens) return; + if (causal) { + uint32_t n_visible = (pos0 + t + 1u) / ratio; + if (c >= n_visible) { + if (threadIdx.x == 0) scores[(uint64_t)t * n_comp + c] = -INFINITY; + return; + } + } + float total = 0.0f; + for (uint32_t h = 0; h < n_head; h++) { + const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; + const float *kh = index_comp + (uint64_t)c * head_dim; + float dot = 0.0f; + for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) dot += qh[d] * kh[d]; + __shared__ float partial[256]; + partial[threadIdx.x] = dot; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + total += fmaxf(partial[0], 0.0f) * weights[(uint64_t)t * n_head + h]; + __syncthreads(); + } + if (threadIdx.x == 0) scores[(uint64_t)t * n_comp + c] = total * scale; +} + +__global__ static void indexer_score_one_direct_kernel( + float *scores, + const float *q, + const float *weights, + const float *index_comp, + uint32_t n_comp, + uint32_t pos0, + uint32_t ratio, + float scale, + int causal) { + const uint32_t c = blockIdx.x; + const uint32_t tid = threadIdx.x; + const uint32_t lane = tid & 31u; + const uint32_t warp = tid >> 5u; + if (c >= n_comp || tid >= 128u) return; + if (causal) { + const uint32_t visible = ratio ? (pos0 + 1u) / ratio : n_comp; + if (c >= visible) { + if (tid == 0) scores[c] = -INFINITY; + return; + } + } + + __shared__ float krow[128]; + __shared__ float partial[4]; + if (tid < 128u) krow[tid] = index_comp[(uint64_t)c * 128u + tid]; + __syncthreads(); + + float total = 0.0f; + for (uint32_t h0 = 0; h0 < 64u; h0 += 4u) { + const uint32_t h = h0 + warp; + const float4 qv = ((const float4 *)(q + (uint64_t)h * 128u))[lane]; + const float4 kv = ((const float4 *)krow)[lane]; + float dot = qv.x * kv.x + qv.y * kv.y + qv.z * kv.z + qv.w * kv.w; + dot = warp_sum_f32(dot); + if (lane == 0) partial[warp] = fmaxf(dot, 0.0f) * weights[h] * scale; + __syncthreads(); + if (tid == 0) total += partial[0] + partial[1] + partial[2] + partial[3]; + __syncthreads(); + } + if (tid == 0) scores[c] = total; +} + +__global__ static void indexer_scores_wmma_kernel( + float *scores, + const float *q, + const float *weights, + const float *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale, + int causal) { +#if DS4_HAVE_ROCWMMA + namespace wmma = rocwmma; + const uint32_t tile_c = blockIdx.x * 16u; + const uint32_t tile_t = blockIdx.y * 16u; + const uint32_t tid = threadIdx.x; + if (tid >= 32u || head_dim != 128u) return; + + if (causal) { + const uint32_t last_token = min(tile_t + 16u, n_tokens); + const uint32_t max_visible = last_token > tile_t + ? min((pos0 + last_token) / ratio, n_comp) + : 0u; + if (tile_c >= max_visible) { + for (uint32_t i = tid; i < 16u * 16u; i += 32u) { + const uint32_t r = i >> 4u; + const uint32_t c = i & 15u; + const uint32_t token = tile_t + r; + const uint32_t comp = tile_c + c; + if (token < n_tokens && comp < n_comp) { + scores[(uint64_t)token * n_comp + comp] = -INFINITY; + } + } + return; + } + } + + __shared__ __half a_sh[16 * 128]; + __shared__ __half b_sh[16 * 128]; + __shared__ float c_sh[16 * 16]; + __shared__ float acc_sh[16 * 16]; + + for (uint32_t i = tid; i < 16u * 16u; i += 32u) acc_sh[i] = 0.0f; + __syncthreads(); + + for (uint32_t h = 0; h < n_head; h++) { + for (uint32_t i = tid; i < 16u * 128u; i += 32u) { + const uint32_t r = i >> 7u; + const uint32_t d = i & 127u; + const uint32_t token = tile_t + r; + float v = 0.0f; + if (token < n_tokens) { + v = q[((uint64_t)token * n_head + h) * head_dim + d]; + } + a_sh[i] = __float2half(v); + } + for (uint32_t i = tid; i < 16u * 128u; i += 32u) { + const uint32_t c = i >> 7u; + const uint32_t d = i & 127u; + const uint32_t comp = tile_c + c; + float v = 0.0f; + if (comp < n_comp) v = index_comp[(uint64_t)comp * head_dim + d]; + b_sh[d + c * 128u] = __float2half(v); + } + __syncthreads(); + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + for (uint32_t k0 = 0; k0 < 128u; k0 += 16u) { + wmma::load_matrix_sync(a_frag, a_sh + k0, 128); + wmma::load_matrix_sync(b_frag, b_sh + k0, 128); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + wmma::store_matrix_sync(c_sh, c_frag, 16, wmma::mem_row_major); + __syncthreads(); + + for (uint32_t i = tid; i < 16u * 16u; i += 32u) { + const uint32_t r = i >> 4u; + const uint32_t token = tile_t + r; + if (token < n_tokens) { + const float w = weights[(uint64_t)token * n_head + h]; + acc_sh[i] += fmaxf(c_sh[i], 0.0f) * w; + } + } + __syncthreads(); + } + + for (uint32_t i = tid; i < 16u * 16u; i += 32u) { + const uint32_t r = i >> 4u; + const uint32_t c = i & 15u; + const uint32_t token = tile_t + r; + const uint32_t comp = tile_c + c; + if (token < n_tokens && comp < n_comp) { + float out = acc_sh[i] * scale; + if (causal) { + const uint32_t visible = (pos0 + token + 1u) / ratio; + if (comp >= visible) out = -INFINITY; + } + scores[(uint64_t)token * n_comp + comp] = out; + } + } +#endif +} + +__global__ static void indexer_topk_kernel(uint32_t *selected, const float *scores, uint32_t n_comp, uint32_t n_tokens, uint32_t top_k) { + uint32_t t = blockIdx.x; + if (t >= n_tokens || threadIdx.x != 0) return; + const float *row = scores + (uint64_t)t * n_comp; + uint32_t *sel = selected + (uint64_t)t * top_k; + for (uint32_t k = 0; k < top_k; k++) sel[k] = 0; + for (uint32_t c = 0; c < n_comp; c++) { + float v = row[c]; + for (uint32_t k = 0; k < top_k; k++) { + if ((k >= c) || v > row[sel[k]]) { + for (uint32_t j = top_k - 1; j > k; j--) sel[j] = sel[j - 1]; + sel[k] = c; + break; + } + } + } +} + +__device__ __forceinline__ static bool topk_score_better(float av, uint32_t ai, float bv, uint32_t bi) { + return av > bv || (av == bv && ai < bi); +} + +__global__ static void indexer_topk_1024_kernel( + uint32_t *selected, + const float *scores, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t top_k) { + uint32_t t = blockIdx.x; + uint32_t tid = threadIdx.x; + if (t >= n_tokens || tid >= 1024u) return; + __shared__ float vals[1024]; + __shared__ uint32_t idxs[1024]; + + const float *row = scores + (uint64_t)t * n_comp; + if (tid < n_comp) { + vals[tid] = row[tid]; + idxs[tid] = tid; + } else { + vals[tid] = -INFINITY; + idxs[tid] = UINT32_MAX; + } + __syncthreads(); + + for (uint32_t k = 2u; k <= 1024u; k <<= 1u) { + for (uint32_t j = k >> 1u; j > 0u; j >>= 1u) { + uint32_t other = tid ^ j; + if (other > tid && other < 1024u) { + const float av = vals[tid]; + const float bv = vals[other]; + const uint32_t ai = idxs[tid]; + const uint32_t bi = idxs[other]; + const bool desc_half = (tid & k) == 0u; + const bool swap = desc_half + ? topk_score_better(bv, bi, av, ai) + : topk_score_better(av, ai, bv, bi); + if (swap) { + vals[tid] = bv; + idxs[tid] = bi; + vals[other] = av; + idxs[other] = ai; + } + } + __syncthreads(); + } + } + + if (tid < top_k) selected[(uint64_t)t * top_k + tid] = idxs[tid]; +} + +template +__global__ static void indexer_topk_pow2_kernel( + uint32_t *selected, + const float *scores, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t top_k) { + uint32_t t = blockIdx.x; + uint32_t tid = threadIdx.x; + if (t >= n_tokens) return; + __shared__ float vals[SORT_N]; + __shared__ uint32_t idxs[SORT_N]; + + const float *row = scores + (uint64_t)t * n_comp; + for (uint32_t i = tid; i < SORT_N; i += blockDim.x) { + if (i < n_comp) { + vals[i] = row[i]; + idxs[i] = i; + } else { + vals[i] = -INFINITY; + idxs[i] = UINT32_MAX; + } + } + __syncthreads(); + + for (uint32_t k = 2u; k <= SORT_N; k <<= 1u) { + for (uint32_t j = k >> 1u; j > 0u; j >>= 1u) { + for (uint32_t i = tid; i < SORT_N; i += blockDim.x) { + uint32_t other = i ^ j; + if (other > i && other < SORT_N) { + const float av = vals[i]; + const float bv = vals[other]; + const uint32_t ai = idxs[i]; + const uint32_t bi = idxs[other]; + const bool desc_half = (i & k) == 0u; + const bool swap = desc_half + ? topk_score_better(bv, bi, av, ai) + : topk_score_better(av, ai, bv, bi); + if (swap) { + vals[i] = bv; + idxs[i] = bi; + vals[other] = av; + idxs[other] = ai; + } + } + } + __syncthreads(); + } + } + + for (uint32_t i = tid; i < top_k; i += blockDim.x) { + selected[(uint64_t)t * top_k + i] = idxs[i]; + } +} + +template +__global__ static void indexer_topk_chunk_pow2_kernel( + uint32_t *candidates, + const float *scores, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t top_k, + uint32_t candidate_stride) { + uint32_t t = blockIdx.x; + uint32_t chunk = blockIdx.y; + uint32_t tid = threadIdx.x; + if (t >= n_tokens) return; + + const uint32_t chunk_start = chunk * SORT_N; + if (chunk_start >= n_comp) return; + const uint32_t chunk_n = n_comp - chunk_start < SORT_N ? n_comp - chunk_start : SORT_N; + __shared__ float vals[SORT_N]; + __shared__ uint32_t idxs[SORT_N]; + + const float *row = scores + (uint64_t)t * n_comp; + for (uint32_t i = tid; i < SORT_N; i += blockDim.x) { + if (i < chunk_n) { + vals[i] = row[chunk_start + i]; + idxs[i] = chunk_start + i; + } else { + vals[i] = -INFINITY; + idxs[i] = UINT32_MAX; + } + } + __syncthreads(); + + for (uint32_t k = 2u; k <= SORT_N; k <<= 1u) { + for (uint32_t j = k >> 1u; j > 0u; j >>= 1u) { + for (uint32_t i = tid; i < SORT_N; i += blockDim.x) { + uint32_t other = i ^ j; + if (other > i && other < SORT_N) { + const float av = vals[i]; + const float bv = vals[other]; + const uint32_t ai = idxs[i]; + const uint32_t bi = idxs[other]; + const bool desc_half = (i & k) == 0u; + const bool swap = desc_half + ? topk_score_better(bv, bi, av, ai) + : topk_score_better(av, ai, bv, bi); + if (swap) { + vals[i] = bv; + idxs[i] = bi; + vals[other] = av; + idxs[other] = ai; + } + } + } + __syncthreads(); + } + } + + uint32_t *out = candidates + (uint64_t)t * candidate_stride + chunk * top_k; + for (uint32_t i = tid; i < top_k; i += blockDim.x) { + out[i] = idxs[i]; + } +} + +template +__global__ static void indexer_topk_merge_pow2_kernel( + uint32_t *selected, + const uint32_t *candidates, + const float *scores, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t top_k, + uint32_t candidate_count, + uint32_t candidate_stride) { + uint32_t t = blockIdx.x; + uint32_t tid = threadIdx.x; + if (t >= n_tokens) return; + __shared__ float vals[SORT_N]; + __shared__ uint32_t idxs[SORT_N]; + + const float *row = scores + (uint64_t)t * n_comp; + const uint32_t *cand = candidates + (uint64_t)t * candidate_stride; + for (uint32_t i = tid; i < SORT_N; i += blockDim.x) { + uint32_t idx = UINT32_MAX; + float v = -INFINITY; + if (i < candidate_count) { + idx = cand[i]; + if (idx < n_comp) v = row[idx]; + } + vals[i] = v; + idxs[i] = idx; + } + __syncthreads(); + + for (uint32_t k = 2u; k <= SORT_N; k <<= 1u) { + for (uint32_t j = k >> 1u; j > 0u; j >>= 1u) { + for (uint32_t i = tid; i < SORT_N; i += blockDim.x) { + uint32_t other = i ^ j; + if (other > i && other < SORT_N) { + const float av = vals[i]; + const float bv = vals[other]; + const uint32_t ai = idxs[i]; + const uint32_t bi = idxs[other]; + const bool desc_half = (i & k) == 0u; + const bool swap = desc_half + ? topk_score_better(bv, bi, av, ai) + : topk_score_better(av, ai, bv, bi); + if (swap) { + vals[i] = bv; + idxs[i] = bi; + vals[other] = av; + idxs[other] = ai; + } + } + } + __syncthreads(); + } + } + + for (uint32_t i = tid; i < top_k; i += blockDim.x) { + selected[(uint64_t)t * top_k + i] = idxs[i]; + } +} + +template +__global__ static void indexer_topk_tree_merge_pow2_kernel( + uint32_t *out, + const uint32_t *candidates, + const float *scores, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t top_k, + uint32_t n_sets, + uint32_t merge_group, + uint32_t candidate_stride, + uint32_t out_stride) { + uint32_t t = blockIdx.x; + uint32_t group = blockIdx.y; + uint32_t tid = threadIdx.x; + if (t >= n_tokens) return; + + const uint32_t set0 = group * merge_group; + if (set0 >= n_sets) return; + uint32_t set_count = n_sets - set0; + if (set_count > merge_group) set_count = merge_group; + const uint32_t candidate_count = set_count * top_k; + + __shared__ float vals[SORT_N]; + __shared__ uint32_t idxs[SORT_N]; + + const float *row = scores + (uint64_t)t * n_comp; + const uint32_t *cand = candidates + (uint64_t)t * candidate_stride + set0 * top_k; + for (uint32_t i = tid; i < SORT_N; i += blockDim.x) { + uint32_t idx = UINT32_MAX; + float v = -INFINITY; + if (i < candidate_count) { + idx = cand[i]; + if (idx < n_comp) v = row[idx]; + } + vals[i] = v; + idxs[i] = idx; + } + __syncthreads(); + + for (uint32_t k = 2u; k <= SORT_N; k <<= 1u) { + for (uint32_t j = k >> 1u; j > 0u; j >>= 1u) { + for (uint32_t i = tid; i < SORT_N; i += blockDim.x) { + uint32_t other = i ^ j; + if (other > i && other < SORT_N) { + const float av = vals[i]; + const float bv = vals[other]; + const uint32_t ai = idxs[i]; + const uint32_t bi = idxs[other]; + const bool desc_half = (i & k) == 0u; + const bool swap = desc_half + ? topk_score_better(bv, bi, av, ai) + : topk_score_better(av, ai, bv, bi); + if (swap) { + vals[i] = bv; + idxs[i] = bi; + vals[other] = av; + idxs[other] = ai; + } + } + } + __syncthreads(); + } + } + + uint32_t *dst = out + (uint64_t)t * out_stride + group * top_k; + for (uint32_t i = tid; i < top_k; i += blockDim.x) { + dst[i] = idxs[i]; + } +} + +__global__ static void topk_mask_kernel(float *mask, const uint32_t *topk, uint32_t n_comp, uint32_t n_tokens, uint32_t top_k) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_tokens * n_comp; + if (gid >= n) return; + uint32_t t = gid / n_comp; + uint32_t c = gid - (uint64_t)t * n_comp; + float v = -INFINITY; + for (uint32_t k = 0; k < top_k; k++) { + if (topk[(uint64_t)t * top_k + k] == c) { + v = 0.0f; + break; + } + } + mask[gid] = v; +} + +extern "C" int ds4_gpu_embed_token_hc_tensor(ds4_gpu_tensor *out_hc, const void *model_map, uint64_t model_size, uint64_t weight_offset, uint32_t n_vocab, uint32_t token, uint32_t n_embd, uint32_t n_hc) { + (void)n_vocab; + if (!out_hc || !model_map || weight_offset >= model_size) return 0; + uint64_t weight_bytes = (uint64_t)n_vocab * n_embd * sizeof(uint16_t); + if (weight_offset > model_size || weight_bytes > model_size - weight_offset) return 0; + const char *wptr = hip_model_range_ptr(model_map, weight_offset, weight_bytes, "token_embd"); + if (!wptr) return 0; + uint32_t n = n_embd * n_hc; + embed_token_hc_kernel<<<(n + 255) / 256, 256>>>((float *)out_hc->ptr, (const unsigned short *)wptr, token, n_embd, n_hc); + return hip_ok(hipGetLastError(), "embed token launch"); +} + +extern "C" int ds4_gpu_embed_tokens_hc_tensor( + ds4_gpu_tensor *out_hc, + const ds4_gpu_tensor *tokens_t, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint32_t n_vocab, + uint32_t n_tokens, + uint32_t n_embd, + uint32_t n_hc) { + if (!out_hc || !tokens_t || !model_map || + weight_offset > model_size || + (uint64_t)n_vocab * n_embd * sizeof(uint16_t) > model_size - weight_offset || + tokens_t->bytes < (uint64_t)n_tokens * sizeof(int32_t) || + out_hc->bytes < (uint64_t)n_tokens * n_hc * n_embd * sizeof(float)) { + return 0; + } + const char *wptr = hip_model_range_ptr(model_map, weight_offset, + (uint64_t)n_vocab * n_embd * sizeof(uint16_t), + "token_embd"); + if (!wptr) return 0; + uint64_t n = (uint64_t)n_tokens * n_hc * n_embd; + embed_tokens_hc_kernel<<<(n + 255) / 256, 256>>>( + (float *)out_hc->ptr, + (const int32_t *)tokens_t->ptr, + (const __half *)wptr, + n_vocab, n_tokens, n_embd, n_hc); + return hip_ok(hipGetLastError(), "embed tokens launch"); +} + +static int indexer_scores_launch( + ds4_gpu_tensor *scores, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *weights, + const ds4_gpu_tensor *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale, + uint32_t causal) { + if (!scores || !q || !weights || !index_comp || + n_comp == 0 || n_tokens == 0 || n_head == 0 || head_dim == 0 || + q->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + weights->bytes < (uint64_t)n_tokens * n_head * sizeof(float) || + index_comp->bytes < (uint64_t)n_comp * head_dim * sizeof(float) || + scores->bytes < (uint64_t)n_tokens * n_comp * sizeof(float)) { + return 0; + } + if (causal && ratio == 0) return 0; + if (n_tokens == 1u && head_dim == 128u && n_head == 64u && + getenv("DS4_HIP_NO_INDEXER_DIRECT_ONE") == NULL) { + indexer_score_one_direct_kernel<<>>((float *)scores->ptr, + (const float *)q->ptr, + (const float *)weights->ptr, + (const float *)index_comp->ptr, + n_comp, pos0, ratio, + scale, causal ? 1 : 0); + return hip_ok(hipGetLastError(), "indexer score one direct launch"); + } + if (!g_quality_mode && head_dim == 128u && n_head == 64u && + getenv("DS4_HIP_NO_INDEXER_WMMA") == NULL) { + dim3 grid((n_comp + 15u) / 16u, (n_tokens + 15u) / 16u, 1); + indexer_scores_wmma_kernel<<>>((float *)scores->ptr, + (const float *)q->ptr, + (const float *)weights->ptr, + (const float *)index_comp->ptr, + n_comp, n_tokens, pos0, n_head, + head_dim, ratio, scale, causal ? 1 : 0); + return hip_ok(hipGetLastError(), "indexer scores wmma launch"); + } + dim3 grid(n_comp, n_tokens, 1); + indexer_scores_kernel<<>>((float *)scores->ptr, + (const float *)q->ptr, + (const float *)weights->ptr, + (const float *)index_comp->ptr, + n_comp, n_tokens, pos0, n_head, + head_dim, ratio, scale, causal ? 1 : 0); + return hip_ok(hipGetLastError(), "indexer scores launch"); +} + +extern "C" int ds4_gpu_indexer_score_one_tensor( + ds4_gpu_tensor *scores, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *weights, + const ds4_gpu_tensor *index_comp, + uint32_t n_comp, + uint32_t n_head, + uint32_t head_dim, + float scale) { + return indexer_scores_launch(scores, q, weights, index_comp, n_comp, 1, 0, + n_head, head_dim, 1, scale, 0); +} + +extern "C" int ds4_gpu_indexer_scores_prefill_tensor( + ds4_gpu_tensor *scores, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *weights, + const ds4_gpu_tensor *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale) { + return indexer_scores_launch(scores, q, weights, index_comp, n_comp, n_tokens, 0, + n_head, head_dim, ratio, scale, 1); +} + +extern "C" int ds4_gpu_indexer_scores_decode_batch_tensor( + ds4_gpu_tensor *scores, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *weights, + const ds4_gpu_tensor *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale) { + return indexer_scores_launch(scores, q, weights, index_comp, n_comp, n_tokens, pos0, + n_head, head_dim, ratio, scale, 1); +} + +extern "C" int ds4_gpu_indexer_topk_tensor( + ds4_gpu_tensor *selected, + const ds4_gpu_tensor *scores, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t top_k) { + if (!selected || !scores || n_comp == 0 || n_tokens == 0 || top_k == 0 || + top_k > n_comp || + scores->bytes < (uint64_t)n_tokens * n_comp * sizeof(float) || + selected->bytes < (uint64_t)n_tokens * top_k * sizeof(uint32_t)) { + return 0; + } + if (top_k == 512u && n_comp <= 1024u && + getenv("DS4_HIP_NO_TOPK1024") == NULL) { + indexer_topk_1024_kernel<<>>((uint32_t *)selected->ptr, + (const float *)scores->ptr, + n_comp, n_tokens, top_k); + return hip_ok(hipGetLastError(), "indexer topk 1024 launch"); + } + if (top_k == 512u && n_comp <= 2048u && + getenv("DS4_HIP_NO_TOPK2048") == NULL) { + indexer_topk_pow2_kernel<2048><<>>((uint32_t *)selected->ptr, + (const float *)scores->ptr, + n_comp, n_tokens, top_k); + return hip_ok(hipGetLastError(), "indexer topk 2048 launch"); + } + if (top_k == 512u && n_comp <= 4096u && + getenv("DS4_HIP_NO_TOPK2048") == NULL) { + indexer_topk_pow2_kernel<4096><<>>((uint32_t *)selected->ptr, + (const float *)scores->ptr, + n_comp, n_tokens, top_k); + return hip_ok(hipGetLastError(), "indexer topk 4096 launch"); + } + if (top_k == 512u && getenv("DS4_HIP_NO_TOPK2048") == NULL && + getenv("DS4_HIP_NO_TOPK_CHUNKED") == NULL) { + const uint32_t chunk_n = 4096u; + const uint32_t n_chunks = (n_comp + chunk_n - 1u) / chunk_n; + const uint32_t candidate_stride = n_chunks * top_k; + uint32_t n_sets = n_chunks; + uint64_t scratch_u32_per_token = candidate_stride; + while (n_sets > DS4_HIP_TOPK_MERGE_GROUP) { + n_sets = (n_sets + DS4_HIP_TOPK_MERGE_GROUP - 1u) / DS4_HIP_TOPK_MERGE_GROUP; + scratch_u32_per_token += (uint64_t)n_sets * top_k; + } + if (scratch_u32_per_token > UINT64_MAX / n_tokens / sizeof(uint32_t)) return 0; + const uint64_t tmp_bytes = (uint64_t)n_tokens * scratch_u32_per_token * sizeof(uint32_t); + uint32_t *scratch = (uint32_t *)hip_tmp_alloc(tmp_bytes, "indexer topk tree"); + if (!scratch) return 0; + + uint32_t *cur = scratch; + n_sets = n_chunks; + uint32_t cur_stride = candidate_stride; + dim3 grid_chunks(n_tokens, n_chunks, 1); + indexer_topk_chunk_pow2_kernel<4096><<>>(cur, + (const float *)scores->ptr, + n_comp, + n_tokens, + top_k, + candidate_stride); + if (!hip_ok(hipGetLastError(), "indexer topk chunk launch")) return 0; + + while (n_sets > DS4_HIP_TOPK_MERGE_GROUP) { + const uint32_t next_sets = (n_sets + DS4_HIP_TOPK_MERGE_GROUP - 1u) / DS4_HIP_TOPK_MERGE_GROUP; + const uint32_t next_stride = next_sets * top_k; + uint32_t *next = cur + (uint64_t)n_tokens * cur_stride; + dim3 grid_merge(n_tokens, next_sets, 1); + indexer_topk_tree_merge_pow2_kernel<4096><<>>( + next, + cur, + (const float *)scores->ptr, + n_comp, + n_tokens, + top_k, + n_sets, + DS4_HIP_TOPK_MERGE_GROUP, + cur_stride, + next_stride); + if (!hip_ok(hipGetLastError(), "indexer topk tree merge launch")) return 0; + cur = next; + n_sets = next_sets; + cur_stride = next_stride; + } + + indexer_topk_merge_pow2_kernel<4096><<>>((uint32_t *)selected->ptr, + cur, + (const float *)scores->ptr, + n_comp, + n_tokens, + top_k, + n_sets * top_k, + cur_stride); + return hip_ok(hipGetLastError(), "indexer topk tree final launch"); + } + indexer_topk_kernel<<>>((uint32_t *)selected->ptr, + (const float *)scores->ptr, + n_comp, n_tokens, top_k); + return hip_ok(hipGetLastError(), "indexer topk launch"); +} + +extern "C" int ds4_gpu_dsv4_topk_mask_tensor( + ds4_gpu_tensor *mask, + const ds4_gpu_tensor *topk, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t top_k) { + if (!mask || !topk || n_comp == 0 || n_tokens == 0 || top_k == 0 || + mask->bytes < (uint64_t)n_tokens * n_comp * sizeof(float) || + topk->bytes < (uint64_t)n_tokens * top_k * sizeof(uint32_t)) { + return 0; + } + uint64_t n = (uint64_t)n_tokens * n_comp; + uint64_t nk = (uint64_t)n_tokens * top_k; + uint64_t blocks = ((n > nk ? n : nk) + 255) / 256; + topk_mask_kernel<<>>((float *)mask->ptr, + (const uint32_t *)topk->ptr, + n_comp, n_tokens, top_k); + return hip_ok(hipGetLastError(), "topk mask launch"); +} +static int hip_matmul_q8_0_tensor_labeled(ds4_gpu_tensor *out, const void *model_map, uint64_t model_size, uint64_t weight_offset, uint64_t in_dim, uint64_t out_dim, const ds4_gpu_tensor *x, uint64_t n_tok, const char *label) { + if (!out || !x || !model_map) return 0; + uint64_t blocks = (in_dim + 31) / 32; + if (weight_offset > model_size || out_dim > UINT64_MAX / (blocks * 34)) return 0; + uint64_t weight_bytes = out_dim * blocks * 34; + if (weight_bytes > model_size - weight_offset) return 0; + if (x->bytes < n_tok * in_dim * sizeof(float) || + out->bytes < n_tok * out_dim * sizeof(float)) return 0; + const char *wptr = hip_model_range_ptr(model_map, weight_offset, weight_bytes, "q8_0"); + if (!wptr) return 0; + if (g_hipblas_ready && n_tok > 1) { + const float *w_f32 = hip_q8_f32_ptr(model_map, weight_offset, weight_bytes, in_dim, out_dim, label); + if (w_f32) { + const float alpha = 1.0f; + const float beta = 0.0f; + hipblasStatus_t st = hipblasSgemm(g_hipblas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + (int)out_dim, + (int)n_tok, + (int)in_dim, + &alpha, + w_f32, + (int)in_dim, + (const float *)x->ptr, + (int)in_dim, + &beta, + (float *)out->ptr, + (int)out_dim); + return hipblas_ok(st, "q8 fp32 matmul"); + } + const __half *w_f16 = hip_q8_f16_ptr(model_map, weight_offset, weight_bytes, in_dim, out_dim, label); + if (w_f16) { + const uint64_t xh_count = n_tok * in_dim; + __half *xh = (__half *)hip_tmp_alloc(xh_count * sizeof(__half), "q8 f16 gemm activations"); + if (!xh) return 0; + f32_to_f16_kernel<<<(xh_count + 255) / 256, 256>>>(xh, (const float *)x->ptr, xh_count); + if (!hip_ok(hipGetLastError(), "q8 f16 activation convert launch")) return 0; + const float alpha = 1.0f; + const float beta = 0.0f; + hipblasStatus_t st = hipblasGemmEx(g_hipblas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + (int)out_dim, + (int)n_tok, + (int)in_dim, + &alpha, + w_f16, + HIP_R_16F, + (int)in_dim, + xh, + HIP_R_16F, + (int)in_dim, + &beta, + out->ptr, + HIP_R_32F, + (int)out_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); + if (st == HIPBLAS_STATUS_SUCCESS) return 1; + fprintf(stderr, "ds4: hipBLAS q8 f16 matmul failed: status %d\n", (int)st); + hip_q8_f16_cache_disable_after_failure("hipBLAS f16 matmul failure", + in_dim * out_dim * sizeof(__half)); + /* The F16 expansion cache is only an optimization. If hipBLAS + * rejects the cached path under memory pressure, retry the same + * operation through the native Q8 kernels below. */ + } + } + const uint64_t xq_bytes = n_tok * blocks * 32u; + const uint64_t scale_offset = (xq_bytes + 15u) & ~15ull; + const uint64_t tmp_bytes = scale_offset + n_tok * blocks * sizeof(float); + void *tmp = hip_tmp_alloc(tmp_bytes, "q8_0 prequant"); + if (!tmp) return 0; + int8_t *xq = (int8_t *)tmp; + float *xscale = (float *)((char *)tmp + scale_offset); + const int use_dp4a = hip_q8_use_dp4a(); + dim3 qgrid((unsigned)blocks, (unsigned)n_tok, 1); + quantize_q8_0_f32_kernel<<>>(xq, xscale, (const float *)x->ptr, in_dim, blocks); + if (!hip_ok(hipGetLastError(), "matmul_q8_0 quantize launch")) return 0; + if (n_tok == 1) { + matmul_q8_0_preq_warp8_kernel<<<((unsigned)out_dim + 7u) / 8u, 256>>>( + (float *)out->ptr, + reinterpret_cast(wptr), + xq, + xscale, + in_dim, + out_dim, + blocks, + use_dp4a); + return hip_ok(hipGetLastError(), "matmul_q8_0 warp launch"); + } + if (getenv("DS4_HIP_NO_Q8_BATCH_WARP") == NULL && blocks <= 32u) { + dim3 bgrid(((unsigned)out_dim + 7u) / 8u, (unsigned)n_tok, 1); + matmul_q8_0_preq_batch_warp8_kernel<<>>( + (float *)out->ptr, + reinterpret_cast(wptr), + xq, + xscale, + in_dim, + out_dim, + n_tok, + blocks, + use_dp4a); + return hip_ok(hipGetLastError(), "matmul_q8_0 batch warp launch"); + } + dim3 grid((unsigned)out_dim, (unsigned)n_tok, 1); + matmul_q8_0_preq_kernel<<>>((float *)out->ptr, + reinterpret_cast(wptr), + xq, + xscale, + in_dim, out_dim, n_tok, blocks, + use_dp4a); + return hip_ok(hipGetLastError(), "matmul_q8_0 launch"); +} + +extern "C" int ds4_gpu_matmul_q8_0_tensor(ds4_gpu_tensor *out, const void *model_map, uint64_t model_size, uint64_t weight_offset, uint64_t in_dim, uint64_t out_dim, const ds4_gpu_tensor *x, uint64_t n_tok) { + return hip_matmul_q8_0_tensor_labeled(out, model_map, model_size, weight_offset, + in_dim, out_dim, x, n_tok, "q8_0"); +} + +extern "C" int ds4_gpu_matmul_q8_0_pair_tensor( + ds4_gpu_tensor *out0, + ds4_gpu_tensor *out1, + const void *model_map, + uint64_t model_size, + uint64_t weight0_offset, + uint64_t weight1_offset, + uint64_t in_dim, + uint64_t out0_dim, + uint64_t out1_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + if (!out0 || !out1 || !x || !model_map || in_dim == 0 || out0_dim == 0 || out1_dim == 0 || n_tok == 0) { + return 0; + } + if (n_tok != 1) { + return hip_matmul_q8_0_tensor_labeled(out0, model_map, model_size, weight0_offset, + in_dim, out0_dim, x, n_tok, "q8_0_pair0") && + hip_matmul_q8_0_tensor_labeled(out1, model_map, model_size, weight1_offset, + in_dim, out1_dim, x, n_tok, "q8_0_pair1"); + } + const uint64_t blocks = (in_dim + 31) / 32; + if (weight0_offset > model_size || weight1_offset > model_size || + out0_dim > UINT64_MAX / (blocks * 34) || + out1_dim > UINT64_MAX / (blocks * 34)) { + return 0; + } + const uint64_t weight0_bytes = out0_dim * blocks * 34; + const uint64_t weight1_bytes = out1_dim * blocks * 34; + if (weight0_bytes > model_size - weight0_offset || + weight1_bytes > model_size - weight1_offset || + x->bytes < in_dim * sizeof(float) || + out0->bytes < out0_dim * sizeof(float) || + out1->bytes < out1_dim * sizeof(float)) { + return 0; + } + const char *w0 = hip_model_range_ptr(model_map, weight0_offset, weight0_bytes, "q8_0_pair0"); + const char *w1 = hip_model_range_ptr(model_map, weight1_offset, weight1_bytes, "q8_0_pair1"); + if (!w0 || !w1) return 0; + + const uint64_t xq_bytes = blocks * 32u; + const uint64_t scale_offset = (xq_bytes + 15u) & ~15ull; + const uint64_t tmp_bytes = scale_offset + blocks * sizeof(float); + void *tmp = hip_tmp_alloc(tmp_bytes, "q8_0 pair prequant"); + if (!tmp) return 0; + int8_t *xq = (int8_t *)tmp; + float *xscale = (float *)((char *)tmp + scale_offset); + const int use_dp4a = hip_q8_use_dp4a(); + dim3 qgrid((unsigned)blocks, 1, 1); + quantize_q8_0_f32_kernel<<>>(xq, xscale, (const float *)x->ptr, in_dim, blocks); + if (!hip_ok(hipGetLastError(), "matmul_q8_0 pair quantize launch")) return 0; + const uint64_t max_out = out0_dim > out1_dim ? out0_dim : out1_dim; + matmul_q8_0_pair_preq_warp8_kernel<<<((unsigned)max_out + 7u) / 8u, 256>>>( + (float *)out0->ptr, + (float *)out1->ptr, + reinterpret_cast(w0), + reinterpret_cast(w1), + xq, + xscale, + in_dim, + out0_dim, + out1_dim, + blocks, + use_dp4a); + return hip_ok(hipGetLastError(), "matmul_q8_0 pair warp launch"); +} + +static int hip_matmul_q8_0_hc_expand_tensor_labeled( + ds4_gpu_tensor *out_hc, + ds4_gpu_tensor *block_out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *block_add, + const ds4_gpu_tensor *residual_hc, + const ds4_gpu_tensor *split, + uint32_t n_embd, + uint32_t n_hc, + const char *label) { + if (!out_hc || !block_out || !x || !residual_hc || !split || !model_map || + in_dim == 0 || out_dim == 0 || n_embd == 0 || n_hc == 0 || + out_dim != (uint64_t)n_embd) { + return 0; + } + const uint64_t blocks = (in_dim + 31) / 32; + if (weight_offset > model_size || out_dim > UINT64_MAX / (blocks * 34)) return 0; + const uint64_t weight_bytes = out_dim * blocks * 34; + const uint64_t hc_bytes = (uint64_t)n_hc * n_embd * sizeof(float); + const uint64_t split_bytes = (uint64_t)(2u * n_hc + n_hc * n_hc) * sizeof(float); + if (weight_bytes > model_size - weight_offset || + x->bytes < in_dim * sizeof(float) || + block_out->bytes < out_dim * sizeof(float) || + residual_hc->bytes < hc_bytes || + split->bytes < split_bytes || + out_hc->bytes < hc_bytes || + (block_add && block_add->bytes < out_dim * sizeof(float))) { + return 0; + } + const char *wptr = hip_model_range_ptr(model_map, weight_offset, weight_bytes, label ? label : "q8_0_hc_expand"); + if (!wptr) return 0; + + const uint64_t xq_bytes = blocks * 32u; + const uint64_t scale_offset = (xq_bytes + 15u) & ~15ull; + const uint64_t tmp_bytes = scale_offset + blocks * sizeof(float); + void *tmp = hip_tmp_alloc(tmp_bytes, "q8_0 hc expand prequant"); + if (!tmp) return 0; + int8_t *xq = (int8_t *)tmp; + float *xscale = (float *)((char *)tmp + scale_offset); + const int use_dp4a = hip_q8_use_dp4a(); + quantize_q8_0_f32_kernel<<<(unsigned)blocks, 32>>>(xq, xscale, (const float *)x->ptr, in_dim, blocks); + if (!hip_ok(hipGetLastError(), "matmul_q8_0_hc_expand quantize launch")) return 0; + matmul_q8_0_hc_expand_preq_warp8_kernel<<<((unsigned)out_dim + 7u) / 8u, 256>>>( + (float *)out_hc->ptr, + (float *)block_out->ptr, + block_add ? (const float *)block_add->ptr : (const float *)block_out->ptr, + (const float *)residual_hc->ptr, + (const float *)split->ptr, + reinterpret_cast(wptr), + xq, + xscale, + in_dim, + out_dim, + n_embd, + n_hc, + blocks, + block_add ? 1 : 0, + use_dp4a); + return hip_ok(hipGetLastError(), "matmul_q8_0_hc_expand launch"); +} + +extern "C" int ds4_gpu_matmul_f16_tensor(ds4_gpu_tensor *out, const void *model_map, uint64_t model_size, uint64_t weight_offset, uint64_t in_dim, uint64_t out_dim, const ds4_gpu_tensor *x, uint64_t n_tok) { + if (!out || !x || !model_map) return 0; + if (weight_offset > model_size || out_dim > UINT64_MAX / in_dim) return 0; + uint64_t weight_bytes = out_dim * in_dim * sizeof(uint16_t); + if (weight_bytes > model_size - weight_offset) return 0; + if (x->bytes < n_tok * in_dim * sizeof(float) || + out->bytes < n_tok * out_dim * sizeof(float)) return 0; + const char *wptr = hip_model_range_ptr(model_map, weight_offset, weight_bytes, "f16"); + if (!wptr) return 0; + const __half *w = (const __half *)wptr; + const int serial_f16 = getenv("DS4_HIP_SERIAL_F16_MATMUL") != NULL; + const int router_shape = in_dim == 4096u && out_dim == 256u && n_tok == 1u; + const int serial_router = + !serial_f16 && + router_shape && + getenv("DS4_HIP_SERIAL_ROUTER") != NULL; + const int ordered_router = + !serial_f16 && + !serial_router && + n_tok == 1u && + getenv("DS4_HIP_NO_ORDERED_F16_MATMUL") == NULL; + if (!serial_f16 && g_hipblas_ready && n_tok > 1) { + const uint64_t xh_count = n_tok * in_dim; + __half *xh = (__half *)hip_tmp_alloc(xh_count * sizeof(__half), "f16 gemm activations"); + if (!xh) return 0; + f32_to_f16_kernel<<<(xh_count + 255) / 256, 256>>>(xh, (const float *)x->ptr, xh_count); + if (!hip_ok(hipGetLastError(), "f16 activation convert launch")) return 0; + const float alpha = 1.0f; + const float beta = 0.0f; + hipblasStatus_t st = hipblasGemmEx(g_hipblas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + (int)out_dim, + (int)n_tok, + (int)in_dim, + &alpha, + w, + HIP_R_16F, + (int)in_dim, + xh, + HIP_R_16F, + (int)in_dim, + &beta, + out->ptr, + HIP_R_32F, + (int)out_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); + return hipblas_ok(st, "f16 matmul"); + } + dim3 grid((unsigned)out_dim, (unsigned)n_tok, 1); + if (serial_f16 || serial_router) { + matmul_f16_serial_kernel<<>>((float *)out->ptr, w, (const float *)x->ptr, in_dim, out_dim, n_tok); + return hip_ok(hipGetLastError(), serial_router ? "matmul_f16_router_serial launch" : "matmul_f16_serial launch"); + } + if (ordered_router) { + matmul_f16_ordered_chunks_kernel<<>>((float *)out->ptr, w, (const float *)x->ptr, in_dim, out_dim, n_tok); + return hip_ok(hipGetLastError(), "matmul_f16_ordered_chunks launch"); + } + matmul_f16_kernel<<>>((float *)out->ptr, w, (const float *)x->ptr, in_dim, out_dim, n_tok); + return hip_ok(hipGetLastError(), "matmul_f16 launch"); +} + +extern "C" int ds4_gpu_matmul_f16_pair_tensor( + ds4_gpu_tensor *out0, + ds4_gpu_tensor *out1, + const void *model_map, + uint64_t model_size, + uint64_t weight0_offset, + uint64_t weight1_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + uint64_t n_tok) { + if (!out0 || !out1 || !x || !model_map || in_dim == 0 || out_dim == 0 || n_tok == 0) { + return 0; + } + if (n_tok != 1 || + getenv("DS4_HIP_NO_F16_PAIR_MATMUL") != NULL || + getenv("DS4_HIP_SERIAL_F16_MATMUL") != NULL || + getenv("DS4_HIP_SERIAL_ROUTER") != NULL || + getenv("DS4_HIP_NO_ORDERED_F16_MATMUL") != NULL) { + return ds4_gpu_matmul_f16_tensor(out0, model_map, model_size, weight0_offset, + in_dim, out_dim, x, n_tok) && + ds4_gpu_matmul_f16_tensor(out1, model_map, model_size, weight1_offset, + in_dim, out_dim, x, n_tok); + } + if (weight0_offset > model_size || weight1_offset > model_size || + out_dim > UINT64_MAX / in_dim) { + return 0; + } + const uint64_t weight_bytes = out_dim * in_dim * sizeof(uint16_t); + if (weight_bytes > model_size - weight0_offset || + weight_bytes > model_size - weight1_offset || + x->bytes < in_dim * sizeof(float) || + out0->bytes < out_dim * sizeof(float) || + out1->bytes < out_dim * sizeof(float)) { + return 0; + } + const __half *w0 = (const __half *)hip_model_range_ptr(model_map, weight0_offset, weight_bytes, "f16_pair0"); + const __half *w1 = (const __half *)hip_model_range_ptr(model_map, weight1_offset, weight_bytes, "f16_pair1"); + if (!w0 || !w1) return 0; + matmul_f16_pair_ordered_chunks_kernel<<<(unsigned)out_dim, 32>>>( + (float *)out0->ptr, + (float *)out1->ptr, + w0, + w1, + (const float *)x->ptr, + in_dim, + out_dim, + out_dim, + n_tok); + return hip_ok(hipGetLastError(), "matmul_f16_pair_ordered_chunks launch"); +} + +extern "C" int ds4_gpu_matmul_f32_tensor(ds4_gpu_tensor *out, const void *model_map, uint64_t model_size, uint64_t weight_offset, uint64_t in_dim, uint64_t out_dim, const ds4_gpu_tensor *x, uint64_t n_tok) { + if (!out || !x || !model_map || in_dim == 0 || out_dim == 0 || n_tok == 0) return 0; + if (weight_offset > model_size || out_dim > UINT64_MAX / in_dim) return 0; + uint64_t weight_elems = out_dim * in_dim; + if (weight_elems > UINT64_MAX / sizeof(float)) return 0; + uint64_t weight_bytes = weight_elems * sizeof(float); + if (weight_bytes > model_size - weight_offset) return 0; + if (x->bytes < n_tok * in_dim * sizeof(float) || + out->bytes < n_tok * out_dim * sizeof(float)) return 0; + const char *wptr = hip_model_range_ptr(model_map, weight_offset, weight_bytes, "f32"); + if (!wptr) return 0; + const float *w = (const float *)wptr; + if (g_hipblas_ready && n_tok > 1) { + const float alpha = 1.0f; + const float beta = 0.0f; + hipblasStatus_t st = hipblasSgemm(g_hipblas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + (int)out_dim, + (int)n_tok, + (int)in_dim, + &alpha, + w, + (int)in_dim, + (const float *)x->ptr, + (int)in_dim, + &beta, + (float *)out->ptr, + (int)out_dim); + return hipblas_ok(st, "f32 matmul"); + } + dim3 grid((unsigned)out_dim, (unsigned)n_tok, 1); + matmul_f32_kernel<<>>((float *)out->ptr, w, (const float *)x->ptr, in_dim, out_dim, n_tok); + return hip_ok(hipGetLastError(), "matmul_f32 launch"); +} + +extern "C" int ds4_gpu_repeat_hc_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *row, uint32_t n_embd, uint32_t n_hc) { + if (!out || !row || n_embd == 0 || n_hc == 0 || + row->bytes < (uint64_t)n_embd * sizeof(float) || + out->bytes < (uint64_t)n_embd * n_hc * sizeof(float)) { + return 0; + } + uint64_t n = (uint64_t)n_embd * n_hc; + repeat_hc_kernel<<<(n + 255) / 256, 256>>>((float *)out->ptr, (const float *)row->ptr, n_embd, n_hc); + return hip_ok(hipGetLastError(), "repeat_hc launch"); +} + +extern "C" int ds4_gpu_rms_norm_plain_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *x, uint32_t n, float eps) { + if (!out || !x || out->bytes < (uint64_t)n * sizeof(float) || + x->bytes < (uint64_t)n * sizeof(float)) return 0; + rms_norm_plain_kernel<<<1, 256>>>((float *)out->ptr, (const float *)x->ptr, n, 1, eps); + return hip_ok(hipGetLastError(), "rms_norm_plain launch"); +} +extern "C" int ds4_gpu_rms_norm_plain_rows_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *x, uint32_t n, uint32_t rows, float eps) { + if (!out || !x || out->bytes < (uint64_t)n * rows * sizeof(float) || + x->bytes < (uint64_t)n * rows * sizeof(float)) return 0; + rms_norm_plain_kernel<<>>((float *)out->ptr, (const float *)x->ptr, n, rows, eps); + return hip_ok(hipGetLastError(), "rms_norm_plain launch"); +} +extern "C" int ds4_gpu_rms_norm_weight_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *x, const void *model_map, uint64_t model_size, uint64_t weight_offset, uint32_t n, float eps) { + if (!out || !x || !model_map || weight_offset > model_size || + model_size - weight_offset < (uint64_t)n * sizeof(float) || + out->bytes < (uint64_t)n * sizeof(float) || + x->bytes < (uint64_t)n * sizeof(float)) return 0; + const char *wptr = hip_model_range_ptr(model_map, weight_offset, (uint64_t)n * sizeof(float), "rms_weight"); + if (!wptr) return 0; + const float *w = (const float *)wptr; + rms_norm_weight_kernel<<<1, 256>>>((float *)out->ptr, (const float *)x->ptr, w, n, 1, eps); + return hip_ok(hipGetLastError(), "rms_norm_weight launch"); +} +extern "C" int ds4_gpu_rms_norm_weight_rows_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *x, const void *model_map, uint64_t model_size, uint64_t weight_offset, uint32_t n, uint32_t rows, float eps) { + if (!out || !x || !model_map || weight_offset > model_size || + model_size - weight_offset < (uint64_t)n * sizeof(float) || + out->bytes < (uint64_t)n * rows * sizeof(float) || + x->bytes < (uint64_t)n * rows * sizeof(float)) return 0; + const char *wptr = hip_model_range_ptr(model_map, weight_offset, (uint64_t)n * sizeof(float), "rms_weight"); + if (!wptr) return 0; + const float *w = (const float *)wptr; + rms_norm_weight_kernel<<>>((float *)out->ptr, (const float *)x->ptr, w, n, rows, eps); + return hip_ok(hipGetLastError(), "rms_norm_weight launch"); +} +extern "C" int ds4_gpu_dsv4_qkv_rms_norm_rows_tensor( + ds4_gpu_tensor *q_out, + const ds4_gpu_tensor *q, + const void *model_map, + uint64_t model_size, + uint64_t q_weight_offset, + uint32_t q_n, + ds4_gpu_tensor *kv_out, + const ds4_gpu_tensor *kv, + uint64_t kv_weight_offset, + uint32_t kv_n, + uint32_t rows, + float eps) { + if (getenv("DS4_HIP_DISABLE_QKV_RMS_FUSED") == NULL) { + if (!q_out || !q || !kv_out || !kv || !model_map || + q_weight_offset > model_size || + kv_weight_offset > model_size || + model_size - q_weight_offset < (uint64_t)q_n * sizeof(float) || + model_size - kv_weight_offset < (uint64_t)kv_n * sizeof(float) || + q_out->bytes < (uint64_t)q_n * rows * sizeof(float) || + q->bytes < (uint64_t)q_n * rows * sizeof(float) || + kv_out->bytes < (uint64_t)kv_n * rows * sizeof(float) || + kv->bytes < (uint64_t)kv_n * rows * sizeof(float)) { + return 0; + } + const float *q_w = (const float *)hip_model_range_ptr(model_map, + q_weight_offset, (uint64_t)q_n * sizeof(float), "q_rms_weight"); + const float *kv_w = (const float *)hip_model_range_ptr(model_map, + kv_weight_offset, (uint64_t)kv_n * sizeof(float), "kv_rms_weight"); + if (!q_w || !kv_w) return 0; + dim3 grid(rows, 2u, 1u); + dsv4_qkv_rms_norm_rows_kernel<<>>( + (float *)q_out->ptr, + (const float *)q->ptr, + q_w, + q_n, + (float *)kv_out->ptr, + (const float *)kv->ptr, + kv_w, + kv_n, + rows, + eps); + return hip_ok(hipGetLastError(), "dsv4 qkv rms norm rows launch"); + } + return ds4_gpu_rms_norm_weight_rows_tensor(q_out, q, model_map, model_size, + q_weight_offset, q_n, rows, eps) && + ds4_gpu_rms_norm_weight_rows_tensor(kv_out, kv, model_map, model_size, + kv_weight_offset, kv_n, rows, eps); +} +extern "C" int ds4_gpu_head_rms_norm_tensor(ds4_gpu_tensor *x, uint32_t n_tok, uint32_t n_head, uint32_t head_dim, float eps) { + if (!x || x->bytes < (uint64_t)n_tok * n_head * head_dim * sizeof(float)) return 0; + head_rms_norm_kernel<<>>((float *)x->ptr, n_tok, n_head, head_dim, eps); + return hip_ok(hipGetLastError(), "head_rms_norm launch"); +} +extern "C" int ds4_gpu_head_rms_norm_rope_tail_tensor(ds4_gpu_tensor *x, uint32_t n_tok, uint32_t n_head, uint32_t head_dim, uint32_t n_rot, uint32_t pos0, uint32_t n_ctx_orig, bool inverse, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, float eps) { + if (!x || n_rot > head_dim || (n_rot & 1u) || + x->bytes < (uint64_t)n_tok * n_head * head_dim * sizeof(float)) return 0; + head_rms_norm_rope_tail_kernel<<>>((float *)x->ptr, n_tok, n_head, head_dim, n_rot, pos0, n_ctx_orig, inverse ? 1 : 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, eps); + return hip_ok(hipGetLastError(), "head_rms_norm_rope_tail launch"); +} +extern "C" int ds4_gpu_dsv4_fp8_kv_quantize_tensor(ds4_gpu_tensor *x, uint32_t n_tok, uint32_t head_dim, uint32_t n_rot) { + if (!x || n_rot > head_dim || x->bytes < (uint64_t)n_tok * head_dim * sizeof(float)) return 0; + fp8_kv_quantize_kernel<<>>((float *)x->ptr, n_tok, head_dim, n_rot); + return hip_ok(hipGetLastError(), "fp8_kv_quantize launch"); +} +extern "C" int ds4_gpu_rope_tail_tensor(ds4_gpu_tensor *x, uint32_t n_tok, uint32_t n_head, uint32_t head_dim, uint32_t n_rot, uint32_t pos0, uint32_t n_ctx_orig, bool inverse, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow) { + if (!x || n_rot > head_dim || (n_rot & 1) || x->bytes < (uint64_t)n_tok * n_head * head_dim * sizeof(float)) return 0; + uint32_t pairs = n_tok * n_head * (n_rot / 2); + rope_tail_kernel<<<(pairs + 255) / 256, 256>>>((float *)x->ptr, n_tok, n_head, head_dim, n_rot, pos0, n_ctx_orig, inverse ? 1 : 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + return hip_ok(hipGetLastError(), "rope_tail launch"); +} +extern "C" int ds4_gpu_store_raw_kv_tensor(ds4_gpu_tensor *raw_cache, const ds4_gpu_tensor *kv, uint32_t raw_cap, uint32_t row, uint32_t head_dim); +extern "C" int ds4_gpu_kv_fp8_store_raw_tensor( + ds4_gpu_tensor *kv, + ds4_gpu_tensor *raw_cache, + uint32_t raw_cap, + uint32_t raw_row, + uint32_t head_dim, + uint32_t n_rot) { + return ds4_gpu_dsv4_fp8_kv_quantize_tensor(kv, 1, head_dim, n_rot) && + ds4_gpu_store_raw_kv_tensor(raw_cache, kv, raw_cap, raw_row, head_dim); +} +extern "C" int ds4_gpu_store_raw_kv_tensor(ds4_gpu_tensor *raw_cache, const ds4_gpu_tensor *kv, uint32_t raw_cap, uint32_t row, uint32_t head_dim) { + if (!raw_cache || !kv || raw_cap == 0 || + raw_cache->bytes < (uint64_t)raw_cap * head_dim * sizeof(float) || + kv->bytes < (uint64_t)head_dim * sizeof(float)) return 0; + store_raw_kv_batch_kernel<<<(head_dim + 255) / 256, 256>>>((float *)raw_cache->ptr, (const float *)kv->ptr, raw_cap, row, 1, head_dim); + return hip_ok(hipGetLastError(), "store_raw_kv launch"); +} +extern "C" int ds4_gpu_store_raw_kv_batch_tensor(ds4_gpu_tensor *raw_cache, const ds4_gpu_tensor *kv, uint32_t raw_cap, uint32_t pos0, uint32_t n_tokens, uint32_t head_dim) { + if (!raw_cache || !kv || raw_cap == 0 || + raw_cache->bytes < (uint64_t)raw_cap * head_dim * sizeof(float) || + kv->bytes < (uint64_t)n_tokens * head_dim * sizeof(float)) return 0; + uint64_t n = (uint64_t)n_tokens * head_dim; + store_raw_kv_batch_kernel<<<(n + 255) / 256, 256>>>((float *)raw_cache->ptr, (const float *)kv->ptr, raw_cap, pos0, n_tokens, head_dim); + return hip_ok(hipGetLastError(), "store_raw_kv_batch launch"); +} +extern "C" int ds4_gpu_compressor_store_batch_tensor( + const ds4_gpu_tensor *kv, + const ds4_gpu_tensor *sc, + ds4_gpu_tensor *state_kv, + ds4_gpu_tensor *state_score, + const void *model_map, + uint64_t model_size, + uint64_t ape_offset, + uint32_t ape_type, + uint32_t head_dim, + uint32_t ratio, + uint32_t pos0, + uint32_t n_tokens) { + if (!kv || !sc || !state_kv || !state_score || !model_map || + head_dim == 0 || ratio == 0 || n_tokens == 0 || + (ape_type != 0u && ape_type != 1u)) { + return 0; + } + const uint32_t coff = ratio == 4u ? 2u : 1u; + const uint32_t width = coff * head_dim; + const uint32_t state_rows = coff * ratio; + const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; + const uint64_t kv_bytes = (uint64_t)n_tokens * width * sizeof(float); + const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); + const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + if (ape_offset > model_size || ape_bytes > model_size - ape_offset || + kv->bytes < kv_bytes || sc->bytes < kv_bytes || + state_kv->bytes < state_bytes || state_score->bytes < state_bytes) { + return 0; + } + const char *ape = hip_model_range_ptr(model_map, ape_offset, ape_bytes, "compressor_ape"); + if (!ape) return 0; + uint64_t n = (uint64_t)n_tokens * width; + compressor_store_kernel<<<(n + 255) / 256, 256>>>( + (const float *)kv->ptr, + (const float *)sc->ptr, + (float *)state_kv->ptr, + (float *)state_score->ptr, + ape, + 0, + ape_type, + head_dim, + ratio, + pos0, + n_tokens); + return hip_ok(hipGetLastError(), "compressor store launch"); +} + +extern "C" int ds4_gpu_compressor_update_tensor( + const ds4_gpu_tensor *kv_cur, + const ds4_gpu_tensor *sc_cur, + ds4_gpu_tensor *state_kv, + ds4_gpu_tensor *state_score, + ds4_gpu_tensor *comp_cache, + const void *model_map, + uint64_t model_size, + uint64_t ape_offset, + uint32_t ape_type, + uint64_t norm_offset, + uint32_t norm_type, + uint32_t head_dim, + uint32_t ratio, + uint32_t pos, + uint32_t comp_row, + uint32_t n_rot, + uint32_t n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float rms_eps) { + if (!kv_cur || !sc_cur || !state_kv || !state_score || !comp_cache || + !model_map || head_dim == 0 || ratio == 0 || + n_rot > head_dim || (n_rot & 1u) != 0 || + (ape_type != 0u && ape_type != 1u) || norm_type != 0u) { + return 0; + } + const uint32_t coff = ratio == 4u ? 2u : 1u; + const uint32_t width = coff * head_dim; + const uint32_t state_rows = coff * ratio; + const uint32_t emit = ((pos + 1u) % ratio) == 0u ? 1u : 0u; + const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; + const uint64_t kv_bytes = (uint64_t)width * sizeof(float); + const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); + const uint64_t comp_bytes = (uint64_t)(comp_row + (emit ? 1u : 0u)) * head_dim * sizeof(float); + const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + const uint64_t norm_bytes = (uint64_t)head_dim * sizeof(float); + if (ape_offset > model_size || ape_bytes > model_size - ape_offset || + norm_offset > model_size || norm_bytes > model_size - norm_offset || + kv_cur->bytes < kv_bytes || sc_cur->bytes < kv_bytes || + state_kv->bytes < state_bytes || state_score->bytes < state_bytes || + (emit && comp_cache->bytes < comp_bytes)) { + return 0; + } + if (!ds4_gpu_compressor_store_batch_tensor(kv_cur, sc_cur, state_kv, state_score, + model_map, model_size, ape_offset, ape_type, + head_dim, ratio, pos, 1)) { + return 0; + } + if (!emit) return 1; + ds4_gpu_tensor *comp_row_view = ds4_gpu_tensor_view( + comp_cache, + (uint64_t)comp_row * head_dim * sizeof(float), + (uint64_t)head_dim * sizeof(float)); + if (!comp_row_view) return 0; + compressor_update_pool_kernel<<<(head_dim + 255) / 256, 256>>>( + (float *)comp_row_view->ptr, + (const float *)state_kv->ptr, + (const float *)state_score->ptr, + head_dim, + ratio); + int ok = hip_ok(hipGetLastError(), "compressor update pool launch"); + if (ok) ok = ds4_gpu_rms_norm_weight_rows_tensor(comp_row_view, comp_row_view, + model_map, model_size, norm_offset, + head_dim, 1, rms_eps); + if (ok) ok = ds4_gpu_rope_tail_tensor(comp_row_view, 1, 1, head_dim, n_rot, + pos + 1u - ratio, n_ctx_orig, false, + freq_base, freq_scale, ext_factor, attn_factor, + beta_fast, beta_slow); + ds4_gpu_tensor_free(comp_row_view); + if (ok && ratio == 4u) { + uint64_t half = 4ull * width; + compressor_shift_ratio4_kernel<<<(half + 255) / 256, 256>>>( + (float *)state_kv->ptr, (float *)state_score->ptr, width); + ok = hip_ok(hipGetLastError(), "compressor ratio4 shift launch"); + } + return ok; +} +extern "C" int ds4_gpu_compressor_prefill_tensor( + ds4_gpu_tensor *comp_cache, + ds4_gpu_tensor *state_kv, + ds4_gpu_tensor *state_score, + const ds4_gpu_tensor *kv, + const ds4_gpu_tensor *sc, + const void *model_map, + uint64_t model_size, + uint64_t ape_offset, + uint32_t ape_type, + uint64_t norm_offset, + uint32_t norm_type, + uint32_t head_dim, + uint32_t ratio, + uint32_t pos0, + uint32_t n_tokens, + uint32_t n_rot, + uint32_t n_ctx_orig, + bool quantize_fp8, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float rms_eps) { + if (!comp_cache || !state_kv || !state_score || !kv || !sc || !model_map || + head_dim == 0 || ratio == 0 || n_tokens == 0 || + n_rot > head_dim || (n_rot & 1u) != 0 || + (ape_type != 0u && ape_type != 1u) || norm_type != 0u) { + return 0; + } + + const uint32_t coff = ratio == 4u ? 2u : 1u; + const uint32_t width = coff * head_dim; + const uint32_t state_rows = coff * ratio; + const uint32_t n_comp = n_tokens / ratio; + const uint32_t cutoff = n_comp * ratio; + const uint32_t rem = n_tokens - cutoff; + const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; + const uint64_t kv_bytes = (uint64_t)n_tokens * width * sizeof(float); + const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); + const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * sizeof(float); + const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + const uint64_t norm_bytes = (uint64_t)head_dim * sizeof(float); + + if (ape_offset > model_size || ape_bytes > model_size - ape_offset || + norm_offset > model_size || norm_bytes > model_size - norm_offset || + kv->bytes < kv_bytes || sc->bytes < kv_bytes || + state_kv->bytes < state_bytes || state_score->bytes < state_bytes || + (n_comp && comp_cache->bytes < comp_bytes)) { + return 0; + } + const char *ape = hip_model_range_ptr(model_map, ape_offset, ape_bytes, "compressor_ape"); + if (!ape) return 0; + + uint64_t state_n = (uint64_t)state_rows * width; + fill_f32_kernel<<<(state_n + 255) / 256, 256>>>((float *)state_kv->ptr, state_n, 0.0f); + if (!hip_ok(hipGetLastError(), "compressor state kv fill launch")) return 0; + fill_f32_kernel<<<(state_n + 255) / 256, 256>>>((float *)state_score->ptr, state_n, -INFINITY); + if (!hip_ok(hipGetLastError(), "compressor state score fill launch")) return 0; + + if (ratio == 4u) { + if (cutoff >= ratio) { + uint32_t prev_start = cutoff - ratio; + uint64_t n = (uint64_t)ratio * width; + compressor_set_rows_kernel<<<(n + 255) / 256, 256>>>( + (float *)state_kv->ptr, (float *)state_score->ptr, + (const float *)kv->ptr, (const float *)sc->ptr, + ape, 0, ape_type, width, ratio, pos0, + prev_start, 0, ratio); + if (!hip_ok(hipGetLastError(), "compressor prefill prev state launch")) return 0; + } + if (rem != 0) { + uint64_t n = (uint64_t)rem * width; + compressor_set_rows_kernel<<<(n + 255) / 256, 256>>>( + (float *)state_kv->ptr, (float *)state_score->ptr, + (const float *)kv->ptr, (const float *)sc->ptr, + ape, 0, ape_type, width, ratio, pos0, + cutoff, ratio, rem); + if (!hip_ok(hipGetLastError(), "compressor prefill rem state launch")) return 0; + } + } else if (rem != 0) { + uint64_t n = (uint64_t)rem * width; + compressor_set_rows_kernel<<<(n + 255) / 256, 256>>>( + (float *)state_kv->ptr, (float *)state_score->ptr, + (const float *)kv->ptr, (const float *)sc->ptr, + ape, 0, ape_type, width, ratio, pos0, + cutoff, 0, rem); + if (!hip_ok(hipGetLastError(), "compressor prefill rem state launch")) return 0; + } + if (n_comp != 0) { + dim3 grid((head_dim + 255) / 256, n_comp, 1); + compressor_prefill_pool_kernel<<>>( + (float *)comp_cache->ptr, + (const float *)kv->ptr, + (const float *)sc->ptr, + (const float *)state_kv->ptr, + (const float *)state_score->ptr, + ape, 0, ape_type, head_dim, ratio, pos0, n_comp, 0); + if (!hip_ok(hipGetLastError(), "compressor prefill pool launch")) return 0; + if (!ds4_gpu_rms_norm_weight_rows_tensor(comp_cache, comp_cache, + model_map, model_size, norm_offset, + head_dim, n_comp, rms_eps)) return 0; + if (n_rot != 0 && !ds4_gpu_rope_tail_tensor(comp_cache, n_comp, 1, head_dim, + n_rot, pos0, n_ctx_orig, false, + freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow)) return 0; + if (quantize_fp8 && !ds4_gpu_dsv4_fp8_kv_quantize_tensor(comp_cache, n_comp, head_dim, n_rot)) return 0; + } + return 1; +} +extern "C" int ds4_gpu_compressor_prefill_ratio4_replay_tensor( + ds4_gpu_tensor *comp_cache, + ds4_gpu_tensor *state_kv, + ds4_gpu_tensor *state_score, + const ds4_gpu_tensor *kv, + const ds4_gpu_tensor *sc, + const void *model_map, + uint64_t model_size, + uint64_t ape_offset, + uint32_t ape_type, + uint64_t norm_offset, + uint32_t norm_type, + uint32_t head_dim, + uint32_t pos0, + uint32_t n_tokens, + uint32_t n_rot, + uint32_t n_ctx_orig, + bool quantize_fp8, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float rms_eps) { + if (!comp_cache || !state_kv || !state_score || !kv || !sc || !model_map || + head_dim == 0 || n_tokens == 0 || (n_tokens & 3u) != 0 || (pos0 & 3u) != 0 || + n_rot > head_dim || (n_rot & 1u) != 0 || + (ape_type != 0u && ape_type != 1u) || norm_type != 0u) { + return 0; + } + + const uint32_t ratio = 4u; + const uint32_t width = 2u * head_dim; + const uint32_t state_rows = 8u; + const uint32_t n_comp = n_tokens / ratio; + const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; + const uint64_t kv_bytes = (uint64_t)n_tokens * width * sizeof(float); + const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); + const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * sizeof(float); + const uint64_t ape_bytes = (uint64_t)width * ratio * elem_ape; + const uint64_t norm_bytes = (uint64_t)head_dim * sizeof(float); + if (ape_offset > model_size || ape_bytes > model_size - ape_offset || + norm_offset > model_size || norm_bytes > model_size - norm_offset || + kv->bytes < kv_bytes || sc->bytes < kv_bytes || + state_kv->bytes < state_bytes || state_score->bytes < state_bytes || + comp_cache->bytes < comp_bytes) { + return 0; + } + const char *ape = hip_model_range_ptr(model_map, ape_offset, ape_bytes, "compressor_ape"); + if (!ape) return 0; + dim3 grid((head_dim + 255) / 256, n_comp, 1); + compressor_prefill_pool_kernel<<>>( + (float *)comp_cache->ptr, + (const float *)kv->ptr, + (const float *)sc->ptr, + (const float *)state_kv->ptr, + (const float *)state_score->ptr, + ape, 0, ape_type, head_dim, ratio, pos0, n_comp, 1); + if (!hip_ok(hipGetLastError(), "compressor replay pool launch")) return 0; + if (!ds4_gpu_rms_norm_weight_rows_tensor(comp_cache, comp_cache, + model_map, model_size, norm_offset, + head_dim, n_comp, rms_eps)) return 0; + if (n_rot != 0 && !ds4_gpu_rope_tail_tensor(comp_cache, n_comp, 1, head_dim, + n_rot, pos0, n_ctx_orig, false, + freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow)) return 0; + if (quantize_fp8 && !ds4_gpu_dsv4_fp8_kv_quantize_tensor(comp_cache, n_comp, head_dim, n_rot)) return 0; + + uint64_t state_n = (uint64_t)state_rows * width; + fill_f32_kernel<<<(state_n + 255) / 256, 256>>>((float *)state_kv->ptr, state_n, 0.0f); + if (!hip_ok(hipGetLastError(), "compressor replay state kv fill launch")) return 0; + fill_f32_kernel<<<(state_n + 255) / 256, 256>>>((float *)state_score->ptr, state_n, -INFINITY); + if (!hip_ok(hipGetLastError(), "compressor replay state score fill launch")) return 0; + uint32_t prev_start = n_tokens - ratio; + uint64_t n = (uint64_t)ratio * width; + compressor_set_rows_kernel<<<(n + 255) / 256, 256>>>( + (float *)state_kv->ptr, (float *)state_score->ptr, + (const float *)kv->ptr, (const float *)sc->ptr, + ape, 0, ape_type, width, ratio, pos0, + prev_start, 0, ratio); + return hip_ok(hipGetLastError(), "compressor replay state launch"); +} +extern "C" int ds4_gpu_compressor_prefill_state_ratio4_tensor( + ds4_gpu_tensor *state_kv, + ds4_gpu_tensor *state_score, + const ds4_gpu_tensor *kv_tail, + const ds4_gpu_tensor *sc_tail, + const void *model_map, + uint64_t model_size, + uint64_t ape_offset, + uint32_t ape_type, + uint32_t head_dim, + uint32_t pos0) { + if (!state_kv || !state_score || !kv_tail || !sc_tail || !model_map || + head_dim == 0 || (ape_type != 0u && ape_type != 1u)) { + return 0; + } + const uint32_t ratio = 4u; + const uint32_t width = 2u * head_dim; + const uint32_t state_rows = 8u; + const uint64_t elem_ape = ape_type == 1u ? 2u : 4u; + const uint64_t tail_bytes = (uint64_t)ratio * width * sizeof(float); + const uint64_t state_bytes = (uint64_t)state_rows * width * sizeof(float); + const uint64_t ape_bytes = (uint64_t)ratio * width * elem_ape; + if (ape_offset > model_size || ape_bytes > model_size - ape_offset || + kv_tail->bytes < tail_bytes || sc_tail->bytes < tail_bytes || + state_kv->bytes < state_bytes || state_score->bytes < state_bytes) { + return 0; + } + const char *ape = hip_model_range_ptr(model_map, ape_offset, ape_bytes, "compressor_ape"); + if (!ape) return 0; + uint64_t state_n = (uint64_t)state_rows * width; + fill_f32_kernel<<<(state_n + 255) / 256, 256>>>((float *)state_kv->ptr, state_n, 0.0f); + if (!hip_ok(hipGetLastError(), "compressor state kv fill launch")) return 0; + fill_f32_kernel<<<(state_n + 255) / 256, 256>>>((float *)state_score->ptr, state_n, -INFINITY); + if (!hip_ok(hipGetLastError(), "compressor state score fill launch")) return 0; + uint64_t n = (uint64_t)ratio * width; + compressor_set_rows_kernel<<<(n + 255) / 256, 256>>>( + (float *)state_kv->ptr, (float *)state_score->ptr, + (const float *)kv_tail->ptr, (const float *)sc_tail->ptr, + ape, 0, ape_type, width, ratio, pos0, + 0, 0, ratio); + return hip_ok(hipGetLastError(), "compressor state set launch"); +} +extern "C" int ds4_gpu_attention_decode_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + const ds4_gpu_tensor *comp_kv, + uint32_t n_comp, + const ds4_gpu_tensor *comp_mask, + uint32_t use_mask, + uint32_t n_head, + uint32_t head_dim) { + if (!heads || !q || !raw_kv || !model_map || n_raw == 0 || raw_cap < n_raw || + raw_start >= raw_cap || (n_comp != 0 && !comp_kv) || (use_mask && !comp_mask) || + sinks_offset > model_size || + (uint64_t)n_head * sizeof(float) > model_size - sinks_offset || + heads->bytes < (uint64_t)n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_head * head_dim * sizeof(float) || + raw_kv->bytes < (uint64_t)raw_cap * head_dim * sizeof(float) || + (n_comp && comp_kv->bytes < (uint64_t)n_comp * head_dim * sizeof(float)) || + (use_mask && comp_mask->bytes < (uint64_t)n_comp * sizeof(float))) { + return 0; + } + const float *sinks = (const float *)hip_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + if (!hip_attention_score_buffer_fits(n_comp)) { + if (!use_mask && head_dim == 512u && + getenv("DS4_HIP_NO_WINDOW_ATTENTION") == NULL) { + dim3 online_grid(1, (n_head + 7u) / 8u, 1); + attention_decode_mixed_heads8_online_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv->ptr, + 1, + 0, + n_raw, + raw_cap, + raw_start, + n_comp, + 0, + 0, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention decode online launch"); + } + fprintf(stderr, "ds4: ROCm attention score buffer too small for %u compressed rows\n", n_comp); + return 0; + } + dim3 grid(1, n_head, 1); + attention_decode_mixed_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv->ptr, + use_mask ? (const float *)comp_mask->ptr : NULL, + use_mask, + 1, 0, n_raw, raw_cap, raw_start, n_comp, + 0, 0, n_head, head_dim); + return hip_ok(hipGetLastError(), "attention decode launch"); +} +extern "C" int ds4_gpu_attention_prefill_raw_heads_tensor(ds4_gpu_tensor *heads, const void *model_map, uint64_t model_size, uint64_t sinks_offset, const ds4_gpu_tensor *q, const ds4_gpu_tensor *raw_kv, uint32_t n_tokens, uint32_t window, uint32_t n_head, uint32_t head_dim) { + if (!heads || !q || !raw_kv || !model_map || sinks_offset > model_size || + model_size - sinks_offset < (uint64_t)n_head * sizeof(float) || + heads->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + raw_kv->bytes < (uint64_t)n_tokens * head_dim * sizeof(float)) return 0; + const float *sinks = (const float *)hip_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + if (n_tokens > 1 && + getenv("DS4_HIP_NO_WINDOW_ATTENTION") == NULL && + (getenv("DS4_HIP_WINDOW_ATTENTION") != NULL || (!g_quality_mode && n_tokens >= 128u))) { + dim3 grid(n_tokens, (n_head + 7u) / 8u, 1); + attention_static_mixed_heads8_online_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + (const float *)raw_kv->ptr, + n_tokens, + 0, + window, + 1, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention raw window launch"); + } + if (g_hipblas_ready && n_tokens > 1 && + getenv("DS4_HIP_NO_HIPBLAS_ATTENTION") == NULL) { + const uint32_t n_keys = n_tokens; + const uint64_t score_count = (uint64_t)n_head * n_tokens * n_keys; + const uint64_t out_count = (uint64_t)n_head * n_tokens * head_dim; + const uint64_t score_bytes = score_count * sizeof(float); + const uint64_t out_offset = (score_bytes + 255u) & ~255ull; + const uint64_t tmp_bytes = out_offset + out_count * sizeof(float); + float *tmp = (float *)hip_tmp_alloc(tmp_bytes, "attention raw hipblas"); + if (!tmp) return 0; + float *scores = tmp; + float *out_tmp = (float *)((char *)tmp + out_offset); + const float alpha = rsqrtf((float)head_dim); + const float beta = 0.0f; + hipblasStatus_t st = hipblasSgemmStridedBatched(g_hipblas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + (int)n_keys, + (int)n_tokens, + (int)head_dim, + &alpha, + (const float *)raw_kv->ptr, + (int)head_dim, + 0, + (const float *)q->ptr, + (int)(n_head * head_dim), + (long long)head_dim, + &beta, + scores, + (int)n_keys, + (long long)n_keys * n_tokens, + (int)n_head); + if (!hipblas_ok(st, "attention raw score gemm")) return 0; + dim3 sgrid(n_tokens, n_head, 1); + attention_prefill_raw_softmax_kernel<<>>(scores, sinks, n_tokens, window, n_keys); + if (!hip_ok(hipGetLastError(), "attention raw softmax launch")) return 0; + const float one = 1.0f; + st = hipblasSgemmStridedBatched(g_hipblas, + HIPBLAS_OP_N, + HIPBLAS_OP_N, + (int)head_dim, + (int)n_tokens, + (int)n_keys, + &one, + (const float *)raw_kv->ptr, + (int)head_dim, + 0, + scores, + (int)n_keys, + (long long)n_keys * n_tokens, + &beta, + out_tmp, + (int)head_dim, + (long long)head_dim * n_tokens, + (int)n_head); + if (!hipblas_ok(st, "attention raw value gemm")) return 0; + uint64_t n = (uint64_t)n_tokens * n_head * head_dim; + attention_prefill_unpack_heads_kernel<<<(n + 255) / 256, 256>>>((float *)heads->ptr, + out_tmp, + n_tokens, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention raw unpack launch"); + } + dim3 grid(n_tokens, n_head, 1); + attention_prefill_raw_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + n_tokens, window, n_head, head_dim); + return hip_ok(hipGetLastError(), "attention_prefill_raw launch"); +} +static int attention_decode_batch_launch( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + const ds4_gpu_tensor *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + if (!heads || !q || !raw_kv || !model_map || n_tokens == 0 || + n_raw == 0 || raw_cap < n_raw || raw_start >= raw_cap || + (n_comp != 0 && !comp_kv) || (use_comp_mask && !comp_mask) || + sinks_offset > model_size || + (uint64_t)n_head * sizeof(float) > model_size - sinks_offset || + heads->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + raw_kv->bytes < (uint64_t)raw_cap * head_dim * sizeof(float) || + (n_comp && comp_kv->bytes < (uint64_t)n_comp * head_dim * sizeof(float)) || + (use_comp_mask && comp_mask->bytes < (uint64_t)n_tokens * n_comp * sizeof(float))) { + return 0; + } + if (n_comp != 0 && ratio == 0) return 0; + const float *sinks = (const float *)hip_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + if (!hip_attention_score_buffer_fits(n_comp)) { + if (!use_comp_mask && head_dim == 512u && + getenv("DS4_HIP_NO_WINDOW_ATTENTION") == NULL) { + dim3 online_grid(n_tokens, (n_head + 7u) / 8u, 1); + attention_decode_mixed_heads8_online_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv->ptr, + n_tokens, + pos0, + n_raw, + raw_cap, + raw_start, + n_comp, + window, + ratio, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention decode online launch"); + } + fprintf(stderr, "ds4: ROCm attention score buffer too small for %u compressed rows\n", n_comp); + return 0; + } + if (!use_comp_mask && n_tokens > 1 && head_dim == 512 && + getenv("DS4_HIP_NO_WINDOW_ATTENTION") == NULL && + (getenv("DS4_HIP_WINDOW_ATTENTION") != NULL || (!g_quality_mode && n_tokens >= 128u))) { + dim3 grid(n_tokens, (n_head + 7u) / 8u, 1); + attention_decode_mixed_heads8_online_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv->ptr, + n_tokens, + pos0, + n_raw, + raw_cap, + raw_start, + n_comp, + window, + ratio, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention decode window launch"); + } + dim3 grid(n_tokens, n_head, 1); + attention_decode_mixed_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv->ptr, + use_comp_mask ? (const float *)comp_mask->ptr : NULL, + use_comp_mask, n_tokens, pos0, n_raw, raw_cap, + raw_start, n_comp, window, ratio, n_head, head_dim); + return hip_ok(hipGetLastError(), "attention decode batch launch"); +} + +extern "C" int ds4_gpu_attention_decode_raw_batch_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t window, + uint32_t n_head, + uint32_t head_dim) { + return attention_decode_batch_launch(heads, model_map, model_size, sinks_offset, + q, raw_kv, NULL, NULL, 0, n_tokens, pos0, + n_raw, raw_cap, raw_start, 0, window, 1, + n_head, head_dim); +} + +extern "C" int ds4_gpu_attention_decode_mixed_batch_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + const ds4_gpu_tensor *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + return attention_decode_batch_launch(heads, model_map, model_size, sinks_offset, + q, raw_kv, comp_kv, comp_mask, use_comp_mask, + n_tokens, pos0, n_raw, raw_cap, raw_start, + n_comp, window, ratio, n_head, head_dim); +} + +extern "C" int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + const ds4_gpu_tensor *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + if (!heads || !q || !raw_kv || !comp_kv || !topk || !model_map || + n_tokens == 0 || n_raw == 0 || raw_cap < n_raw || raw_start >= raw_cap || + n_comp == 0 || top_k == 0 || + sinks_offset > model_size || + (uint64_t)n_head * sizeof(float) > model_size - sinks_offset || + heads->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + raw_kv->bytes < (uint64_t)raw_cap * head_dim * sizeof(float) || + comp_kv->bytes < (uint64_t)n_comp * head_dim * sizeof(float) || + topk->bytes < (uint64_t)n_tokens * top_k * sizeof(int32_t)) { + return 0; + } + if (top_k > 512u) return 0; + const float *sinks = (const float *)hip_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + if (n_tokens > 1 && head_dim == 512 && top_k <= 512u && + getenv("DS4_HIP_NO_INDEXED_HEADS8") == NULL) { + dim3 grid(n_tokens, (n_head + 7u) / 8u, 1); + if (getenv("DS4_HIP_INDEXED_TWOPASS") == NULL) { + attention_indexed_mixed_heads8_online_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + (const float *)comp_kv->ptr, + (const int32_t *)topk->ptr, + n_tokens, + pos0, + n_raw, + raw_cap, + raw_start, + n_comp, + top_k, + window, + ratio, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention indexed online launch"); + } + attention_indexed_mixed_heads8_rb4_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + (const float *)comp_kv->ptr, + (const int32_t *)topk->ptr, + n_tokens, + pos0, + n_raw, + raw_cap, + raw_start, + n_comp, + top_k, + window, + ratio, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention indexed heads8 launch"); + } + dim3 grid(n_tokens, n_head, 1); + attention_indexed_mixed_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + (const float *)comp_kv->ptr, + (const int32_t *)topk->ptr, + n_tokens, + pos0, + n_raw, + raw_cap, + raw_start, + n_comp, + top_k, + window, + ratio, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention indexed mixed launch"); +} + +static int attention_prefill_mixed_launch( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + const ds4_gpu_tensor *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + if (!heads || !q || !raw_kv || !model_map || n_tokens == 0 || ratio == 0 || + (n_comp != 0 && !comp_kv) || (use_comp_mask && !comp_mask) || + sinks_offset > model_size || + (uint64_t)n_head * sizeof(float) > model_size - sinks_offset || + heads->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + raw_kv->bytes < (uint64_t)n_tokens * head_dim * sizeof(float) || + (n_comp && comp_kv->bytes < (uint64_t)n_comp * head_dim * sizeof(float)) || + (use_comp_mask && comp_mask->bytes < (uint64_t)n_tokens * n_comp * sizeof(float))) { + return 0; + } + const float *sinks = (const float *)hip_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + if (!use_comp_mask && n_tokens > 1 && head_dim == 512 && + getenv("DS4_HIP_NO_WINDOW_ATTENTION") == NULL && + (getenv("DS4_HIP_WINDOW_ATTENTION") != NULL || (!g_quality_mode && n_tokens >= 128u))) { + dim3 grid(n_tokens, (n_head + 7u) / 8u, 1); + attention_static_mixed_heads8_online_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv->ptr, + n_tokens, + n_comp, + window, + ratio, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention mixed window launch"); + } + if (g_hipblas_ready && n_tokens > 1 && head_dim == 512 && + getenv("DS4_HIP_NO_HIPBLAS_ATTENTION") == NULL) { + const uint32_t n_keys = n_tokens + n_comp; + const uint64_t kv_count = (uint64_t)n_keys * head_dim; + const uint64_t score_count = (uint64_t)n_head * n_tokens * n_keys; + const uint64_t out_count = (uint64_t)n_head * n_tokens * head_dim; + const uint64_t kv_bytes = kv_count * sizeof(float); + const uint64_t score_offset = (kv_bytes + 255u) & ~255ull; + const uint64_t score_bytes = score_count * sizeof(float); + const uint64_t out_offset = score_offset + ((score_bytes + 255u) & ~255ull); + const uint64_t tmp_bytes = out_offset + out_count * sizeof(float); + float *tmp = (float *)hip_tmp_alloc(tmp_bytes, "attention mixed hipblas"); + if (!tmp) return 0; + float *kv = tmp; + float *scores = (float *)((char *)tmp + score_offset); + float *out_tmp = (float *)((char *)tmp + out_offset); + attention_prefill_pack_mixed_kv_kernel<<<(kv_count + 255) / 256, 256>>>( + kv, + (const float *)raw_kv->ptr, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv->ptr, + n_tokens, + n_comp, + head_dim); + if (!hip_ok(hipGetLastError(), "attention mixed kv pack launch")) return 0; + const float alpha = rsqrtf((float)head_dim); + const float beta = 0.0f; + hipblasStatus_t st = hipblasSgemmStridedBatched(g_hipblas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + (int)n_keys, + (int)n_tokens, + (int)head_dim, + &alpha, + kv, + (int)head_dim, + 0, + (const float *)q->ptr, + (int)(n_head * head_dim), + (long long)head_dim, + &beta, + scores, + (int)n_keys, + (long long)n_keys * n_tokens, + (int)n_head); + if (!hipblas_ok(st, "attention mixed score gemm")) return 0; + dim3 sgrid(n_tokens, n_head, 1); + attention_prefill_mixed_softmax_kernel<<>>( + scores, + sinks, + use_comp_mask ? (const float *)comp_mask->ptr : NULL, + use_comp_mask, + n_tokens, + n_comp, + window, + ratio, + n_keys); + if (!hip_ok(hipGetLastError(), "attention mixed softmax launch")) return 0; + const float one = 1.0f; + st = hipblasSgemmStridedBatched(g_hipblas, + HIPBLAS_OP_N, + HIPBLAS_OP_N, + (int)head_dim, + (int)n_tokens, + (int)n_keys, + &one, + kv, + (int)head_dim, + 0, + scores, + (int)n_keys, + (long long)n_keys * n_tokens, + &beta, + out_tmp, + (int)head_dim, + (long long)head_dim * n_tokens, + (int)n_head); + if (!hipblas_ok(st, "attention mixed value gemm")) return 0; + uint64_t n = (uint64_t)n_tokens * n_head * head_dim; + attention_prefill_unpack_heads_kernel<<<(n + 255) / 256, 256>>>((float *)heads->ptr, + out_tmp, + n_tokens, + n_head, + head_dim); + return hip_ok(hipGetLastError(), "attention mixed unpack launch"); + } + dim3 grid(n_tokens, n_head, 1); + attention_prefill_mixed_kernel<<>>((float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const float *)raw_kv->ptr, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv->ptr, + use_comp_mask ? (const float *)comp_mask->ptr : NULL, + use_comp_mask, n_tokens, n_comp, window, ratio, + n_head, head_dim); + return hip_ok(hipGetLastError(), "attention prefill mixed launch"); +} + +extern "C" int ds4_gpu_attention_prefill_static_mixed_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + uint32_t n_tokens, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + return attention_prefill_mixed_launch(heads, model_map, model_size, sinks_offset, + q, raw_kv, comp_kv, NULL, 0, n_tokens, + n_comp, window, ratio, n_head, head_dim); +} + +extern "C" int ds4_gpu_attention_prefill_masked_mixed_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv, + const ds4_gpu_tensor *comp_kv, + const ds4_gpu_tensor *comp_mask, + uint32_t n_tokens, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + return attention_prefill_mixed_launch(heads, model_map, model_size, sinks_offset, + q, raw_kv, comp_kv, comp_mask, 1, n_tokens, + n_comp, window, ratio, n_head, head_dim); +} +extern "C" int ds4_gpu_attention_output_q8_batch_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *low, + ds4_gpu_tensor *group_tmp, + ds4_gpu_tensor *low_tmp, + const void *model_map, + uint64_t model_size, + uint64_t out_a_offset, + uint64_t out_b_offset, + uint64_t group_dim, + uint64_t rank, + uint32_t n_groups, + uint64_t out_dim, + const ds4_gpu_tensor *heads, + uint32_t n_tokens) { + (void)group_tmp; + (void)low_tmp; + if (!out || !low || !heads || !model_map || + group_dim == 0 || rank == 0 || n_groups == 0 || out_dim == 0 || n_tokens == 0) { + return 0; + } + const uint64_t low_dim = (uint64_t)n_groups * rank; + const uint64_t blocks_a = (group_dim + 31) / 32; + const uint64_t blocks_b = (low_dim + 31) / 32; + const uint64_t out_a_bytes = (uint64_t)n_groups * rank * blocks_a * 34; + const uint64_t out_b_bytes = out_dim * blocks_b * 34; + if (out_a_offset > model_size || out_b_offset > model_size || + out_a_bytes > model_size - out_a_offset || + out_b_bytes > model_size - out_b_offset || + heads->bytes < (uint64_t)n_tokens * n_groups * group_dim * sizeof(float) || + low->bytes < (uint64_t)n_tokens * low_dim * sizeof(float) || + out->bytes < (uint64_t)n_tokens * out_dim * sizeof(float)) { + return 0; + } + const unsigned char *out_a = reinterpret_cast( + hip_model_range_ptr(model_map, out_a_offset, out_a_bytes, "attn_out_a")); + const unsigned char *out_b = reinterpret_cast( + hip_model_range_ptr(model_map, out_b_offset, out_b_bytes, "attn_out_b")); + if (!out_a || !out_b) return 0; + + const __half *out_a_f16 = NULL; + uint32_t out_a_hipblas_min_tokens = 2u; + const char *out_a_min_env = getenv("DS4_HIP_ATTENTION_OUTPUT_A_HIPBLAS_MIN"); + if (out_a_min_env && out_a_min_env[0]) { + char *endp = NULL; + long v = strtol(out_a_min_env, &endp, 10); + if (endp != out_a_min_env && v > 1 && v < 4096) out_a_hipblas_min_tokens = (uint32_t)v; + } + if (!g_quality_mode && + g_hipblas_ready && + n_tokens >= out_a_hipblas_min_tokens && + getenv("DS4_HIP_NO_HIPBLAS_ATTENTION_OUTPUT_A") == NULL) { + out_a_f16 = hip_q8_f16_ptr(model_map, out_a_offset, out_a_bytes, group_dim, low_dim, "attn_output_a"); + } + if (out_a_f16) { + const uint64_t heads_h_count = (uint64_t)n_groups * n_tokens * group_dim; + const uint64_t low_tmp_count = (uint64_t)n_groups * n_tokens * rank; + const uint64_t heads_h_bytes = heads_h_count * sizeof(__half); + const uint64_t low_tmp_offset = (heads_h_bytes + 255u) & ~255ull; + const uint64_t tmp_bytes = low_tmp_offset + low_tmp_count * sizeof(float); + void *tmp = hip_tmp_alloc(tmp_bytes, "attention output a hipblas"); + if (!tmp) return 0; + __half *heads_h = (__half *)tmp; + float *low_packed = (float *)((char *)tmp + low_tmp_offset); + attention_pack_group_heads_f16_kernel<<<(heads_h_count + 255) / 256, 256>>>( + heads_h, + (const float *)heads->ptr, + n_tokens, + n_groups, + group_dim); + if (!hip_ok(hipGetLastError(), "attention_output_q8_a pack launch")) return 0; + const float alpha = 1.0f; + const float beta = 0.0f; + hipblasStatus_t st = hipblasGemmStridedBatchedEx(g_hipblas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + (int)rank, + (int)n_tokens, + (int)group_dim, + &alpha, + out_a_f16, + HIP_R_16F, + (int)group_dim, + (long long)rank * group_dim, + heads_h, + HIP_R_16F, + (int)group_dim, + (long long)n_tokens * group_dim, + &beta, + low_packed, + HIP_R_32F, + (int)rank, + (long long)rank * n_tokens, + (int)n_groups, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT); + if (!hipblas_ok(st, "attention output a gemm")) return 0; + attention_unpack_group_low_kernel<<<(low_tmp_count + 255) / 256, 256>>>( + (float *)low->ptr, + low_packed, + n_tokens, + n_groups, + rank); + if (!hip_ok(hipGetLastError(), "attention_output_q8_a unpack launch")) return 0; + } else { + const uint64_t x_rows = (uint64_t)n_tokens * n_groups; + const uint64_t xq_bytes = x_rows * blocks_a * 32u; + const uint64_t scale_offset = (xq_bytes + 15u) & ~15ull; + const uint64_t tmp_bytes = scale_offset + x_rows * blocks_a * sizeof(float); + void *tmp = hip_tmp_alloc(tmp_bytes, "attention output a q8 prequant"); + if (!tmp) return 0; + int8_t *xq = (int8_t *)tmp; + float *xscale = (float *)((char *)tmp + scale_offset); + const int use_dp4a = hip_q8_use_dp4a(); + dim3 qgrid((unsigned)blocks_a, (unsigned)x_rows, 1); + quantize_q8_0_f32_kernel<<>>(xq, + xscale, + (const float *)heads->ptr, + group_dim, + blocks_a); + if (!hip_ok(hipGetLastError(), "attention_output_q8_a prequant launch")) return 0; + dim3 grid_a(((unsigned)low_dim + 7u) / 8u, (unsigned)n_tokens, 1); + grouped_q8_0_a_preq_warp8_kernel<<>>((float *)low->ptr, + out_a, + xq, + xscale, + group_dim, + rank, + n_groups, + n_tokens, + blocks_a, + use_dp4a); + if (!hip_ok(hipGetLastError(), "attention_output_q8_a preq launch")) return 0; + } + + (void)out_b; + return hip_matmul_q8_0_tensor_labeled(out, + model_map, + model_size, + out_b_offset, + low_dim, + out_dim, + low, + n_tokens, + "attn_output_b"); +} +extern "C" int ds4_gpu_attention_output_low_q8_tensor( + ds4_gpu_tensor *low, + const void *model_map, + uint64_t model_size, + uint64_t out_a_offset, + uint64_t group_dim, + uint64_t rank, + uint32_t n_groups, + const ds4_gpu_tensor *heads) { + if (!low || !heads || !model_map || group_dim == 0 || rank == 0 || n_groups == 0) { + return 0; + } + const uint64_t low_dim = (uint64_t)n_groups * rank; + const uint64_t blocks_a = (group_dim + 31) / 32; + const uint64_t out_a_bytes = (uint64_t)n_groups * rank * blocks_a * 34; + if (out_a_offset > model_size || + out_a_bytes > model_size - out_a_offset || + heads->bytes < (uint64_t)n_groups * group_dim * sizeof(float) || + low->bytes < low_dim * sizeof(float)) { + return 0; + } + const unsigned char *out_a = reinterpret_cast( + hip_model_range_ptr(model_map, out_a_offset, out_a_bytes, "attn_out_a")); + if (!out_a) return 0; + + const uint64_t x_rows = (uint64_t)n_groups; + const uint64_t xq_bytes = x_rows * blocks_a * 32u; + const uint64_t scale_offset = (xq_bytes + 15u) & ~15ull; + const uint64_t tmp_bytes = scale_offset + x_rows * blocks_a * sizeof(float); + void *tmp = hip_tmp_alloc(tmp_bytes, "attention output low q8 prequant"); + if (!tmp) return 0; + int8_t *xq = (int8_t *)tmp; + float *xscale = (float *)((char *)tmp + scale_offset); + const int use_dp4a = hip_q8_use_dp4a(); + dim3 qgrid((unsigned)blocks_a, (unsigned)x_rows, 1); + quantize_q8_0_f32_kernel<<>>(xq, + xscale, + (const float *)heads->ptr, + group_dim, + blocks_a); + if (!hip_ok(hipGetLastError(), "attention_output_low_q8 prequant launch")) return 0; + dim3 grid_a(((unsigned)low_dim + 7u) / 8u, 1, 1); + grouped_q8_0_a_preq_warp8_kernel<<>>((float *)low->ptr, + out_a, + xq, + xscale, + group_dim, + rank, + n_groups, + 1, + blocks_a, + use_dp4a); + return hip_ok(hipGetLastError(), "attention_output_low_q8 launch"); +} +extern "C" int ds4_gpu_swiglu_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *gate, const ds4_gpu_tensor *up, uint32_t n, float clamp, float weight) { + if (!out || !gate || !up || + out->bytes < (uint64_t)n * sizeof(float) || + gate->bytes < (uint64_t)n * sizeof(float) || + up->bytes < (uint64_t)n * sizeof(float)) return 0; + swiglu_kernel<<<(n + 255) / 256, 256>>>((float *)out->ptr, (const float *)gate->ptr, (const float *)up->ptr, n, clamp, weight); + return hip_ok(hipGetLastError(), "swiglu launch"); +} +extern "C" int ds4_gpu_shared_gate_up_swiglu_q8_0_tensor( + ds4_gpu_tensor *gate, + ds4_gpu_tensor *up, + ds4_gpu_tensor *mid, + const void *model_map, + uint64_t model_size, + uint64_t gate_offset, + uint64_t up_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x) { + if (getenv("DS4_HIP_DISABLE_SHARED_GATE_UP_PAIR") == NULL) { + return ds4_gpu_matmul_q8_0_pair_tensor(gate, up, + model_map, model_size, + gate_offset, up_offset, + in_dim, out_dim, out_dim, + x, 1) && + ds4_gpu_swiglu_tensor(mid, gate, up, (uint32_t)out_dim, 10.0f, 1.0f); + } + return ds4_gpu_matmul_q8_0_tensor(gate, model_map, model_size, + gate_offset, in_dim, out_dim, x, 1) && + ds4_gpu_matmul_q8_0_tensor(up, model_map, model_size, + up_offset, in_dim, out_dim, x, 1) && + ds4_gpu_swiglu_tensor(mid, gate, up, (uint32_t)out_dim, 10.0f, 1.0f); +} +extern "C" int ds4_gpu_add_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *a, const ds4_gpu_tensor *b, uint32_t n) { + if (!out || !a || !b || + out->bytes < (uint64_t)n * sizeof(float) || + a->bytes < (uint64_t)n * sizeof(float) || + b->bytes < (uint64_t)n * sizeof(float)) return 0; + add_kernel<<<(n + 255) / 256, 256>>>((float *)out->ptr, (const float *)a->ptr, (const float *)b->ptr, n); + return hip_ok(hipGetLastError(), "add launch"); +} +extern "C" int ds4_gpu_directional_steering_project_tensor( + ds4_gpu_tensor *x, + const ds4_gpu_tensor *directions, + uint32_t layer, + uint32_t width, + uint32_t rows, + float scale) { + if (!x || !directions || width == 0 || rows == 0 || scale == 0.0f) return 0; + const uint64_t x_bytes = (uint64_t)width * rows * sizeof(float); + const uint64_t dir_bytes = (uint64_t)(layer + 1u) * width * sizeof(float); + if (x->bytes < x_bytes || directions->bytes < dir_bytes) return 0; + + uint32_t nth = 256u; + while (nth > width && nth > 1u) nth >>= 1; + directional_steering_project_kernel<<>>( + (float *)x->ptr, + (const float *)directions->ptr, + layer, + width, + rows, + scale); + return hip_ok(hipGetLastError(), "directional steering launch"); +} +extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t token, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits) { + if (!selected || !weights || !probs || !logits || !model_map || n_expert_groups > 1u || n_group_used > 0u) return 0; + int32_t tok = (int32_t)token; + int ok = 1; + const float *bias = NULL; + const int32_t *hash = NULL; + if (ok && has_bias && !hash_mode) { + if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) ok = 0; + else bias = (const float *)hip_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias"); + if (!bias) ok = 0; + } + if (ok && hash_mode) { + const uint64_t hash_bytes = (uint64_t)hash_rows * 6u * sizeof(int32_t); + if (hash_offset > model_size || hash_bytes > model_size - hash_offset) ok = 0; + else hash = (const int32_t *)hip_model_range_ptr(model_map, hash_offset, hash_bytes, "router_hash"); + if (!hash) ok = 0; + } + if (ok) { + if (getenv("DS4_HIP_NO_WARP_ROUTER_SELECT") == NULL && + getenv("DS4_HIP_NO_PARALLEL_ROUTER_SELECT") == NULL) { + dim3 block(32, 4, 1); + router_select_warp_topk_kernel<<<1, block>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, + bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1, + has_bias && !hash_mode, hash_mode); + } else if (getenv("DS4_HIP_NO_PARALLEL_ROUTER_SELECT") == NULL) { + router_select_parallel_kernel<<<1, 256>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, + bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1, + has_bias && !hash_mode, hash_mode); + } else { + router_select_kernel<<<1, 1>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, + bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1, + has_bias && !hash_mode, hash_mode); + } + ok = hip_ok(hipGetLastError(), "router_select launch"); + } + return ok; +} +extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits, const ds4_gpu_tensor *tokens, uint32_t n_tokens) { + if (!selected || !weights || !probs || !logits || !tokens || !model_map || n_tokens == 0 || + n_expert_groups > 1u || n_group_used > 0u || + logits->bytes < (uint64_t)n_tokens * 256u * sizeof(float) || + probs->bytes < (uint64_t)n_tokens * 256u * sizeof(float) || + selected->bytes < (uint64_t)n_tokens * 6u * sizeof(int32_t) || + weights->bytes < (uint64_t)n_tokens * 6u * sizeof(float)) { + return 0; + } + const float *bias = NULL; + const int32_t *hash = NULL; + if (has_bias && !hash_mode) { + if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) return 0; + bias = (const float *)hip_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias"); + if (!bias) return 0; + } + if (hash_mode) { + const uint64_t hash_bytes = (uint64_t)hash_rows * 6u * sizeof(int32_t); + if (hash_offset > model_size || hash_bytes > model_size - hash_offset) return 0; + hash = (const int32_t *)hip_model_range_ptr(model_map, hash_offset, hash_bytes, "router_hash"); + if (!hash) return 0; + } + if (getenv("DS4_HIP_NO_WARP_ROUTER_SELECT") == NULL && + getenv("DS4_HIP_NO_PARALLEL_ROUTER_SELECT") == NULL) { + dim3 block(32, 4, 1); + router_select_warp_topk_kernel<<<(n_tokens + 3u) / 4u, block>>>((int32_t *)selected->ptr, + (float *)weights->ptr, + (float *)probs->ptr, + bias, + hash, + (const float *)logits->ptr, + (const int32_t *)tokens->ptr, + 0, + hash_rows, + n_tokens, + has_bias && !hash_mode, + hash_mode); + } else if (getenv("DS4_HIP_NO_PARALLEL_ROUTER_SELECT") == NULL) { + router_select_parallel_kernel<<>>((int32_t *)selected->ptr, + (float *)weights->ptr, + (float *)probs->ptr, + bias, + hash, + (const float *)logits->ptr, + (const int32_t *)tokens->ptr, + 0, + hash_rows, + n_tokens, + has_bias && !hash_mode, + hash_mode); + } else { + router_select_kernel<<>>((int32_t *)selected->ptr, + (float *)weights->ptr, + (float *)probs->ptr, + bias, + hash, + (const float *)logits->ptr, + (const int32_t *)tokens->ptr, + 0, + hash_rows, + n_tokens, + has_bias && !hash_mode, + hash_mode); + } + return hip_ok(hipGetLastError(), "router_select launch"); +} + +__device__ static float dev_f16_to_f32(uint16_t v) { + return __half2float(*reinterpret_cast(&v)); +} + +__device__ __forceinline__ static uint32_t dev_unpack_iq2_signs(uint32_t v) { + const uint32_t p = __popc(v) & 1u; + const uint32_t s = v ^ (p << 7u); + return s * 0x01010101u; +} + +__device__ __forceinline__ static int32_t dev_iq2_dp4a_8(uint64_t grid, uint32_t sign, const int8_t *q8, int32_t acc) { + const uint32_t signs = dev_unpack_iq2_signs(sign); + const int32_t sm0 = __vcmpne4(signs & 0x08040201u, 0); + const int32_t sm1 = __vcmpne4(signs & 0x80402010u, 0); + const int32_t g0 = __vsub4((int32_t)(uint32_t)grid ^ sm0, sm0); + const int32_t g1 = __vsub4((int32_t)(uint32_t)(grid >> 32) ^ sm1, sm1); + acc = __dp4a(g0, *(const int32_t *)(q8 + 0), acc); + acc = __dp4a(g1, *(const int32_t *)(q8 + 4), acc); + return acc; +} + +__device__ static int32_t dev_dot_q2_16(const uint8_t *q2, const int8_t *q8, int shift) { + int32_t sum = 0; + #pragma unroll + for (uint32_t i = 0; i < 16; i += 4) { + const int32_t v = (*(const int32_t *)(q2 + i) >> shift) & 0x03030303; + sum = __dp4a(v, *(const int32_t *)(q8 + i), sum); + } + return sum; +} + +__device__ static int32_t dev_dot_iq2_pair_16(uint8_t grid0, uint32_t sign0, uint8_t grid1, uint32_t sign1, const int8_t *q8) { + int32_t sum = 0; + sum = dev_iq2_dp4a_8(hip_iq2xxs_grid[grid0], hip_ksigns_iq2xs[sign0], q8, sum); + sum = dev_iq2_dp4a_8(hip_iq2xxs_grid[grid1], hip_ksigns_iq2xs[sign1], q8 + 8, sum); + return sum; +} + +__device__ __forceinline__ static void dev_iq2_i8x8_lut( + const uint64_t *grid, + const uint8_t *signs, + uint8_t grid_idx, + uint32_t sign_idx, + int32_t *w0, + int32_t *w1) { + const uint32_t s = dev_unpack_iq2_signs(signs[sign_idx]); + const int32_t sm0 = __vcmpne4(s & 0x08040201u, 0); + const int32_t sm1 = __vcmpne4(s & 0x80402010u, 0); + const uint64_t g = grid[grid_idx]; + *w0 = __vsub4((int32_t)(uint32_t)g ^ sm0, sm0); + *w1 = __vsub4((int32_t)(uint32_t)(g >> 32) ^ sm1, sm1); +} + +__device__ static float dev_dot_iq2_xxs_q8_K_block_lut( + const hip_block_iq2_xxs *x, + const hip_block_q8_K *y, + const uint64_t *grid, + const uint8_t *signs) { + const float xd = dev_f16_to_f32(x->d); + const uint16_t *q2 = x->qs; + const int8_t *q8 = y->qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < HIP_QK_K / 32; ib32++) { + const uint32_t aux0 = (uint32_t)q2[0] | ((uint32_t)q2[1] << 16); + const uint32_t aux1 = (uint32_t)q2[2] | ((uint32_t)q2[3] << 16); + q2 += 4; + const int32_t ls = (int32_t)(2u * (aux1 >> 28) + 1u); + int32_t w[8]; + dev_iq2_i8x8_lut(grid, signs, (uint8_t)(aux0 & 0xffu), (aux1 >> 0) & 127u, &w[0], &w[1]); + dev_iq2_i8x8_lut(grid, signs, (uint8_t)((aux0 >> 8) & 0xffu), (aux1 >> 7) & 127u, &w[2], &w[3]); + dev_iq2_i8x8_lut(grid, signs, (uint8_t)((aux0 >> 16) & 0xffu), (aux1 >> 14) & 127u, &w[4], &w[5]); + dev_iq2_i8x8_lut(grid, signs, (uint8_t)((aux0 >> 24) & 0xffu), (aux1 >> 21) & 127u, &w[6], &w[7]); + int32_t sumi = 0; + sumi = __dp4a(w[0], *(const int32_t *)(q8 + ib32 * 32u + 0), sumi); + sumi = __dp4a(w[1], *(const int32_t *)(q8 + ib32 * 32u + 4), sumi); + sumi = __dp4a(w[2], *(const int32_t *)(q8 + ib32 * 32u + 8), sumi); + sumi = __dp4a(w[3], *(const int32_t *)(q8 + ib32 * 32u + 12), sumi); + sumi = __dp4a(w[4], *(const int32_t *)(q8 + ib32 * 32u + 16), sumi); + sumi = __dp4a(w[5], *(const int32_t *)(q8 + ib32 * 32u + 20), sumi); + sumi = __dp4a(w[6], *(const int32_t *)(q8 + ib32 * 32u + 24), sumi); + sumi = __dp4a(w[7], *(const int32_t *)(q8 + ib32 * 32u + 28), sumi); + bsum += sumi * ls; + } + return 0.125f * xd * y->d * (float)bsum; +} + +__device__ static float dev_dot_iq2_xxs_q8_K_block(const hip_block_iq2_xxs *x, const hip_block_q8_K *y) { + const float d = dev_f16_to_f32(x->d) * y->d; + const uint16_t *q2 = x->qs; + const int8_t *q8 = y->qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < HIP_QK_K / 32; ib32++) { + const uint32_t aux0 = (uint32_t)q2[0] | ((uint32_t)q2[1] << 16); + const uint32_t aux1 = (uint32_t)q2[2] | ((uint32_t)q2[3] << 16); + q2 += 4; + const uint32_t ls = 2u * (aux1 >> 28) + 1u; + const uint8_t a0 = (uint8_t)(aux0 & 0xffu); + const uint8_t a1 = (uint8_t)((aux0 >> 8) & 0xffu); + const uint8_t a2 = (uint8_t)((aux0 >> 16) & 0xffu); + const uint8_t a3 = (uint8_t)((aux0 >> 24) & 0xffu); + int32_t sumi = 0; + sumi += dev_dot_iq2_pair_16(a0, (aux1 >> 0) & 127u, a1, (aux1 >> 7) & 127u, q8); + q8 += 16; + sumi += dev_dot_iq2_pair_16(a2, (aux1 >> 14) & 127u, a3, (aux1 >> 21) & 127u, q8); + q8 += 16; + bsum += sumi * (int32_t)ls; + } + return 0.125f * d * (float)bsum; +} + +__device__ static void dev_dot_iq2_xxs_q8_K_block8_deq_lut( + const hip_block_iq2_xxs *x, + const hip_block_q8_K *y0, + const hip_block_q8_K *y1, + const hip_block_q8_K *y2, + const hip_block_q8_K *y3, + const hip_block_q8_K *y4, + const hip_block_q8_K *y5, + const hip_block_q8_K *y6, + const hip_block_q8_K *y7, + uint32_t n, + float acc[8], + const uint64_t *grid, + const uint8_t *signs) { + const float xd = dev_f16_to_f32(x->d); + const uint16_t *q2 = x->qs; + int32_t bsum[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const int8_t *q8[8] = { + y0 ? y0->qs : NULL, y1 ? y1->qs : NULL, y2 ? y2->qs : NULL, y3 ? y3->qs : NULL, + y4 ? y4->qs : NULL, y5 ? y5->qs : NULL, y6 ? y6->qs : NULL, y7 ? y7->qs : NULL, + }; + for (int ib32 = 0; ib32 < HIP_QK_K / 32; ib32++) { + const uint32_t aux0 = (uint32_t)q2[0] | ((uint32_t)q2[1] << 16); + const uint32_t aux1 = (uint32_t)q2[2] | ((uint32_t)q2[3] << 16); + q2 += 4; + const int32_t ls = (int32_t)(2u * (aux1 >> 28) + 1u); + int32_t w[8]; + dev_iq2_i8x8_lut(grid, signs, (uint8_t)(aux0 & 0xffu), (aux1 >> 0) & 127u, &w[0], &w[1]); + dev_iq2_i8x8_lut(grid, signs, (uint8_t)((aux0 >> 8) & 0xffu), (aux1 >> 7) & 127u, &w[2], &w[3]); + dev_iq2_i8x8_lut(grid, signs, (uint8_t)((aux0 >> 16) & 0xffu), (aux1 >> 14) & 127u, &w[4], &w[5]); + dev_iq2_i8x8_lut(grid, signs, (uint8_t)((aux0 >> 24) & 0xffu), (aux1 >> 21) & 127u, &w[6], &w[7]); + for (uint32_t p = 0; p < n; p++) { + const int8_t *q = q8[p] + ib32 * 32; + int32_t sumi = 0; + sumi = __dp4a(w[0], *(const int32_t *)(q + 0), sumi); + sumi = __dp4a(w[1], *(const int32_t *)(q + 4), sumi); + sumi = __dp4a(w[2], *(const int32_t *)(q + 8), sumi); + sumi = __dp4a(w[3], *(const int32_t *)(q + 12), sumi); + sumi = __dp4a(w[4], *(const int32_t *)(q + 16), sumi); + sumi = __dp4a(w[5], *(const int32_t *)(q + 20), sumi); + sumi = __dp4a(w[6], *(const int32_t *)(q + 24), sumi); + sumi = __dp4a(w[7], *(const int32_t *)(q + 28), sumi); + bsum[p] += sumi * ls; + } + } + const hip_block_q8_K *ys[8] = { y0, y1, y2, y3, y4, y5, y6, y7 }; + for (uint32_t p = 0; p < n; p++) acc[p] += 0.125f * xd * ys[p]->d * (float)bsum[p]; +} + +__device__ static void dev_dot_iq2_xxs_q8_K_block4( + const hip_block_iq2_xxs *x, + const hip_block_q8_K *y0, + const hip_block_q8_K *y1, + const hip_block_q8_K *y2, + const hip_block_q8_K *y3, + uint32_t n, + float acc[4]) { + const float xd = dev_f16_to_f32(x->d); + const uint16_t *q2 = x->qs; + int32_t bsum[4] = {0, 0, 0, 0}; + const int8_t *q8[4] = { + y0 ? y0->qs : NULL, + y1 ? y1->qs : NULL, + y2 ? y2->qs : NULL, + y3 ? y3->qs : NULL, + }; + for (int ib32 = 0; ib32 < HIP_QK_K / 32; ib32++) { + const uint32_t aux0 = (uint32_t)q2[0] | ((uint32_t)q2[1] << 16); + const uint32_t aux1 = (uint32_t)q2[2] | ((uint32_t)q2[3] << 16); + q2 += 4; + const uint32_t ls = 2u * (aux1 >> 28) + 1u; + const uint8_t a0 = (uint8_t)(aux0 & 0xffu); + const uint8_t a1 = (uint8_t)((aux0 >> 8) & 0xffu); + const uint8_t a2 = (uint8_t)((aux0 >> 16) & 0xffu); + const uint8_t a3 = (uint8_t)((aux0 >> 24) & 0xffu); + for (uint32_t p = 0; p < n; p++) { + int32_t sumi = 0; + sumi += dev_dot_iq2_pair_16(a0, (aux1 >> 0) & 127u, a1, (aux1 >> 7) & 127u, q8[p] + ib32 * 32); + sumi += dev_dot_iq2_pair_16(a2, (aux1 >> 14) & 127u, a3, (aux1 >> 21) & 127u, q8[p] + ib32 * 32 + 16); + bsum[p] += sumi * (int32_t)ls; + } + } + const hip_block_q8_K *ys[4] = { y0, y1, y2, y3 }; + for (uint32_t p = 0; p < n; p++) acc[p] += 0.125f * xd * ys[p]->d * (float)bsum[p]; +} + +__device__ static DS4_HIP_UNUSED void dev_dot_iq2_xxs_q8_K_block8( + const hip_block_iq2_xxs *x, + const hip_block_q8_K *y0, + const hip_block_q8_K *y1, + const hip_block_q8_K *y2, + const hip_block_q8_K *y3, + const hip_block_q8_K *y4, + const hip_block_q8_K *y5, + const hip_block_q8_K *y6, + const hip_block_q8_K *y7, + uint32_t n, + float acc[8]) { + const float xd = dev_f16_to_f32(x->d); + const uint16_t *q2 = x->qs; + int32_t bsum[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const int8_t *q8[8] = { + y0 ? y0->qs : NULL, y1 ? y1->qs : NULL, y2 ? y2->qs : NULL, y3 ? y3->qs : NULL, + y4 ? y4->qs : NULL, y5 ? y5->qs : NULL, y6 ? y6->qs : NULL, y7 ? y7->qs : NULL, + }; + for (int ib32 = 0; ib32 < HIP_QK_K / 32; ib32++) { + const uint32_t aux0 = (uint32_t)q2[0] | ((uint32_t)q2[1] << 16); + const uint32_t aux1 = (uint32_t)q2[2] | ((uint32_t)q2[3] << 16); + q2 += 4; + const uint32_t ls = 2u * (aux1 >> 28) + 1u; + const uint8_t a0 = (uint8_t)(aux0 & 0xffu); + const uint8_t a1 = (uint8_t)((aux0 >> 8) & 0xffu); + const uint8_t a2 = (uint8_t)((aux0 >> 16) & 0xffu); + const uint8_t a3 = (uint8_t)((aux0 >> 24) & 0xffu); + for (uint32_t p = 0; p < n; p++) { + int32_t sumi = 0; + sumi += dev_dot_iq2_pair_16(a0, (aux1 >> 0) & 127u, a1, (aux1 >> 7) & 127u, q8[p] + ib32 * 32); + sumi += dev_dot_iq2_pair_16(a2, (aux1 >> 14) & 127u, a3, (aux1 >> 21) & 127u, q8[p] + ib32 * 32 + 16); + bsum[p] += sumi * (int32_t)ls; + } + } + const hip_block_q8_K *ys[8] = { y0, y1, y2, y3, y4, y5, y6, y7 }; + for (uint32_t p = 0; p < n; p++) acc[p] += 0.125f * xd * ys[p]->d * (float)bsum[p]; +} + +__device__ static float dev_dot_q2_K_q8_K_block(const hip_block_q2_K *x, const hip_block_q8_K *y) { + const uint8_t *q2 = x->qs; + const int8_t *q8 = y->qs; + const uint8_t *sc = x->scales; + int summs = 0; + #pragma unroll + for (int j = 0; j < 16; j++) summs += y->bsums[j] * (sc[j] >> 4); + const float dall = y->d * dev_f16_to_f32(x->d); + const float dmin = y->d * dev_f16_to_f32(x->dmin); + int isum = 0; + int is = 0; + #pragma unroll + for (int k = 0; k < HIP_QK_K / 128; k++) { + int shift = 0; + #pragma unroll + for (int j = 0; j < 4; j++) { + int d0 = sc[is++] & 0x0f; + isum += d0 * dev_dot_q2_16(q2, q8, shift); + int d1 = sc[is++] & 0x0f; + isum += d1 * dev_dot_q2_16(q2 + 16, q8 + 16, shift); + shift += 2; + q8 += 32; + } + q2 += 32; + } + return dall * (float)isum - dmin * (float)summs; +} + +__device__ static void dev_dot_q2_K_q8_K_block4( + const hip_block_q2_K *x, + const hip_block_q8_K *y0, + const hip_block_q8_K *y1, + const hip_block_q8_K *y2, + const hip_block_q8_K *y3, + uint32_t n, + float acc[4]) { + const uint8_t *sc = x->scales; + const float xd = dev_f16_to_f32(x->d); + const float xmin = dev_f16_to_f32(x->dmin); + const hip_block_q8_K *ys[4] = { y0, y1, y2, y3 }; + int isum[4] = {0, 0, 0, 0}; + int summs[4] = {0, 0, 0, 0}; + for (uint32_t p = 0; p < n; p++) { + for (int j = 0; j < 16; j++) summs[p] += ys[p]->bsums[j] * (sc[j] >> 4); + } + for (uint32_t p = 0; p < n; p++) { + const uint8_t *q2 = x->qs; + const int8_t *q8 = ys[p]->qs; + int is = 0; + for (int k = 0; k < HIP_QK_K / 128; k++) { + int shift = 0; + for (int j = 0; j < 4; j++) { + int d = sc[is++] & 0x0f; + isum[p] += d * dev_dot_q2_16(q2, q8, shift); + d = sc[is++] & 0x0f; + isum[p] += d * dev_dot_q2_16(q2 + 16, q8 + 16, shift); + shift += 2; + q8 += 32; + } + q2 += 32; + } + } + for (uint32_t p = 0; p < n; p++) { + const float yd = ys[p]->d; + acc[p] += yd * xd * (float)isum[p] - yd * xmin * (float)summs[p]; + } +} + +__device__ static void dev_dot_q2_K_q8_K_block8( + const hip_block_q2_K *x, + const hip_block_q8_K *y0, + const hip_block_q8_K *y1, + const hip_block_q8_K *y2, + const hip_block_q8_K *y3, + const hip_block_q8_K *y4, + const hip_block_q8_K *y5, + const hip_block_q8_K *y6, + const hip_block_q8_K *y7, + uint32_t n, + float acc[8]) { + const uint8_t *sc = x->scales; + const float xd = dev_f16_to_f32(x->d); + const float xmin = dev_f16_to_f32(x->dmin); + const hip_block_q8_K *ys[8] = { y0, y1, y2, y3, y4, y5, y6, y7 }; + int isum[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + int summs[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + for (uint32_t p = 0; p < n; p++) { + for (int j = 0; j < 16; j++) summs[p] += ys[p]->bsums[j] * (sc[j] >> 4); + } + for (uint32_t p = 0; p < n; p++) { + const uint8_t *q2 = x->qs; + const int8_t *q8 = ys[p]->qs; + int is = 0; + for (int k = 0; k < HIP_QK_K / 128; k++) { + int shift = 0; + for (int j = 0; j < 4; j++) { + int d = sc[is++] & 0x0f; + isum[p] += d * dev_dot_q2_16(q2, q8, shift); + d = sc[is++] & 0x0f; + isum[p] += d * dev_dot_q2_16(q2 + 16, q8 + 16, shift); + shift += 2; + q8 += 32; + } + q2 += 32; + } + } + for (uint32_t p = 0; p < n; p++) { + const float yd = ys[p]->d; + acc[p] += yd * xd * (float)isum[p] - yd * xmin * (float)summs[p]; + } +} + +__device__ static void dev_dot_q2_K_q8_K_block16( + const hip_block_q2_K *x, + const hip_block_q8_K *y0, + const hip_block_q8_K *y1, + const hip_block_q8_K *y2, + const hip_block_q8_K *y3, + const hip_block_q8_K *y4, + const hip_block_q8_K *y5, + const hip_block_q8_K *y6, + const hip_block_q8_K *y7, + const hip_block_q8_K *y8, + const hip_block_q8_K *y9, + const hip_block_q8_K *y10, + const hip_block_q8_K *y11, + const hip_block_q8_K *y12, + const hip_block_q8_K *y13, + const hip_block_q8_K *y14, + const hip_block_q8_K *y15, + uint32_t n, + float acc[16]) { + const uint8_t *sc = x->scales; + const float xd = dev_f16_to_f32(x->d); + const float xmin = dev_f16_to_f32(x->dmin); + const hip_block_q8_K *ys[16] = { + y0, y1, y2, y3, y4, y5, y6, y7, + y8, y9, y10, y11, y12, y13, y14, y15, + }; + int isum[16] = {0}; + int summs[16] = {0}; + for (uint32_t p = 0; p < n; p++) { + #pragma unroll + for (int j = 0; j < 16; j++) summs[p] += ys[p]->bsums[j] * (sc[j] >> 4); + } + + for (uint32_t p = 0; p < n; p++) { + const uint8_t *q2 = x->qs; + const int8_t *q8 = ys[p]->qs; + int is = 0; + for (int k = 0; k < HIP_QK_K / 128; k++) { + int shift = 0; + for (int j = 0; j < 4; j++) { + int d = sc[is++] & 0x0f; + isum[p] += d * dev_dot_q2_16(q2, q8, shift); + d = sc[is++] & 0x0f; + isum[p] += d * dev_dot_q2_16(q2 + 16, q8 + 16, shift); + shift += 2; + q8 += 32; + } + q2 += 32; + } + } + for (uint32_t p = 0; p < n; p++) { + const float yd = ys[p]->d; + acc[p] += yd * xd * (float)isum[p] - yd * xmin * (float)summs[p]; + } +} + +__device__ static float half_warp_sum_f32(float v, uint32_t lane16) { + uint32_t mask = 0xffffu << (threadIdx.x & 16u); + for (int offset = 8; offset > 0; offset >>= 1) { + v += __shfl_down(v, offset, 16); + } + (void)lane16; + return v; +} + +__device__ static float quarter_warp_sum_f32(float v, uint32_t lane8) { + for (int offset = 4; offset > 0; offset >>= 1) { + v += __shfl_down(v, offset, 8); + } + (void)lane8; + return v; +} + +__global__ static void q8_K_quantize_kernel(hip_block_q8_K *out, const float *x, uint32_t in_dim, uint32_t n_rows) { + uint32_t b = blockIdx.x; + uint32_t row = blockIdx.y; + if (row >= n_rows || b >= in_dim / HIP_QK_K) return; + const float *xr = x + (uint64_t)row * in_dim + (uint64_t)b * HIP_QK_K; + hip_block_q8_K *yb = out + (uint64_t)row * (in_dim / HIP_QK_K) + b; + __shared__ float abs_part[256]; + __shared__ float val_part[256]; + __shared__ float maxv_s; + __shared__ float iscale_s; + uint32_t tid = threadIdx.x; + float v = tid < HIP_QK_K ? xr[tid] : 0.0f; + abs_part[tid] = tid < HIP_QK_K ? fabsf(v) : 0.0f; + val_part[tid] = v; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (tid < stride && abs_part[tid + stride] > abs_part[tid]) { + abs_part[tid] = abs_part[tid + stride]; + val_part[tid] = val_part[tid + stride]; + } + __syncthreads(); + } + float amax = abs_part[0]; + if (amax == 0.0f) { + if (tid == 0) yb->d = 0.0f; + if (tid < HIP_QK_K) yb->qs[tid] = 0; + if (tid < HIP_QK_K / 16) yb->bsums[tid] = 0; + return; + } + if (tid == 0) { + maxv_s = val_part[0]; + iscale_s = -127.0f / maxv_s; + } + __syncthreads(); + if (tid < HIP_QK_K) { + int qv = (int)lrintf(iscale_s * xr[tid]); + if (qv > 127) qv = 127; + if (qv < -128) qv = -128; + yb->qs[tid] = (int8_t)qv; + } + __syncthreads(); + if (tid < HIP_QK_K / 16) { + int sum = 0; + for (int i = 0; i < 16; i++) sum += yb->qs[tid * 16 + i]; + yb->bsums[tid] = (int16_t)sum; + } + if (tid == 0) yb->d = 1.0f / iscale_s; +} + +__global__ static DS4_HIP_UNUSED void moe_gate_up_mid_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const int32_t *selected, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + float clamp) { + uint32_t row = blockIdx.x; + uint32_t pair = blockIdx.y; + if (row >= expert_mid_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + uint32_t expert = (uint32_t)expert_i; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_q8_K *xqb = xq + (uint64_t)tok * xq_blocks; + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = threadIdx.x; b < xq_blocks; b += blockDim.x) { + gate += dev_dot_iq2_xxs_q8_K_block(gr + b, xqb + b); + up += dev_dot_iq2_xxs_q8_K_block(ur + b, xqb + b); + } + __shared__ float partial_gate[256]; + __shared__ float partial_up[256]; + partial_gate[threadIdx.x] = gate; + partial_up[threadIdx.x] = up; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + partial_gate[threadIdx.x] += partial_gate[threadIdx.x + stride]; + partial_up[threadIdx.x] += partial_up[threadIdx.x + stride]; + } + __syncthreads(); + } + if (threadIdx.x == 0) { + gate = partial_gate[0]; + up = partial_up[0]; + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + gate_out[off] = gate; + up_out[off] = up; + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } +} + +__global__ static DS4_HIP_UNUSED void moe_gate_up_mid_warp8_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const int32_t *selected, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + float clamp) { + uint32_t lane = threadIdx.x & 31u; + uint32_t warp = threadIdx.x >> 5u; + uint32_t row = blockIdx.x * 8u + warp; + uint32_t pair = blockIdx.y; + if (row >= expert_mid_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + uint32_t expert = (uint32_t)expert_i; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_q8_K *xqb = xq + (uint64_t)tok * xq_blocks; + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = lane; b < xq_blocks; b += 32u) { + gate += dev_dot_iq2_xxs_q8_K_block(gr + b, xqb + b); + up += dev_dot_iq2_xxs_q8_K_block(ur + b, xqb + b); + } + gate = warp_sum_f32(gate); + up = warp_sum_f32(up); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + gate_out[off] = gate; + up_out[off] = up; + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } +} + +__global__ static DS4_HIP_UNUSED void moe_gate_up_mid_hwarp16_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const int32_t *selected, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + float clamp) { + uint32_t lane = threadIdx.x & 15u; + uint32_t row = blockIdx.x * 16u + (threadIdx.x >> 4u); + uint32_t pair = blockIdx.y; + if (row >= expert_mid_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + uint32_t expert = (uint32_t)expert_i; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_q8_K *xqb = xq + (uint64_t)tok * xq_blocks; + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = lane; b < xq_blocks; b += 16u) { + gate += dev_dot_iq2_xxs_q8_K_block(gr + b, xqb + b); + up += dev_dot_iq2_xxs_q8_K_block(ur + b, xqb + b); + } + gate = half_warp_sum_f32(gate, lane); + up = half_warp_sum_f32(up, lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + gate_out[off] = gate; + up_out[off] = up; + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } +} + +__global__ static void moe_gate_up_mid_qwarp32_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const int32_t *selected, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + float clamp) { + uint32_t lane = threadIdx.x & 7u; + uint32_t row_lane = threadIdx.x >> 3u; + uint32_t pair = blockIdx.y; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + uint32_t expert = (uint32_t)expert_i; + const hip_block_q8_K *xqb = xq + (uint64_t)tok * xq_blocks; + for (uint32_t rr = 0; rr < 4u; rr++) { + uint32_t row = blockIdx.x * 128u + row_lane + rr * 32u; + if (row >= expert_mid_dim) continue; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + gate += dev_dot_iq2_xxs_q8_K_block(gr + b, xqb + b); + up += dev_dot_iq2_xxs_q8_K_block(ur + b, xqb + b); + } + gate = quarter_warp_sum_f32(gate, lane); + up = quarter_warp_sum_f32(up, lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + gate_out[off] = gate; + up_out[off] = up; + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } + } +} + +__global__ static void moe_gate_up_mid_decode_lut_qwarp32_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const int32_t *selected, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + uint32_t write_aux, + float clamp) { + uint32_t lane = threadIdx.x & 7u; + uint32_t row_lane = threadIdx.x >> 3u; + uint32_t pair = blockIdx.y; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + uint32_t expert = (uint32_t)expert_i; + const hip_block_q8_K *xqb = xq + (uint64_t)tok * xq_blocks; + __shared__ hip_block_q8_K sxq[16]; + __shared__ uint64_t s_iq2_grid[256]; + __shared__ uint8_t s_iq2_signs[128]; + if (xq_blocks <= 16u) { + for (uint32_t i = threadIdx.x; i < xq_blocks; i += blockDim.x) sxq[i] = xqb[i]; + for (uint32_t i = threadIdx.x; i < 256u; i += blockDim.x) s_iq2_grid[i] = hip_iq2xxs_grid[i]; + for (uint32_t i = threadIdx.x; i < 128u; i += blockDim.x) s_iq2_signs[i] = hip_ksigns_iq2xs[i]; + __syncthreads(); + xqb = sxq; + } + for (uint32_t rr = 0; rr < 4u; rr++) { + uint32_t row = blockIdx.x * 128u + row_lane + rr * 32u; + if (row >= expert_mid_dim) continue; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + gate += dev_dot_iq2_xxs_q8_K_block_lut(gr + b, xqb + b, s_iq2_grid, s_iq2_signs); + up += dev_dot_iq2_xxs_q8_K_block_lut(ur + b, xqb + b, s_iq2_grid, s_iq2_signs); + } + gate = quarter_warp_sum_f32(gate, lane); + up = quarter_warp_sum_f32(up, lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + if (write_aux) { + gate_out[off] = gate; + up_out[off] = up; + } + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } + } +} + +__global__ static void moe_count_sorted_pairs_kernel( + uint32_t *counts, + const int32_t *selected, + uint32_t pair_count) { + uint32_t pair = (uint32_t)((uint64_t)blockIdx.x * blockDim.x + threadIdx.x); + if (pair >= pair_count) return; + int32_t expert_i = selected[pair]; + if (expert_i < 0) expert_i = 0; + atomicAdd(counts + (uint32_t)expert_i, 1u); +} + +__global__ static void moe_prefix_sorted_pairs_kernel( + uint32_t *offsets, + uint32_t *cursors, + const uint32_t *counts) { + if (threadIdx.x == 0) { + uint32_t sum = 0; + for (uint32_t e = 0; e < 256u; e++) { + offsets[e] = sum; + cursors[e] = sum; + sum += counts[e]; + } + offsets[256] = sum; + } +} + +__global__ static void moe_scatter_sorted_pairs_kernel( + uint32_t *sorted_pairs, + uint32_t *cursors, + const int32_t *selected, + uint32_t pair_count) { + uint32_t pair = (uint32_t)((uint64_t)blockIdx.x * blockDim.x + threadIdx.x); + if (pair >= pair_count) return; + int32_t expert_i = selected[pair]; + if (expert_i < 0) expert_i = 0; + uint32_t pos = atomicAdd(cursors + (uint32_t)expert_i, 1u); + sorted_pairs[pos] = pair; +} + +__global__ static void moe_build_expert_tile_offsets_kernel( + uint32_t *tile_offsets, + uint32_t *tile_total, + const uint32_t *counts, + uint32_t block_m) { + if (threadIdx.x == 0) { + uint32_t sum = 0; + for (uint32_t e = 0; e < 256u; e++) { + tile_offsets[e] = sum; + sum += (counts[e] + block_m - 1u) / block_m; + } + tile_offsets[256] = sum; + *tile_total = sum; + } +} + +__global__ static void moe_build_expert_tiles_kernel( + uint32_t *tile_experts, + uint32_t *tile_starts, + const uint32_t *tile_offsets, + const uint32_t *counts, + uint32_t block_m) { + uint32_t e = threadIdx.x; + if (e >= 256u) return; + uint32_t ntiles = (counts[e] + block_m - 1u) / block_m; + uint32_t off = tile_offsets[e]; + for (uint32_t t = 0; t < ntiles; t++) { + tile_experts[off + t] = e; + tile_starts[off + t] = t * block_m; + } +} + +__global__ static void moe_gate_up_mid_sorted_qwarp32_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const uint32_t *sorted_pairs, + const int32_t *selected, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + float clamp) { + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + uint32_t pair = sorted_pairs[blockIdx.y]; + if (row >= expert_mid_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + uint32_t expert = (uint32_t)expert_i; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_q8_K *xqb = xq + (uint64_t)tok * xq_blocks; + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + gate += dev_dot_iq2_xxs_q8_K_block(gr + b, xqb + b); + up += dev_dot_iq2_xxs_q8_K_block(ur + b, xqb + b); + } + gate = quarter_warp_sum_f32(gate, lane); + up = quarter_warp_sum_f32(up, lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + gate_out[off] = gate; + up_out[off] = up; + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } +} + +__global__ static DS4_HIP_UNUSED void moe_gate_up_mid_expert_tile8_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + float clamp) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t group = threadIdx.x >> 3u; + uint32_t lane = threadIdx.x & 7u; + uint32_t pair_slot = group & 7u; + uint32_t row_lane = group >> 3u; + uint32_t expert = tile_experts[tile]; + uint32_t local_pair = tile_starts[tile] + pair_slot; + if (local_pair >= counts[expert]) return; + uint32_t sorted_idx = offsets[expert] + local_pair; + uint32_t pair = sorted_pairs[sorted_idx]; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + const hip_block_q8_K *xqb = xq + (uint64_t)tok * xq_blocks; + + for (uint32_t rr = 0; rr < 2u; rr++) { + uint32_t row = blockIdx.x * 8u + row_lane + rr * 4u; + if (row >= expert_mid_dim) continue; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + gate += dev_dot_iq2_xxs_q8_K_block(gr + b, xqb + b); + up += dev_dot_iq2_xxs_q8_K_block(ur + b, xqb + b); + } + gate = quarter_warp_sum_f32(gate, lane); + up = quarter_warp_sum_f32(up, lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + gate_out[off] = gate; + up_out[off] = up; + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } + } +} + +__global__ static void moe_gate_up_mid_expert_tile4_row32_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + uint32_t write_aux, + float clamp) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + uint32_t expert = tile_experts[tile]; + uint32_t local_start = tile_starts[tile]; + __shared__ hip_block_q8_K sxq[4][16]; + uint32_t pair[4] = {0, 0, 0, 0}; + uint32_t tok[4] = {0, 0, 0, 0}; + uint32_t slot[4] = {0, 0, 0, 0}; + const hip_block_q8_K *xqb[4] = {NULL, NULL, NULL, NULL}; + uint32_t np = 0; + for (; np < 4u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + tok[np] = pair[np] / n_expert; + slot[np] = pair[np] - tok[np] * n_expert; + xqb[np] = xq + (uint64_t)tok[np] * xq_blocks; + } + if (xq_blocks <= 16u) { + for (uint32_t i = threadIdx.x; i < np * xq_blocks; i += blockDim.x) { + uint32_t p = i / xq_blocks; + uint32_t b = i - p * xq_blocks; + sxq[p][b] = xqb[p][b]; + } + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + if (row >= expert_mid_dim) return; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + float gate[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + float up[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + dev_dot_iq2_xxs_q8_K_block4(gr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, np, gate); + dev_dot_iq2_xxs_q8_K_block4(ur + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, np, up); + } + for (uint32_t p = 0; p < np; p++) { + gate[p] = quarter_warp_sum_f32(gate[p], lane); + up[p] = quarter_warp_sum_f32(up[p], lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate[p] > clamp) gate[p] = clamp; + if (up[p] > clamp) up[p] = clamp; + if (up[p] < -clamp) up[p] = -clamp; + } + const uint64_t off = (uint64_t)pair[p] * expert_mid_dim + row; + if (write_aux) { + gate_out[off] = gate[p]; + up_out[off] = up[p]; + } + mid_out[off] = (gate[p] / (1.0f + expf(-gate[p]))) * up[p] * weights[(uint64_t)tok[p] * n_expert + slot[p]]; + } + } +} + +__global__ static void moe_gate_up_mid_expert_tile8_row32_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + uint32_t write_aux, + float clamp) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + uint32_t expert = tile_experts[tile]; + uint32_t local_start = tile_starts[tile]; + __shared__ hip_block_q8_K sxq[8][16]; + __shared__ uint64_t s_iq2_grid[256]; + __shared__ uint8_t s_iq2_signs[128]; + uint32_t pair[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + uint32_t tok[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + uint32_t slot[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const hip_block_q8_K *xqb[8] = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL}; + uint32_t np = 0; + for (; np < 8u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + tok[np] = pair[np] / n_expert; + slot[np] = pair[np] - tok[np] * n_expert; + xqb[np] = xq + (uint64_t)tok[np] * xq_blocks; + } + if (xq_blocks <= 16u) { + for (uint32_t i = threadIdx.x; i < np * xq_blocks; i += blockDim.x) { + uint32_t p = i / xq_blocks; + uint32_t b = i - p * xq_blocks; + sxq[p][b] = xqb[p][b]; + } + for (uint32_t i = threadIdx.x; i < 256u; i += blockDim.x) s_iq2_grid[i] = hip_iq2xxs_grid[i]; + for (uint32_t i = threadIdx.x; i < 128u; i += blockDim.x) s_iq2_signs[i] = hip_ksigns_iq2xs[i]; + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + if (row >= expert_mid_dim) return; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + float gate[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + float up[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + dev_dot_iq2_xxs_q8_K_block8_deq_lut(gr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, gate, + s_iq2_grid, s_iq2_signs); + dev_dot_iq2_xxs_q8_K_block8_deq_lut(ur + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, up, + s_iq2_grid, s_iq2_signs); + } + for (uint32_t p = 0; p < np; p++) { + gate[p] = quarter_warp_sum_f32(gate[p], lane); + up[p] = quarter_warp_sum_f32(up[p], lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate[p] > clamp) gate[p] = clamp; + if (up[p] > clamp) up[p] = clamp; + if (up[p] < -clamp) up[p] = -clamp; + } + const uint64_t off = (uint64_t)pair[p] * expert_mid_dim + row; + if (write_aux) { + gate_out[off] = gate[p]; + up_out[off] = up[p]; + } + mid_out[off] = (gate[p] / (1.0f + expf(-gate[p]))) * up[p] * weights[(uint64_t)tok[p] * n_expert + slot[p]]; + } + } +} + +__global__ static void moe_gate_up_mid_expert_tile8_row2048_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + uint32_t write_aux, + float clamp) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row_lane = threadIdx.x >> 3u; + uint32_t expert = tile_experts[tile]; + uint32_t local_start = tile_starts[tile]; + __shared__ hip_block_q8_K sxq[8][16]; + __shared__ uint64_t s_iq2_grid[256]; + __shared__ uint8_t s_iq2_signs[128]; + uint32_t pair[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + uint32_t tok[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + uint32_t slot[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const hip_block_q8_K *xqb[8] = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL}; + uint32_t np = 0; + for (; np < 8u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + tok[np] = pair[np] / n_expert; + slot[np] = pair[np] - tok[np] * n_expert; + xqb[np] = xq + (uint64_t)tok[np] * xq_blocks; + } + if (xq_blocks <= 16u) { + for (uint32_t i = threadIdx.x; i < np * xq_blocks; i += blockDim.x) { + uint32_t p = i / xq_blocks; + uint32_t b = i - p * xq_blocks; + sxq[p][b] = xqb[p][b]; + } + for (uint32_t i = threadIdx.x; i < 256u; i += blockDim.x) s_iq2_grid[i] = hip_iq2xxs_grid[i]; + for (uint32_t i = threadIdx.x; i < 128u; i += blockDim.x) s_iq2_signs[i] = hip_ksigns_iq2xs[i]; + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + for (uint32_t rr = 0; rr < 64u; rr++) { + uint32_t row = blockIdx.x * 2048u + row_lane + rr * 32u; + if (row >= expert_mid_dim) continue; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + float gate[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + float up[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + dev_dot_iq2_xxs_q8_K_block8_deq_lut(gr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, gate, + s_iq2_grid, s_iq2_signs); + dev_dot_iq2_xxs_q8_K_block8_deq_lut(ur + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, up, + s_iq2_grid, s_iq2_signs); + } + for (uint32_t p = 0; p < np; p++) { + gate[p] = quarter_warp_sum_f32(gate[p], lane); + up[p] = quarter_warp_sum_f32(up[p], lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate[p] > clamp) gate[p] = clamp; + if (up[p] > clamp) up[p] = clamp; + if (up[p] < -clamp) up[p] = -clamp; + } + const uint64_t off = (uint64_t)pair[p] * expert_mid_dim + row; + if (write_aux) { + gate_out[off] = gate[p]; + up_out[off] = up[p]; + } + mid_out[off] = (gate[p] / (1.0f + expf(-gate[p]))) * up[p] * weights[(uint64_t)tok[p] * n_expert + slot[p]]; + } + } + } +} + +template +__global__ static void moe_gate_up_mid_expert_tile8_rowspan_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + uint32_t write_aux, + float clamp) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row_lane = threadIdx.x >> 3u; + uint32_t expert = tile_experts[tile]; + uint32_t local_start = tile_starts[tile]; + __shared__ hip_block_q8_K sxq[8][16]; + __shared__ uint64_t s_iq2_grid[256]; + __shared__ uint8_t s_iq2_signs[128]; + uint32_t pair[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + uint32_t tok[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + uint32_t slot[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const hip_block_q8_K *xqb[8] = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL}; + uint32_t np = 0; + for (; np < 8u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + tok[np] = pair[np] / n_expert; + slot[np] = pair[np] - tok[np] * n_expert; + xqb[np] = xq + (uint64_t)tok[np] * xq_blocks; + } + if (xq_blocks <= 16u) { + for (uint32_t i = threadIdx.x; i < np * xq_blocks; i += blockDim.x) { + uint32_t p = i / xq_blocks; + uint32_t b = i - p * xq_blocks; + sxq[p][b] = xqb[p][b]; + } + for (uint32_t i = threadIdx.x; i < 256u; i += blockDim.x) s_iq2_grid[i] = hip_iq2xxs_grid[i]; + for (uint32_t i = threadIdx.x; i < 128u; i += blockDim.x) s_iq2_signs[i] = hip_ksigns_iq2xs[i]; + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + for (uint32_t rr = 0; rr < ROW_SPAN / 32u; rr++) { + uint32_t row = blockIdx.x * ROW_SPAN + row_lane + rr * 32u; + if (row >= expert_mid_dim) continue; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + float gate[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + float up[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + dev_dot_iq2_xxs_q8_K_block8_deq_lut(gr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, gate, + s_iq2_grid, s_iq2_signs); + dev_dot_iq2_xxs_q8_K_block8_deq_lut(ur + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, up, + s_iq2_grid, s_iq2_signs); + } + for (uint32_t p = 0; p < np; p++) { + gate[p] = quarter_warp_sum_f32(gate[p], lane); + up[p] = quarter_warp_sum_f32(up[p], lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate[p] > clamp) gate[p] = clamp; + if (up[p] > clamp) up[p] = clamp; + if (up[p] < -clamp) up[p] = -clamp; + } + const uint64_t off = (uint64_t)pair[p] * expert_mid_dim + row; + if (write_aux) { + gate_out[off] = gate[p]; + up_out[off] = up[p]; + } + mid_out[off] = (gate[p] / (1.0f + expf(-gate[p]))) * up[p] * weights[(uint64_t)tok[p] * n_expert + slot[p]]; + } + } + } +} + +__global__ static void moe_gate_up_mid_sorted_p2_qwarp32_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const hip_block_q8_K *xq, + const uint32_t *sorted_pairs, + const int32_t *selected, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t xq_blocks, + uint32_t expert_mid_dim, + uint32_t n_expert, + uint32_t pair_count, + float clamp) { + uint32_t lane = threadIdx.x & 7u; + uint32_t pair_lane = (threadIdx.x >> 3u) & 1u; + uint32_t row = blockIdx.x * 16u + (threadIdx.x >> 4u); + uint32_t sorted_idx = blockIdx.y * 2u + pair_lane; + if (row >= expert_mid_dim || sorted_idx >= pair_count) return; + uint32_t pair = sorted_pairs[sorted_idx]; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + uint32_t expert = (uint32_t)expert_i; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_q8_K *xqb = xq + (uint64_t)tok * xq_blocks; + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = lane; b < xq_blocks; b += 8u) { + gate += dev_dot_iq2_xxs_q8_K_block(gr + b, xqb + b); + up += dev_dot_iq2_xxs_q8_K_block(ur + b, xqb + b); + } + gate = quarter_warp_sum_f32(gate, lane); + up = quarter_warp_sum_f32(up, lane); + if (lane == 0) { + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + gate_out[off] = gate; + up_out[off] = up; + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } +} + +__global__ static DS4_HIP_UNUSED void moe_down_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const int32_t *selected, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert) { + uint32_t row = blockIdx.x; + uint32_t pair = blockIdx.y; + if (row >= out_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)(uint32_t)expert_i * down_expert_bytes + (uint64_t)row * down_row_bytes); + const hip_block_q8_K *xq = midq + (uint64_t)pair * midq_blocks; + float acc = 0.0f; + for (uint32_t b = threadIdx.x; b < midq_blocks; b += blockDim.x) acc += dev_dot_q2_K_q8_K_block(wr + b, xq + b); + __shared__ float partial[256]; + partial[threadIdx.x] = acc; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) down_out[(uint64_t)pair * out_dim + row] = partial[0]; +} + +__global__ static DS4_HIP_UNUSED void moe_down_warp8_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const int32_t *selected, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert) { + uint32_t lane = threadIdx.x & 31u; + uint32_t warp = threadIdx.x >> 5u; + uint32_t row = blockIdx.x * 8u + warp; + uint32_t pair = blockIdx.y; + if (row >= out_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)(uint32_t)expert_i * down_expert_bytes + (uint64_t)row * down_row_bytes); + const hip_block_q8_K *xq = midq + (uint64_t)pair * midq_blocks; + float acc = 0.0f; + for (uint32_t b = lane; b < midq_blocks; b += 32u) acc += dev_dot_q2_K_q8_K_block(wr + b, xq + b); + acc = warp_sum_f32(acc); + if (lane == 0) down_out[(uint64_t)pair * out_dim + row] = acc; +} + +__global__ static DS4_HIP_UNUSED void moe_down_hwarp16_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const int32_t *selected, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert) { + uint32_t lane = threadIdx.x & 15u; + uint32_t row = blockIdx.x * 16u + (threadIdx.x >> 4u); + uint32_t pair = blockIdx.y; + if (row >= out_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)(uint32_t)expert_i * down_expert_bytes + (uint64_t)row * down_row_bytes); + const hip_block_q8_K *xq = midq + (uint64_t)pair * midq_blocks; + float acc = 0.0f; + for (uint32_t b = lane; b < midq_blocks; b += 16u) acc += dev_dot_q2_K_q8_K_block(wr + b, xq + b); + acc = half_warp_sum_f32(acc, lane); + if (lane == 0) down_out[(uint64_t)pair * out_dim + row] = acc; +} + +__global__ static void moe_down_qwarp32_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const int32_t *selected, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert) { + // Shared memory cache for activations. Each block handles one (token, expert) pair. + // All 32 rows in the block share the same activations. + extern __shared__ hip_block_q8_K smem_xq[]; + + uint32_t pair = blockIdx.y; + const hip_block_q8_K *xq_glob = midq + (uint64_t)pair * midq_blocks; + + // Collaborative load into shared memory + for (uint32_t i = threadIdx.x; i < midq_blocks; i += blockDim.x) { + smem_xq[i] = xq_glob[i]; + } + __syncthreads(); + + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + if (row >= out_dim) return; + + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)(uint32_t)expert_i * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc = 0.0f; + + // Use cached activations from shared memory + for (uint32_t b = lane; b < midq_blocks; b += 8u) { + acc += dev_dot_q2_K_q8_K_block(wr + b, smem_xq + b); + } + + acc = quarter_warp_sum_f32(acc, lane); + if (lane == 0) down_out[(uint64_t)pair * out_dim + row] = acc; +} + +__global__ static void moe_down_sum6_qwarp32_kernel( + float *out, + const char *down_base, + const hip_block_q8_K *midq, + const int32_t *selected, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim) { + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + if (row >= out_dim) return; + float total = 0.0f; + #pragma unroll + for (uint32_t slot = 0; slot < 6u; slot++) { + int32_t expert_i = selected[slot]; + if (expert_i < 0) expert_i = 0; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)(uint32_t)expert_i * down_expert_bytes + (uint64_t)row * down_row_bytes); + const hip_block_q8_K *xq = midq + (uint64_t)slot * midq_blocks; + float acc = 0.0f; + for (uint32_t b = lane; b < midq_blocks; b += 8u) acc += dev_dot_q2_K_q8_K_block(wr + b, xq + b); + acc = quarter_warp_sum_f32(acc, lane); + if (lane == 0) total += acc; + } + if (lane == 0) out[row] = total; +} + +__global__ static void moe_down_sorted_qwarp32_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const int32_t *selected, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert) { + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + uint32_t pair = sorted_pairs[blockIdx.y]; + if (row >= out_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)(uint32_t)expert_i * down_expert_bytes + (uint64_t)row * down_row_bytes); + const hip_block_q8_K *xq = midq + (uint64_t)pair * midq_blocks; + float acc = 0.0f; + for (uint32_t b = lane; b < midq_blocks; b += 8u) acc += dev_dot_q2_K_q8_K_block(wr + b, xq + b); + acc = quarter_warp_sum_f32(acc, lane); + if (lane == 0) down_out[(uint64_t)pair * out_dim + row] = acc; +} + +__global__ static DS4_HIP_UNUSED void moe_down_expert_tile8_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t group = threadIdx.x >> 3u; + uint32_t lane = threadIdx.x & 7u; + uint32_t pair_slot = group & 7u; + uint32_t row_lane = group >> 3u; + uint32_t expert = tile_experts[tile]; + uint32_t local_pair = tile_starts[tile] + pair_slot; + if (local_pair >= counts[expert]) return; + uint32_t sorted_idx = offsets[expert] + local_pair; + uint32_t pair = sorted_pairs[sorted_idx]; + const hip_block_q8_K *xq = midq + (uint64_t)pair * midq_blocks; + + for (uint32_t rr = 0; rr < 2u; rr++) { + uint32_t row = blockIdx.x * 8u + row_lane + rr * 4u; + if (row >= out_dim) continue; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc = 0.0f; + for (uint32_t b = lane; b < midq_blocks; b += 8u) acc += dev_dot_q2_K_q8_K_block(wr + b, xq + b); + acc = quarter_warp_sum_f32(acc, lane); + if (lane == 0) down_out[(uint64_t)pair * out_dim + row] = acc; + } +} + +__global__ static void moe_down_expert_tile4_row32_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert, + uint32_t atomic_out) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + uint32_t expert = tile_experts[tile]; + uint32_t local_start = tile_starts[tile]; + __shared__ hip_block_q8_K sxq[4][8]; + uint32_t pair[4] = {0, 0, 0, 0}; + const hip_block_q8_K *xqb[4] = {NULL, NULL, NULL, NULL}; + uint32_t np = 0; + for (; np < 4u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + xqb[np] = midq + (uint64_t)pair[np] * midq_blocks; + } + if (midq_blocks <= 8u) { + for (uint32_t i = threadIdx.x; i < np * midq_blocks; i += blockDim.x) { + uint32_t p = i / midq_blocks; + uint32_t b = i - p * midq_blocks; + sxq[p][b] = xqb[p][b]; + } + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + if (row >= out_dim) return; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (uint32_t b = lane; b < midq_blocks; b += 8u) { + dev_dot_q2_K_q8_K_block4(wr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, np, acc); + } + for (uint32_t p = 0; p < np; p++) { + acc[p] = quarter_warp_sum_f32(acc[p], lane); + if (lane == 0) { + if (atomic_out) { + uint32_t tok = pair[p] / n_expert; + atomicAdd(down_out + (uint64_t)tok * out_dim + row, acc[p]); + } else { + down_out[(uint64_t)pair[p] * out_dim + row] = acc[p]; + } + } + } +} + +__global__ static void moe_down_expert_tile8_row32_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert, + uint32_t atomic_out) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + uint32_t expert = tile_experts[tile]; + uint32_t local_start = tile_starts[tile]; + __shared__ hip_block_q8_K sxq[8][8]; + uint32_t pair[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const hip_block_q8_K *xqb[8] = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL}; + uint32_t np = 0; + for (; np < 8u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + xqb[np] = midq + (uint64_t)pair[np] * midq_blocks; + } + if (midq_blocks <= 8u) { + for (uint32_t i = threadIdx.x; i < np * midq_blocks; i += blockDim.x) { + uint32_t p = i / midq_blocks; + uint32_t b = i - p * midq_blocks; + sxq[p][b] = xqb[p][b]; + } + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + if (row >= out_dim) return; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + for (uint32_t b = lane; b < midq_blocks; b += 8u) { + dev_dot_q2_K_q8_K_block8(wr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, acc); + } + for (uint32_t p = 0; p < np; p++) { + acc[p] = quarter_warp_sum_f32(acc[p], lane); + if (lane == 0) { + if (atomic_out) { + uint32_t tok = pair[p] / n_expert; + atomicAdd(down_out + (uint64_t)tok * out_dim + row, acc[p]); + } else { + down_out[(uint64_t)pair[p] * out_dim + row] = acc[p]; + } + } + } +} + +__global__ static void moe_down_expert_tile16_row32_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert, + uint32_t atomic_out) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t local_start = tile_starts[tile]; + if (local_start & 8u) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row = blockIdx.x * 32u + (threadIdx.x >> 3u); + uint32_t expert = tile_experts[tile]; + __shared__ hip_block_q8_K sxq[16][8]; + uint32_t pair[16] = {0}; + const hip_block_q8_K *xqb[16] = {NULL}; + uint32_t np = 0; + for (; np < 16u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + xqb[np] = midq + (uint64_t)pair[np] * midq_blocks; + } + if (midq_blocks <= 8u) { + for (uint32_t i = threadIdx.x; i < np * midq_blocks; i += blockDim.x) { + uint32_t p = i / midq_blocks; + uint32_t b = i - p * midq_blocks; + sxq[p][b] = xqb[p][b]; + } + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + if (row >= out_dim) return; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc[16] = {0.0f}; + for (uint32_t b = lane; b < midq_blocks; b += 8u) { + dev_dot_q2_K_q8_K_block8(wr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np < 8u ? np : 8u, acc); + if (np > 8u) { + dev_dot_q2_K_q8_K_block8(wr + b, xqb[8] ? xqb[8] + b : NULL, xqb[9] ? xqb[9] + b : NULL, + xqb[10] ? xqb[10] + b : NULL, xqb[11] ? xqb[11] + b : NULL, + xqb[12] ? xqb[12] + b : NULL, xqb[13] ? xqb[13] + b : NULL, + xqb[14] ? xqb[14] + b : NULL, xqb[15] ? xqb[15] + b : NULL, np - 8u, acc + 8); + } + } + for (uint32_t p = 0; p < np; p++) { + acc[p] = quarter_warp_sum_f32(acc[p], lane); + if (lane == 0) { + if (atomic_out) { + uint32_t tok = pair[p] / n_expert; + atomicAdd(down_out + (uint64_t)tok * out_dim + row, acc[p]); + } else { + down_out[(uint64_t)pair[p] * out_dim + row] = acc[p]; + } + } + } +} + +__global__ static void moe_down_expert_tile16_row2048_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert, + uint32_t atomic_out) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t local_start = tile_starts[tile]; + if (local_start & 8u) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row_lane = threadIdx.x >> 3u; + uint32_t expert = tile_experts[tile]; + __shared__ hip_block_q8_K sxq[16][8]; + uint32_t pair[16] = {0}; + const hip_block_q8_K *xqb[16] = {NULL}; + uint32_t np = 0; + for (; np < 16u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + xqb[np] = midq + (uint64_t)pair[np] * midq_blocks; + } + if (midq_blocks <= 8u) { + for (uint32_t i = threadIdx.x; i < np * midq_blocks; i += blockDim.x) { + uint32_t p = i / midq_blocks; + uint32_t b = i - p * midq_blocks; + sxq[p][b] = xqb[p][b]; + } + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + for (uint32_t rr = 0; rr < 64u; rr++) { + uint32_t row = blockIdx.x * 2048u + row_lane + rr * 32u; + if (row >= out_dim) continue; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc[16] = {0.0f}; + for (uint32_t b = lane; b < midq_blocks; b += 8u) { + dev_dot_q2_K_q8_K_block8(wr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np < 8u ? np : 8u, acc); + if (np > 8u) { + dev_dot_q2_K_q8_K_block8(wr + b, xqb[8] ? xqb[8] + b : NULL, xqb[9] ? xqb[9] + b : NULL, + xqb[10] ? xqb[10] + b : NULL, xqb[11] ? xqb[11] + b : NULL, + xqb[12] ? xqb[12] + b : NULL, xqb[13] ? xqb[13] + b : NULL, + xqb[14] ? xqb[14] + b : NULL, xqb[15] ? xqb[15] + b : NULL, np - 8u, acc + 8); + } + } + for (uint32_t p = 0; p < np; p++) { + acc[p] = quarter_warp_sum_f32(acc[p], lane); + if (lane == 0) { + if (atomic_out) { + uint32_t tok = pair[p] / n_expert; + atomicAdd(down_out + (uint64_t)tok * out_dim + row, acc[p]); + } else { + down_out[(uint64_t)pair[p] * out_dim + row] = acc[p]; + } + } + } + } +} + +template +__global__ static void moe_down_expert_tile16_rowspan_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert, + uint32_t atomic_out) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t local_start = tile_starts[tile]; + if (local_start & 8u) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row_lane = threadIdx.x >> 3u; + uint32_t expert = tile_experts[tile]; + __shared__ hip_block_q8_K sxq[16][8]; + uint32_t pair[16] = {0}; + const hip_block_q8_K *xqb[16] = {NULL}; + uint32_t np = 0; + for (; np < 16u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + xqb[np] = midq + (uint64_t)pair[np] * midq_blocks; + } + if (midq_blocks <= 8u) { + for (uint32_t i = threadIdx.x; i < np * midq_blocks; i += blockDim.x) { + uint32_t p = i / midq_blocks; + uint32_t b = i - p * midq_blocks; + sxq[p][b] = xqb[p][b]; + } + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + for (uint32_t rr = 0; rr < ROW_SPAN / 32u; rr++) { + uint32_t row = blockIdx.x * ROW_SPAN + row_lane + rr * 32u; + if (row >= out_dim) continue; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc[16] = {0.0f}; + for (uint32_t b = lane; b < midq_blocks; b += 8u) { + dev_dot_q2_K_q8_K_block8(wr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np < 8u ? np : 8u, acc); + if (np > 8u) { + dev_dot_q2_K_q8_K_block8(wr + b, xqb[8] ? xqb[8] + b : NULL, xqb[9] ? xqb[9] + b : NULL, + xqb[10] ? xqb[10] + b : NULL, xqb[11] ? xqb[11] + b : NULL, + xqb[12] ? xqb[12] + b : NULL, xqb[13] ? xqb[13] + b : NULL, + xqb[14] ? xqb[14] + b : NULL, xqb[15] ? xqb[15] + b : NULL, np - 8u, acc + 8); + } + } + for (uint32_t p = 0; p < np; p++) { + acc[p] = quarter_warp_sum_f32(acc[p], lane); + if (lane == 0) { + if (atomic_out) { + uint32_t tok = pair[p] / n_expert; + atomicAdd(down_out + (uint64_t)tok * out_dim + row, acc[p]); + } else { + down_out[(uint64_t)pair[p] * out_dim + row] = acc[p]; + } + } + } + } +} + +template +__global__ static void moe_down_expert_tile16_rowspan_block16_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert, + uint32_t atomic_out) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t local_start = tile_starts[tile]; + if (local_start & 8u) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row_lane = threadIdx.x >> 3u; + uint32_t expert = tile_experts[tile]; + __shared__ hip_block_q8_K sxq[16][8]; + uint32_t pair[16] = {0}; + const hip_block_q8_K *xqb[16] = {NULL}; + uint32_t np = 0; + for (; np < 16u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + xqb[np] = midq + (uint64_t)pair[np] * midq_blocks; + } + if (midq_blocks <= 8u) { + for (uint32_t i = threadIdx.x; i < np * midq_blocks; i += blockDim.x) { + uint32_t p = i / midq_blocks; + uint32_t b = i - p * midq_blocks; + sxq[p][b] = xqb[p][b]; + } + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + for (uint32_t rr = 0; rr < ROW_SPAN / 32u; rr++) { + uint32_t row = blockIdx.x * ROW_SPAN + row_lane + rr * 32u; + if (row >= out_dim) continue; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc[16] = {0.0f}; + for (uint32_t b = lane; b < midq_blocks; b += 8u) { + dev_dot_q2_K_q8_K_block16(wr + b, + xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, + xqb[8] ? xqb[8] + b : NULL, xqb[9] ? xqb[9] + b : NULL, + xqb[10] ? xqb[10] + b : NULL, xqb[11] ? xqb[11] + b : NULL, + xqb[12] ? xqb[12] + b : NULL, xqb[13] ? xqb[13] + b : NULL, + xqb[14] ? xqb[14] + b : NULL, xqb[15] ? xqb[15] + b : NULL, + np, acc); + } + for (uint32_t p = 0; p < np; p++) { + acc[p] = quarter_warp_sum_f32(acc[p], lane); + if (lane == 0) { + if (atomic_out) { + uint32_t tok = pair[p] / n_expert; + atomicAdd(down_out + (uint64_t)tok * out_dim + row, acc[p]); + } else { + down_out[(uint64_t)pair[p] * out_dim + row] = acc[p]; + } + } + } + } +} + +__global__ static void moe_down_sorted_p2_qwarp32_kernel( + float *down_out, + const char *down_base, + const hip_block_q8_K *midq, + const uint32_t *sorted_pairs, + const int32_t *selected, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert, + uint32_t pair_count) { + uint32_t lane = threadIdx.x & 7u; + uint32_t pair_lane = (threadIdx.x >> 3u) & 1u; + uint32_t row = blockIdx.x * 16u + (threadIdx.x >> 4u); + uint32_t sorted_idx = blockIdx.y * 2u + pair_lane; + if (row >= out_dim || sorted_idx >= pair_count) return; + uint32_t pair = sorted_pairs[sorted_idx]; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)(uint32_t)expert_i * down_expert_bytes + (uint64_t)row * down_row_bytes); + const hip_block_q8_K *xq = midq + (uint64_t)pair * midq_blocks; + float acc = 0.0f; + for (uint32_t b = lane; b < midq_blocks; b += 8u) acc += dev_dot_q2_K_q8_K_block(wr + b, xq + b); + acc = quarter_warp_sum_f32(acc, lane); + if (lane == 0) down_out[(uint64_t)pair * out_dim + row] = acc; +} + +__global__ static void moe_sum_kernel(float *out, const float *down, uint32_t out_dim, uint32_t n_expert, uint32_t n_tokens) { + uint64_t gid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n = (uint64_t)n_tokens * out_dim; + if (gid >= n) return; + uint32_t tok = gid / out_dim; + uint32_t row = gid - (uint64_t)tok * out_dim; + float acc = 0.0f; + for (uint32_t e = 0; e < n_expert; e++) acc += down[((uint64_t)tok * n_expert + e) * out_dim + row]; + out[gid] = acc; +} + +__device__ static float dev_iq2_xxs_dot_f32(const hip_block_iq2_xxs *row, const float *x, uint32_t nb) { + float acc = 0.0f; + for (uint32_t b = 0; b < nb; b++) { + const hip_block_iq2_xxs *xb = row + b; + const float d = dev_f16_to_f32(xb->d); + const uint16_t *q2 = xb->qs; + const float *xf = x + (uint64_t)b * HIP_QK_K; + for (uint32_t ib32 = 0; ib32 < HIP_QK_K / 32; ib32++) { + const uint32_t aux_g = (uint32_t)q2[0] | ((uint32_t)q2[1] << 16); + const uint32_t aux_s = (uint32_t)q2[2] | ((uint32_t)q2[3] << 16); + q2 += 4; + const float dl = d * (0.5f + (float)(aux_s >> 28)) * 0.25f; + const uint8_t grids[4] = { + (uint8_t)(aux_g & 0xffu), + (uint8_t)((aux_g >> 8) & 0xffu), + (uint8_t)((aux_g >> 16) & 0xffu), + (uint8_t)((aux_g >> 24) & 0xffu), + }; + for (uint32_t half = 0; half < 2; half++) { + for (uint32_t g = 0; g < 2; g++) { + const uint32_t gi = half * 2 + g; + const uint64_t grid = hip_iq2xxs_grid[grids[gi]]; + const uint8_t signs = hip_ksigns_iq2xs[(aux_s >> (14u * half + 7u * g)) & 127u]; + for (uint32_t i = 0; i < 8; i++) { + float w = (float)((grid >> (8u * i)) & 0xffu); + if (signs & (1u << i)) w = -w; + acc += dl * w * xf[ib32 * 32u + half * 16u + g * 8u + i]; + } + } + } + } + } + return acc; +} + +__device__ static float dev_q2_K_dot_f32(const hip_block_q2_K *row, const float *x, uint32_t nb) { + float acc = 0.0f; + for (uint32_t b = 0; b < nb; b++) { + const hip_block_q2_K *xb = row + b; + const float d = dev_f16_to_f32(xb->d); + const float dmin = dev_f16_to_f32(xb->dmin); + for (uint32_t il = 0; il < 16; il++) { + const uint32_t chunk = il / 8u; + const uint32_t pair = il & 1u; + const uint32_t shift = ((il / 2u) & 3u) * 2u; + const uint8_t sc = xb->scales[il]; + const float dl = d * (float)(sc & 0x0fu); + const float ml = dmin * (float)(sc >> 4); + const uint8_t *q = xb->qs + 32u * chunk + 16u * pair; + const float *xf = x + (uint64_t)b * HIP_QK_K + chunk * 128u + ((il % 8u) / 2u) * 32u + pair * 16u; + for (uint32_t i = 0; i < 16; i++) { + const float w = dl * (float)((q[i] >> shift) & 3u) - ml; + acc += w * xf[i]; + } + } + } + return acc; +} + +__global__ static void moe_gate_up_mid_f32_kernel( + float *gate_out, + float *up_out, + float *mid_out, + const char *gate_base, + const char *up_base, + const float *x, + const int32_t *selected, + const float *weights, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint32_t expert_in_dim, + uint32_t expert_mid_dim, + uint32_t n_expert, + float clamp) { + uint32_t row = blockIdx.x; + uint32_t pair = blockIdx.y; + if (row >= expert_mid_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + uint32_t expert = (uint32_t)expert_i; + const uint32_t nb = expert_in_dim / HIP_QK_K; + const hip_block_iq2_xxs *gr = (const hip_block_iq2_xxs *)(gate_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const hip_block_iq2_xxs *ur = (const hip_block_iq2_xxs *)(up_base + (uint64_t)expert * gate_expert_bytes + (uint64_t)row * gate_row_bytes); + const float *xr = x + (uint64_t)tok * expert_in_dim; + float gate = 0.0f; + float up = 0.0f; + for (uint32_t b = threadIdx.x; b < nb; b += blockDim.x) { + gate += dev_iq2_xxs_dot_f32(gr + b, xr + (uint64_t)b * HIP_QK_K, 1); + up += dev_iq2_xxs_dot_f32(ur + b, xr + (uint64_t)b * HIP_QK_K, 1); + } + __shared__ float partial_gate[256]; + __shared__ float partial_up[256]; + partial_gate[threadIdx.x] = gate; + partial_up[threadIdx.x] = up; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + partial_gate[threadIdx.x] += partial_gate[threadIdx.x + stride]; + partial_up[threadIdx.x] += partial_up[threadIdx.x + stride]; + } + __syncthreads(); + } + if (threadIdx.x == 0) { + gate = partial_gate[0]; + up = partial_up[0]; + if (clamp > 1.0e-6f) { + if (gate > clamp) gate = clamp; + if (up > clamp) up = clamp; + if (up < -clamp) up = -clamp; + } + const uint64_t off = (uint64_t)pair * expert_mid_dim + row; + gate_out[off] = gate; + up_out[off] = up; + mid_out[off] = (gate / (1.0f + expf(-gate))) * up * weights[(uint64_t)tok * n_expert + slot]; + } +} + +__global__ static void moe_down_f32_kernel( + float *down_out, + const char *down_base, + const float *mid, + const int32_t *selected, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t expert_mid_dim, + uint32_t out_dim, + uint32_t n_expert) { + uint32_t row = blockIdx.x; + uint32_t pair = blockIdx.y; + if (row >= out_dim) return; + uint32_t tok = pair / n_expert; + uint32_t slot = pair - tok * n_expert; + int32_t expert_i = selected[(uint64_t)tok * n_expert + slot]; + if (expert_i < 0) expert_i = 0; + const uint32_t nb = expert_mid_dim / HIP_QK_K; + const hip_block_q2_K *wr = (const hip_block_q2_K *)(down_base + (uint64_t)(uint32_t)expert_i * down_expert_bytes + (uint64_t)row * down_row_bytes); + const float *xr = mid + (uint64_t)pair * expert_mid_dim; + float acc = 0.0f; + for (uint32_t b = threadIdx.x; b < nb; b += blockDim.x) acc += dev_q2_K_dot_f32(wr + b, xr + (uint64_t)b * HIP_QK_K, 1); + __shared__ float partial[256]; + partial[threadIdx.x] = acc; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) down_out[(uint64_t)pair * out_dim + row] = partial[0]; +} + +static int routed_moe_launch( + ds4_gpu_tensor *out, + ds4_gpu_tensor *gate, + ds4_gpu_tensor *up, + ds4_gpu_tensor *mid, + ds4_gpu_tensor *down, + const void *model_map, + uint64_t model_size, + uint64_t gate_offset, + uint64_t up_offset, + uint64_t down_offset, + uint32_t gate_type, + uint32_t down_type, + uint64_t gate_expert_bytes, + uint64_t gate_row_bytes, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t expert_in_dim, + uint32_t expert_mid_dim, + uint32_t out_dim, + const ds4_gpu_tensor *selected, + const ds4_gpu_tensor *weights, + uint32_t n_expert, + float clamp, + const ds4_gpu_tensor *x, + uint32_t n_tokens) { + if (!out || !gate || !up || !mid || !down || !model_map || !selected || !weights || !x || + n_tokens == 0 || n_expert == 0 || + expert_in_dim % HIP_QK_K != 0 || expert_mid_dim % HIP_QK_K != 0 || + gate_offset > model_size || up_offset > model_size || down_offset > model_size || + x->bytes < (uint64_t)n_tokens * expert_in_dim * sizeof(float) || + selected->bytes < (uint64_t)n_tokens * n_expert * sizeof(int32_t) || + weights->bytes < (uint64_t)n_tokens * n_expert * sizeof(float) || + gate->bytes < (uint64_t)n_tokens * n_expert * expert_mid_dim * sizeof(float) || + up->bytes < (uint64_t)n_tokens * n_expert * expert_mid_dim * sizeof(float) || + mid->bytes < (uint64_t)n_tokens * n_expert * expert_mid_dim * sizeof(float) || + down->bytes < (uint64_t)n_tokens * n_expert * out_dim * sizeof(float) || + out->bytes < (uint64_t)n_tokens * out_dim * sizeof(float)) { + return 0; + } + if (gate_type != 16u || down_type != 10u) return 0; + const uint64_t gate_bytes = 256ull * gate_expert_bytes; + const uint64_t down_bytes = 256ull * down_expert_bytes; + if (gate_bytes > model_size - gate_offset || + gate_bytes > model_size - up_offset || + down_bytes > model_size - down_offset) { + return 0; + } + const char *gate_w = hip_model_range_ptr(model_map, gate_offset, gate_bytes, "moe_gate"); + const char *up_w = hip_model_range_ptr(model_map, up_offset, gate_bytes, "moe_up"); + const char *down_w = hip_model_range_ptr(model_map, down_offset, down_bytes, "moe_down"); + if (!gate_w || !up_w || !down_w) return 0; + + int ok = 1; + const uint32_t xq_blocks = expert_in_dim / HIP_QK_K; + const uint32_t midq_blocks = expert_mid_dim / HIP_QK_K; + const uint64_t xq_count = (uint64_t)n_tokens * xq_blocks; + const uint64_t midq_count = (uint64_t)n_tokens * n_expert * midq_blocks; + const uint64_t xq_bytes = xq_count * sizeof(hip_block_q8_K); + const uint64_t midq_bytes = midq_count * sizeof(hip_block_q8_K); + if (down->bytes >= xq_bytes && gate->bytes >= midq_bytes) { + hip_block_q8_K *xq = (hip_block_q8_K *)down->ptr; + hip_block_q8_K *midq = (hip_block_q8_K *)gate->ptr; + const uint32_t profile_moe = getenv("DS4_HIP_MOE_PROFILE") != NULL; + hipEvent_t prof_ev[7] = {NULL, NULL, NULL, NULL, NULL, NULL, NULL}; + if (profile_moe) { + for (uint32_t i = 0; i < 7u; i++) { + if (hipEventCreate(&prof_ev[i]) != hipSuccess) { + for (uint32_t j = 0; j < i; j++) (void)hipEventDestroy(prof_ev[j]); + memset(prof_ev, 0, sizeof(prof_ev)); + break; + } + } + if (prof_ev[0]) (void)hipEventRecord(prof_ev[0], 0); + } + const uint32_t pair_count = n_tokens * n_expert; + const uint32_t use_sorted_pairs = n_tokens > 1u; + const uint32_t use_expert_tiles = use_sorted_pairs && getenv("DS4_HIP_MOE_NO_EXPERT_TILES") == NULL; + const uint32_t expert_tile_m = getenv("DS4_HIP_MOE_TILE4") ? 4u : 8u; + const uint32_t write_gate_up = getenv("DS4_HIP_MOE_WRITE_GATE_UP") != NULL; + const uint32_t use_p2_sorted = use_sorted_pairs && getenv("DS4_HIP_MOE_NO_P2") == NULL; + const uint32_t use_atomic_down = use_expert_tiles && + (getenv("DS4_HIP_MOE_ATOMIC_DOWN") != NULL || + (n_tokens >= 128u && getenv("DS4_HIP_MOE_NO_ATOMIC_DOWN") == NULL)); + const uint32_t use_gate_row2048 = use_expert_tiles && expert_tile_m == 8u && + (getenv("DS4_HIP_MOE_GATE_ROW2048") != NULL || + getenv("DS4_HIP_MOE_GATE_ROW256") != NULL || + getenv("DS4_HIP_MOE_GATE_ROW128") != NULL || + (n_tokens >= 128u && + getenv("DS4_HIP_MOE_NO_GATE_ROW2048") == NULL && + getenv("DS4_HIP_MOE_NO_GATE_ROW256") == NULL && + getenv("DS4_HIP_MOE_NO_GATE_ROW128") == NULL)); + const uint32_t use_down_tile16 = use_atomic_down && expert_tile_m == 8u && + n_tokens >= 128u && getenv("DS4_HIP_MOE_NO_DOWN_TILE16") == NULL; + const uint32_t use_down_block16 = use_down_tile16 && midq_blocks <= 8u && + getenv("DS4_HIP_MOE_NO_DOWN_BLOCK16") == NULL; + const uint32_t use_decode_lut_gate = + n_tokens == 1u && xq_blocks <= 16u && + getenv("DS4_HIP_MOE_NO_DECODE_LUT_GATE") == NULL; + const uint32_t gate_row_span = + getenv("DS4_HIP_MOE_GATE_ROW512") != NULL ? 512u : + getenv("DS4_HIP_MOE_GATE_ROW2048") != NULL ? 2048u : 1024u; + const uint32_t down_row_span = + getenv("DS4_HIP_MOE_DOWN_ROW512") != NULL ? 512u : + getenv("DS4_HIP_MOE_DOWN_ROW1024") != NULL ? 1024u : 2048u; + const uint32_t use_down_row2048 = use_atomic_down && expert_tile_m == 8u && + (getenv("DS4_HIP_MOE_DOWN_ROW2048") != NULL || + getenv("DS4_HIP_MOE_DOWN_ROW256") != NULL || + getenv("DS4_HIP_MOE_DOWN_ROW128") != NULL || + getenv("DS4_HIP_MOE_DOWN_ROW64") != NULL || + (use_down_tile16 && + getenv("DS4_HIP_MOE_NO_DOWN_ROW2048") == NULL && + getenv("DS4_HIP_MOE_NO_DOWN_ROW256") == NULL && + getenv("DS4_HIP_MOE_NO_DOWN_ROW128") == NULL && + getenv("DS4_HIP_MOE_NO_DOWN_ROW64") == NULL)); + const uint32_t use_direct_down_sum6 = + n_tokens == 1u && n_expert == 6u && + getenv("DS4_HIP_MOE_NO_DIRECT_DOWN_SUM6") == NULL; + uint32_t *sorted_pairs = NULL; + uint32_t *sorted_offsets = NULL; + uint32_t *sorted_counts = NULL; + uint32_t *tile_total = NULL; + uint32_t *tile_experts = NULL; + uint32_t *tile_starts = NULL; + uint32_t *tile16_total = NULL; + uint32_t *tile16_experts = NULL; + uint32_t *tile16_starts = NULL; + uint32_t tile_capacity = 0; + uint32_t tile16_capacity = 0; + dim3 xq_grid(xq_blocks, n_tokens, 1); + q8_K_quantize_kernel<<>>(xq, (const float *)x->ptr, expert_in_dim, n_tokens); + ok = hip_ok(hipGetLastError(), "routed_moe x quantize launch"); + if (prof_ev[1]) (void)hipEventRecord(prof_ev[1], 0); + if (ok && use_sorted_pairs) { + const uint64_t counts_bytes = 256ull * sizeof(uint32_t); + const uint64_t offsets_bytes = 257ull * sizeof(uint32_t); + const uint64_t cursors_bytes = 256ull * sizeof(uint32_t); + const uint64_t sorted_bytes = (uint64_t)pair_count * sizeof(uint32_t); + tile_capacity = (pair_count + expert_tile_m - 1u) / expert_tile_m + 256u; + tile16_capacity = use_down_tile16 ? ((pair_count + 15u) / 16u + 256u) : 0u; + const uint64_t tile_offsets_bytes = 257ull * sizeof(uint32_t); + const uint64_t tile_total_bytes = sizeof(uint32_t); + const uint64_t tile_experts_bytes = (uint64_t)tile_capacity * sizeof(uint32_t); + const uint64_t tile_starts_bytes = (uint64_t)tile_capacity * sizeof(uint32_t); + const uint64_t tile16_offsets_bytes = use_down_tile16 ? 257ull * sizeof(uint32_t) : 0u; + const uint64_t tile16_total_bytes = use_down_tile16 ? sizeof(uint32_t) : 0u; + const uint64_t tile16_experts_bytes = (uint64_t)tile16_capacity * sizeof(uint32_t); + const uint64_t tile16_starts_bytes = (uint64_t)tile16_capacity * sizeof(uint32_t); + const uint64_t tile_offsets_off = counts_bytes + offsets_bytes + cursors_bytes + sorted_bytes; + const uint64_t tile_total_off = tile_offsets_off + tile_offsets_bytes; + const uint64_t tile_experts_off = tile_total_off + tile_total_bytes; + const uint64_t tile_starts_off = tile_experts_off + tile_experts_bytes; + const uint64_t tile16_offsets_off = tile_starts_off + tile_starts_bytes; + const uint64_t tile16_total_off = tile16_offsets_off + tile16_offsets_bytes; + const uint64_t tile16_experts_off = tile16_total_off + tile16_total_bytes; + const uint64_t tile16_starts_off = tile16_experts_off + tile16_experts_bytes; + const uint64_t scratch_bytes = tile16_starts_off + tile16_starts_bytes; + uint8_t *scratch = (uint8_t *)hip_tmp_alloc(scratch_bytes, + "routed_moe sorted pairs"); + if (!scratch) { + ok = 0; + } else { + uint32_t *counts = (uint32_t *)scratch; + uint32_t *offsets = (uint32_t *)(scratch + counts_bytes); + uint32_t *cursors = (uint32_t *)(scratch + counts_bytes + offsets_bytes); + sorted_pairs = (uint32_t *)(scratch + counts_bytes + offsets_bytes + cursors_bytes); + sorted_offsets = offsets; + sorted_counts = counts; + uint32_t *tile_offsets = (uint32_t *)(scratch + tile_offsets_off); + tile_total = (uint32_t *)(scratch + tile_total_off); + tile_experts = (uint32_t *)(scratch + tile_experts_off); + tile_starts = (uint32_t *)(scratch + tile_starts_off); + uint32_t *tile16_offsets = use_down_tile16 ? (uint32_t *)(scratch + tile16_offsets_off) : NULL; + tile16_total = use_down_tile16 ? (uint32_t *)(scratch + tile16_total_off) : NULL; + tile16_experts = use_down_tile16 ? (uint32_t *)(scratch + tile16_experts_off) : NULL; + tile16_starts = use_down_tile16 ? (uint32_t *)(scratch + tile16_starts_off) : NULL; + ok = hip_ok(hipMemset(counts, 0, counts_bytes), "routed_moe sorted counts clear"); + if (ok) { + moe_count_sorted_pairs_kernel<<<(pair_count + 255u) / 256u, 256>>>( + counts, + (const int32_t *)selected->ptr, + pair_count); + ok = hip_ok(hipGetLastError(), "routed_moe sorted count launch"); + } + if (ok) { + moe_prefix_sorted_pairs_kernel<<<1, 1>>>(offsets, cursors, counts); + ok = hip_ok(hipGetLastError(), "routed_moe sorted prefix launch"); + } + if (ok) { + moe_scatter_sorted_pairs_kernel<<<(pair_count + 255u) / 256u, 256>>>( + sorted_pairs, + cursors, + (const int32_t *)selected->ptr, + pair_count); + ok = hip_ok(hipGetLastError(), "routed_moe sorted scatter launch"); + } + if (ok && use_expert_tiles) { + moe_build_expert_tile_offsets_kernel<<<1, 1>>>(tile_offsets, tile_total, counts, expert_tile_m); + ok = hip_ok(hipGetLastError(), "routed_moe expert tile offsets launch"); + } + if (ok && use_expert_tiles) { + moe_build_expert_tiles_kernel<<<1, 256>>>(tile_experts, tile_starts, tile_offsets, counts, expert_tile_m); + ok = hip_ok(hipGetLastError(), "routed_moe expert tiles launch"); + } + if (ok && use_expert_tiles && use_down_tile16) { + moe_build_expert_tile_offsets_kernel<<<1, 1>>>(tile16_offsets, tile16_total, counts, 16u); + ok = hip_ok(hipGetLastError(), "routed_moe expert tile16 offsets launch"); + } + if (ok && use_expert_tiles && use_down_tile16) { + moe_build_expert_tiles_kernel<<<1, 256>>>(tile16_experts, tile16_starts, tile16_offsets, counts, 16u); + ok = hip_ok(hipGetLastError(), "routed_moe expert tile16 launch"); + } + } + } + if (prof_ev[2]) (void)hipEventRecord(prof_ev[2], 0); + if (ok) { + dim3 mgrid((expert_mid_dim + 31u) / 32u, n_tokens * n_expert, 1); + if (ok && sorted_pairs && use_expert_tiles && sorted_offsets && sorted_counts && tile_total && tile_experts && tile_starts) { + if (use_gate_row2048) { + if (gate_row_span == 512u) { + dim3 tgrid((expert_mid_dim + 511u) / 512u, tile_capacity, 1); + moe_gate_up_mid_expert_tile8_rowspan_kernel<512><<>>( + (float *)gate->ptr, (float *)up->ptr, (float *)mid->ptr, + gate_w, up_w, xq, sorted_pairs, sorted_offsets, sorted_counts, + tile_total, tile_experts, tile_starts, (const float *)weights->ptr, + gate_expert_bytes, gate_row_bytes, xq_blocks, expert_mid_dim, n_expert, + write_gate_up, clamp); + } else if (gate_row_span == 1024u) { + dim3 tgrid((expert_mid_dim + 1023u) / 1024u, tile_capacity, 1); + moe_gate_up_mid_expert_tile8_rowspan_kernel<1024><<>>( + (float *)gate->ptr, (float *)up->ptr, (float *)mid->ptr, + gate_w, up_w, xq, sorted_pairs, sorted_offsets, sorted_counts, + tile_total, tile_experts, tile_starts, (const float *)weights->ptr, + gate_expert_bytes, gate_row_bytes, xq_blocks, expert_mid_dim, n_expert, + write_gate_up, clamp); + } else { + dim3 tgrid((expert_mid_dim + 2047u) / 2048u, tile_capacity, 1); + moe_gate_up_mid_expert_tile8_row2048_kernel<<>>( + (float *)gate->ptr, (float *)up->ptr, (float *)mid->ptr, + gate_w, up_w, xq, sorted_pairs, sorted_offsets, sorted_counts, + tile_total, tile_experts, tile_starts, (const float *)weights->ptr, + gate_expert_bytes, gate_row_bytes, xq_blocks, expert_mid_dim, n_expert, + write_gate_up, clamp); + } + } else if (expert_tile_m == 8u) { + dim3 tgrid((expert_mid_dim + 31u) / 32u, tile_capacity, 1); + moe_gate_up_mid_expert_tile8_row32_kernel<<>>( + (float *)gate->ptr, (float *)up->ptr, (float *)mid->ptr, + gate_w, up_w, xq, sorted_pairs, sorted_offsets, sorted_counts, + tile_total, tile_experts, tile_starts, (const float *)weights->ptr, + gate_expert_bytes, gate_row_bytes, xq_blocks, expert_mid_dim, n_expert, + write_gate_up, clamp); + } else { + dim3 tgrid((expert_mid_dim + 31u) / 32u, tile_capacity, 1); + moe_gate_up_mid_expert_tile4_row32_kernel<<>>( + (float *)gate->ptr, (float *)up->ptr, (float *)mid->ptr, + gate_w, up_w, xq, sorted_pairs, sorted_offsets, sorted_counts, + tile_total, tile_experts, tile_starts, (const float *)weights->ptr, + gate_expert_bytes, gate_row_bytes, xq_blocks, expert_mid_dim, n_expert, + write_gate_up, clamp); + } + } else if (ok && sorted_pairs && use_p2_sorted) { + dim3 p2_mgrid((expert_mid_dim + 15u) / 16u, (pair_count + 1u) / 2u, 1); + moe_gate_up_mid_sorted_p2_qwarp32_kernel<<>>( + (float *)gate->ptr, + (float *)up->ptr, + (float *)mid->ptr, + gate_w, + up_w, + xq, + sorted_pairs, + (const int32_t *)selected->ptr, + (const float *)weights->ptr, + gate_expert_bytes, + gate_row_bytes, + xq_blocks, + expert_mid_dim, + n_expert, + pair_count, + clamp); + } else if (ok && sorted_pairs) { + moe_gate_up_mid_sorted_qwarp32_kernel<<>>( + (float *)gate->ptr, + (float *)up->ptr, + (float *)mid->ptr, + gate_w, + up_w, + xq, + sorted_pairs, + (const int32_t *)selected->ptr, + (const float *)weights->ptr, + gate_expert_bytes, + gate_row_bytes, + xq_blocks, + expert_mid_dim, + n_expert, + clamp); + } else if (ok) { + dim3 qgrid((expert_mid_dim + 127u) / 128u, n_tokens * n_expert, 1); + if (use_decode_lut_gate) { + moe_gate_up_mid_decode_lut_qwarp32_kernel<<>>( + (float *)gate->ptr, + (float *)up->ptr, + (float *)mid->ptr, + gate_w, + up_w, + xq, + (const int32_t *)selected->ptr, + (const float *)weights->ptr, + gate_expert_bytes, + gate_row_bytes, + xq_blocks, + expert_mid_dim, + n_expert, + write_gate_up, + clamp); + } else { + moe_gate_up_mid_qwarp32_kernel<<>>( + (float *)gate->ptr, + (float *)up->ptr, + (float *)mid->ptr, + gate_w, + up_w, + xq, + (const int32_t *)selected->ptr, + (const float *)weights->ptr, + gate_expert_bytes, + gate_row_bytes, + xq_blocks, + expert_mid_dim, + n_expert, + clamp); + } + } + ok = hip_ok(hipGetLastError(), "routed_moe gate/up launch"); + } + if (prof_ev[3]) (void)hipEventRecord(prof_ev[3], 0); + if (ok) { + dim3 midq_grid(midq_blocks, n_tokens * n_expert, 1); + q8_K_quantize_kernel<<>>(midq, (const float *)mid->ptr, expert_mid_dim, n_tokens * n_expert); + ok = hip_ok(hipGetLastError(), "routed_moe mid quantize launch"); + } + if (prof_ev[4]) (void)hipEventRecord(prof_ev[4], 0); + if (ok) { + dim3 dgrid((out_dim + 31u) / 32u, n_tokens * n_expert, 1); + uint32_t *down_tile_total = tile_total; + uint32_t *down_tile_experts = tile_experts; + uint32_t *down_tile_starts = tile_starts; + uint32_t down_tile_capacity = tile_capacity; + if (use_down_tile16 && tile16_total && tile16_experts && tile16_starts) { + down_tile_total = tile16_total; + down_tile_experts = tile16_experts; + down_tile_starts = tile16_starts; + down_tile_capacity = tile16_capacity; + } + if (use_direct_down_sum6) { + dim3 sgrid((out_dim + 31u) / 32u, 1, 1); + moe_down_sum6_qwarp32_kernel<<>>( + (float *)out->ptr, + down_w, + midq, + (const int32_t *)selected->ptr, + down_expert_bytes, + down_row_bytes, + midq_blocks, + out_dim); + } else if (use_atomic_down) { + uint64_t n = (uint64_t)n_tokens * out_dim; + zero_kernel<<<(n + 255u) / 256u, 256>>>((float *)out->ptr, n); + ok = hip_ok(hipGetLastError(), "routed_moe atomic zero launch"); + } + if (use_direct_down_sum6) { + /* The direct decode kernel writes the final token row. */ + } else if (sorted_pairs && use_expert_tiles && sorted_offsets && sorted_counts && + down_tile_total && down_tile_experts && down_tile_starts) { + if (use_down_row2048) { + if (down_row_span == 512u) { + dim3 tgrid((out_dim + 511u) / 512u, down_tile_capacity, 1); + if (use_down_block16) { + moe_down_expert_tile16_rowspan_block16_kernel<512><<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } else { + moe_down_expert_tile16_rowspan_kernel<512><<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } + } else if (down_row_span == 1024u) { + dim3 tgrid((out_dim + 1023u) / 1024u, down_tile_capacity, 1); + if (use_down_block16) { + moe_down_expert_tile16_rowspan_block16_kernel<1024><<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } else { + moe_down_expert_tile16_rowspan_kernel<1024><<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } + } else { + dim3 tgrid((out_dim + 2047u) / 2048u, down_tile_capacity, 1); + if (use_down_block16) { + moe_down_expert_tile16_rowspan_block16_kernel<2048><<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } else { + moe_down_expert_tile16_row2048_kernel<<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } + } + } else if (use_down_tile16) { + dim3 tgrid((out_dim + 31u) / 32u, down_tile_capacity, 1); + moe_down_expert_tile16_row32_kernel<<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } else if (expert_tile_m == 8u) { + dim3 tgrid((out_dim + 31u) / 32u, down_tile_capacity, 1); + moe_down_expert_tile8_row32_kernel<<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } else { + dim3 tgrid((out_dim + 31u) / 32u, down_tile_capacity, 1); + moe_down_expert_tile4_row32_kernel<<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + down_tile_total, down_tile_experts, down_tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } + } else if (sorted_pairs && use_p2_sorted) { + dim3 p2_dgrid((out_dim + 15u) / 16u, (pair_count + 1u) / 2u, 1); + moe_down_sorted_p2_qwarp32_kernel<<>>( + (float *)down->ptr, + down_w, + midq, + sorted_pairs, + (const int32_t *)selected->ptr, + down_expert_bytes, + down_row_bytes, + midq_blocks, + out_dim, + n_expert, + pair_count); + } else if (sorted_pairs) { + moe_down_sorted_qwarp32_kernel<<>>( + (float *)down->ptr, + down_w, + midq, + sorted_pairs, + (const int32_t *)selected->ptr, + down_expert_bytes, + down_row_bytes, + midq_blocks, + out_dim, + n_expert); + } else { + size_t smem_size = (size_t)midq_blocks * sizeof(hip_block_q8_K); + moe_down_qwarp32_kernel<<>>( + (float *)down->ptr, + down_w, + midq, + (const int32_t *)selected->ptr, + down_expert_bytes, + down_row_bytes, + midq_blocks, + out_dim, + n_expert); + } + ok = hip_ok(hipGetLastError(), "routed_moe down launch"); + } + if (prof_ev[5]) (void)hipEventRecord(prof_ev[5], 0); + if (ok && !use_atomic_down && !use_direct_down_sum6) { + uint64_t n = (uint64_t)n_tokens * out_dim; + moe_sum_kernel<<<(n + 255) / 256, 256>>>((float *)out->ptr, (const float *)down->ptr, out_dim, n_expert, n_tokens); + ok = hip_ok(hipGetLastError(), "routed_moe sum launch"); + } + if (prof_ev[6]) { + (void)hipEventRecord(prof_ev[6], 0); + if (hipEventSynchronize(prof_ev[6]) == hipSuccess) { + float ms_xq = 0.0f, ms_sort = 0.0f, ms_gate = 0.0f, ms_midq = 0.0f, ms_down = 0.0f, ms_sum = 0.0f, ms_total = 0.0f; + (void)hipEventElapsedTime(&ms_xq, prof_ev[0], prof_ev[1]); + (void)hipEventElapsedTime(&ms_sort, prof_ev[1], prof_ev[2]); + (void)hipEventElapsedTime(&ms_gate, prof_ev[2], prof_ev[3]); + (void)hipEventElapsedTime(&ms_midq, prof_ev[3], prof_ev[4]); + (void)hipEventElapsedTime(&ms_down, prof_ev[4], prof_ev[5]); + (void)hipEventElapsedTime(&ms_sum, prof_ev[5], prof_ev[6]); + (void)hipEventElapsedTime(&ms_total, prof_ev[0], prof_ev[6]); + fprintf(stderr, + "ds4: ROCm MoE profile tokens=%u pairs=%u xq=%.3f sort=%.3f gateup=%.3f midq=%.3f down=%.3f sum=%.3f total=%.3f ms\n", + n_tokens, pair_count, ms_xq, ms_sort, ms_gate, ms_midq, ms_down, ms_sum, ms_total); + } + for (uint32_t i = 0; i < 7u; i++) (void)hipEventDestroy(prof_ev[i]); + } + return ok; + } + + if (ok) { + dim3 mgrid(expert_mid_dim, n_tokens * n_expert, 1); + moe_gate_up_mid_f32_kernel<<>>( + (float *)gate->ptr, + (float *)up->ptr, + (float *)mid->ptr, + gate_w, + up_w, + (const float *)x->ptr, + (const int32_t *)selected->ptr, + (const float *)weights->ptr, + gate_expert_bytes, + gate_row_bytes, + expert_in_dim, + expert_mid_dim, + n_expert, + clamp); + ok = hip_ok(hipGetLastError(), "routed_moe gate/up launch"); + } + if (ok) { + dim3 dgrid(out_dim, n_tokens * n_expert, 1); + moe_down_f32_kernel<<>>( + (float *)down->ptr, + down_w, + (const float *)mid->ptr, + (const int32_t *)selected->ptr, + down_expert_bytes, + down_row_bytes, + expert_mid_dim, + out_dim, + n_expert); + ok = hip_ok(hipGetLastError(), "routed_moe down launch"); + } + if (ok) { + uint64_t n = (uint64_t)n_tokens * out_dim; + moe_sum_kernel<<<(n + 255) / 256, 256>>>((float *)out->ptr, (const float *)down->ptr, out_dim, n_expert, n_tokens); + ok = hip_ok(hipGetLastError(), "routed_moe sum launch"); + } + return ok; +} + +extern "C" int ds4_gpu_routed_moe_one_tensor(ds4_gpu_tensor *out, ds4_gpu_tensor *gate, ds4_gpu_tensor *up, ds4_gpu_tensor *mid, ds4_gpu_tensor *down, const void *model_map, uint64_t model_size, uint64_t gate_offset, uint64_t up_offset, uint64_t down_offset, uint32_t gate_type, uint32_t down_type, uint64_t gate_expert_bytes, uint64_t gate_row_bytes, uint64_t down_expert_bytes, uint64_t down_row_bytes, uint32_t expert_in_dim, uint32_t expert_mid_dim, uint32_t out_dim, const ds4_gpu_tensor *selected, const ds4_gpu_tensor *weights, uint32_t n_expert, float clamp, const ds4_gpu_tensor *x) { + return routed_moe_launch(out, gate, up, mid, down, model_map, model_size, + gate_offset, up_offset, down_offset, + gate_type, down_type, + gate_expert_bytes, gate_row_bytes, + down_expert_bytes, down_row_bytes, + expert_in_dim, expert_mid_dim, out_dim, + selected, weights, n_expert, clamp, x, 1); +} +extern "C" int ds4_gpu_routed_moe_batch_tensor(ds4_gpu_tensor *out, ds4_gpu_tensor *gate, ds4_gpu_tensor *up, ds4_gpu_tensor *mid, ds4_gpu_tensor *down, const void *model_map, uint64_t model_size, uint64_t gate_offset, uint64_t up_offset, uint64_t down_offset, uint32_t gate_type, uint32_t down_type, uint64_t gate_expert_bytes, uint64_t gate_row_bytes, uint64_t down_expert_bytes, uint64_t down_row_bytes, uint32_t expert_in_dim, uint32_t expert_mid_dim, uint32_t out_dim, const ds4_gpu_tensor *selected, const ds4_gpu_tensor *weights, uint32_t n_expert, float clamp, const ds4_gpu_tensor *x, uint32_t n_tokens) { + return routed_moe_launch(out, gate, up, mid, down, model_map, model_size, + gate_offset, up_offset, down_offset, + gate_type, down_type, + gate_expert_bytes, gate_row_bytes, + down_expert_bytes, down_row_bytes, + expert_in_dim, expert_mid_dim, out_dim, + selected, weights, n_expert, clamp, x, n_tokens); +} +extern "C" int ds4_gpu_hc_split_sinkhorn_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *mix, const void *model_map, uint64_t model_size, uint64_t scale_offset, uint64_t base_offset, uint32_t n_hc, uint32_t sinkhorn_iters, float eps) { + if (!out || !mix || !model_map || n_hc != 4) return 0; + const uint64_t mix_bytes = 24ull * sizeof(float); + if (scale_offset > model_size || model_size - scale_offset < 3ull * sizeof(float) || + base_offset > model_size || model_size - base_offset < mix_bytes || + mix->bytes < mix_bytes || out->bytes < mix_bytes) return 0; + const float *scale = (const float *)hip_model_range_ptr(model_map, scale_offset, 3ull * sizeof(float), "hc_scale"); + const float *base = (const float *)hip_model_range_ptr(model_map, base_offset, mix_bytes, "hc_base"); + if (!scale || !base) return 0; + uint32_t n_rows = (uint32_t)(mix->bytes / mix_bytes); + if (out->bytes / mix_bytes < n_rows) n_rows = (uint32_t)(out->bytes / mix_bytes); + hc_split_sinkhorn_kernel<<<(n_rows + 255) / 256, 256>>>( + (float *)out->ptr, (const float *)mix->ptr, + scale, + base, + n_rows, sinkhorn_iters, eps); + return hip_ok(hipGetLastError(), "hc_split_sinkhorn launch"); +} +extern "C" int ds4_gpu_hc_weighted_sum_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *residual_hc, const ds4_gpu_tensor *weights, uint32_t n_embd, uint32_t n_hc) { + if (!out || !residual_hc || !weights || n_embd == 0 || n_hc == 0) return 0; + uint32_t n_tokens = (uint32_t)(out->bytes / ((uint64_t)n_embd * sizeof(float))); + hc_weighted_sum_kernel<<<((uint64_t)n_embd * n_tokens + 255) / 256, 256>>>( + (float *)out->ptr, (const float *)residual_hc->ptr, (const float *)weights->ptr, + n_embd, n_hc, n_tokens, n_hc); + return hip_ok(hipGetLastError(), "hc_weighted_sum launch"); +} +extern "C" int ds4_gpu_hc_weighted_sum_split_tensor(ds4_gpu_tensor *out, const ds4_gpu_tensor *residual_hc, const ds4_gpu_tensor *split, uint32_t n_embd, uint32_t n_hc) { + if (!out || !residual_hc || !split || n_embd == 0 || n_hc == 0) return 0; + uint32_t n_tokens = (uint32_t)(out->bytes / ((uint64_t)n_embd * sizeof(float))); + uint32_t stride = (uint32_t)(2u * n_hc + n_hc * n_hc); + hc_weighted_sum_kernel<<<((uint64_t)n_embd * n_tokens + 255) / 256, 256>>>( + (float *)out->ptr, (const float *)residual_hc->ptr, (const float *)split->ptr, + n_embd, n_hc, n_tokens, stride); + return hip_ok(hipGetLastError(), "hc_weighted_sum_split launch"); +} +extern "C" int ds4_gpu_hc_split_weighted_sum_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *split, + const ds4_gpu_tensor *mix, + const ds4_gpu_tensor *residual_hc, + const void *model_map, + uint64_t model_size, + uint64_t scale_offset, + uint64_t base_offset, + uint32_t n_embd, + uint32_t n_hc, + uint32_t sinkhorn_iters, + float eps) { + if (!out || !split || !mix || !residual_hc || !model_map || + n_embd == 0 || n_hc != 4) { + return 0; + } + const uint64_t mix_hc = 2ull * n_hc + (uint64_t)n_hc * n_hc; + const uint64_t mix_bytes = mix_hc * sizeof(float); + const uint64_t out_row_bytes = (uint64_t)n_embd * sizeof(float); + const uint64_t residual_row_bytes = (uint64_t)n_hc * n_embd * sizeof(float); + if (out->bytes < out_row_bytes || out->bytes % out_row_bytes != 0 || + scale_offset > model_size || 3ull * sizeof(float) > model_size - scale_offset || + base_offset > model_size || mix_bytes > model_size - base_offset) { + return 0; + } + uint64_t n_rows = out->bytes / out_row_bytes; + if (mix->bytes < n_rows * mix_bytes || + split->bytes < n_rows * mix_bytes || + residual_hc->bytes < n_rows * residual_row_bytes) { + return 0; + } + const float *scale = (const float *)hip_model_range_ptr(model_map, scale_offset, 3ull * sizeof(float), "hc_scale"); + const float *base = (const float *)hip_model_range_ptr(model_map, base_offset, mix_bytes, "hc_base"); + if (!scale || !base) return 0; + hc_split_weighted_sum_fused_kernel<<<(uint32_t)n_rows, 256>>>( + (float *)out->ptr, + (float *)split->ptr, + (const float *)mix->ptr, + (const float *)residual_hc->ptr, + scale, + base, + n_embd, n_hc, (uint32_t)n_rows, sinkhorn_iters, eps); + return hip_ok(hipGetLastError(), "hc split weighted sum launch"); +} +extern "C" int ds4_gpu_hc_split_weighted_sum_norm_tensor( + ds4_gpu_tensor *out, + ds4_gpu_tensor *norm_out, + ds4_gpu_tensor *split, + const ds4_gpu_tensor *mix, + const ds4_gpu_tensor *residual_hc, + const void *model_map, + uint64_t model_size, + uint64_t scale_offset, + uint64_t base_offset, + uint64_t norm_weight_offset, + uint32_t n_embd, + uint32_t n_hc, + uint32_t sinkhorn_iters, + float eps, + float norm_eps) { + if (getenv("DS4_HIP_DISABLE_HC_SPLIT_NORM_FUSED") == NULL) { + if (!out || !norm_out || !split || !mix || !residual_hc || !model_map || + n_embd == 0 || n_hc != 4) { + return 0; + } + const uint64_t mix_hc = 2ull * n_hc + (uint64_t)n_hc * n_hc; + const uint64_t mix_bytes = mix_hc * sizeof(float); + const uint64_t out_row_bytes = (uint64_t)n_embd * sizeof(float); + const uint64_t residual_row_bytes = (uint64_t)n_hc * n_embd * sizeof(float); + if (out->bytes < out_row_bytes || out->bytes % out_row_bytes != 0 || + norm_out->bytes < out->bytes || + scale_offset > model_size || 3ull * sizeof(float) > model_size - scale_offset || + base_offset > model_size || mix_bytes > model_size - base_offset || + norm_weight_offset > model_size || + (uint64_t)n_embd * sizeof(float) > model_size - norm_weight_offset) { + return 0; + } + uint64_t n_rows = out->bytes / out_row_bytes; + if (n_rows == 1) { + if (mix->bytes < n_rows * mix_bytes || + split->bytes < n_rows * mix_bytes || + residual_hc->bytes < n_rows * residual_row_bytes) { + return 0; + } + const float *scale = (const float *)hip_model_range_ptr(model_map, scale_offset, + 3ull * sizeof(float), "hc_scale"); + const float *base = (const float *)hip_model_range_ptr(model_map, base_offset, + mix_bytes, "hc_base"); + const float *norm_w = (const float *)hip_model_range_ptr(model_map, norm_weight_offset, + (uint64_t)n_embd * sizeof(float), "hc_norm_weight"); + if (!scale || !base || !norm_w) return 0; + hc_split_weighted_sum_norm_fused_kernel<<<(uint32_t)n_rows, 256>>>( + (float *)out->ptr, + (float *)norm_out->ptr, + (float *)split->ptr, + (const float *)mix->ptr, + (const float *)residual_hc->ptr, + scale, + base, + norm_w, + n_embd, n_hc, (uint32_t)n_rows, sinkhorn_iters, eps, norm_eps); + return hip_ok(hipGetLastError(), "hc split weighted sum norm launch"); + } + } + return ds4_gpu_hc_split_weighted_sum_tensor(out, split, mix, residual_hc, + model_map, model_size, + scale_offset, base_offset, + n_embd, n_hc, + sinkhorn_iters, eps) && + ds4_gpu_rms_norm_weight_tensor(norm_out, out, model_map, model_size, + norm_weight_offset, n_embd, norm_eps); +} +extern "C" int ds4_gpu_output_hc_weights_tensor( + ds4_gpu_tensor *out, + const ds4_gpu_tensor *pre, + const void *model_map, + uint64_t model_size, + uint64_t scale_offset, + uint64_t base_offset, + uint32_t n_hc, + float eps) { + if (!out || !pre || !model_map || n_hc == 0) return 0; + const uint64_t row_bytes = (uint64_t)n_hc * sizeof(float); + if (row_bytes == 0 || out->bytes < row_bytes || out->bytes % row_bytes != 0 || + pre->bytes < out->bytes || + scale_offset > model_size || sizeof(float) > model_size - scale_offset || + base_offset > model_size || row_bytes > model_size - base_offset) { + return 0; + } + const uint64_t n_tokens = out->bytes / row_bytes; + const float *scale = (const float *)hip_model_range_ptr(model_map, scale_offset, sizeof(float), "output_hc_scale"); + const float *base = (const float *)hip_model_range_ptr(model_map, base_offset, row_bytes, "output_hc_base"); + if (!scale || !base) return 0; + uint64_t n = n_tokens * n_hc; + output_hc_weights_kernel<<<(n + 255) / 256, 256>>>( + (float *)out->ptr, + (const float *)pre->ptr, + scale, + base, + n_hc, + (uint32_t)n_tokens, + eps); + return hip_ok(hipGetLastError(), "output hc weights launch"); +} +extern "C" int ds4_gpu_hc_expand_tensor(ds4_gpu_tensor *out_hc, const ds4_gpu_tensor *block_out, const ds4_gpu_tensor *residual_hc, const ds4_gpu_tensor *post, const ds4_gpu_tensor *comb, uint32_t n_embd, uint32_t n_hc) { + if (!out_hc || !block_out || !residual_hc || !post || !comb || n_embd == 0 || n_hc == 0) return 0; + uint32_t n_tokens = (uint32_t)(out_hc->bytes / ((uint64_t)n_hc * n_embd * sizeof(float))); + uint64_t n_elem = (uint64_t)n_tokens * n_hc * n_embd; + hc_expand_kernel<<<(n_elem + 255) / 256, 256>>>((float *)out_hc->ptr, + (const float *)block_out->ptr, + (const float *)block_out->ptr, + (const float *)residual_hc->ptr, + (const float *)post->ptr, + (const float *)comb->ptr, + n_embd, n_hc, n_tokens, + n_hc, n_hc * n_hc, 0); + return hip_ok(hipGetLastError(), "hc_expand launch"); +} +extern "C" int ds4_gpu_hc_expand_split_tensor(ds4_gpu_tensor *out_hc, const ds4_gpu_tensor *block_out, const ds4_gpu_tensor *residual_hc, const ds4_gpu_tensor *split, uint32_t n_embd, uint32_t n_hc) { + if (!out_hc || !block_out || !residual_hc || !split || n_embd == 0 || n_hc == 0) return 0; + uint32_t n_tokens = (uint32_t)(out_hc->bytes / ((uint64_t)n_hc * n_embd * sizeof(float))); + uint32_t mix_hc = 2u * n_hc + n_hc * n_hc; + uint64_t n_elem = (uint64_t)n_tokens * n_hc * n_embd; + const float *base = (const float *)split->ptr; + hc_expand_kernel<<<(n_elem + 255) / 256, 256>>>((float *)out_hc->ptr, + (const float *)block_out->ptr, + (const float *)block_out->ptr, + (const float *)residual_hc->ptr, + base + n_hc, + base + 2u * n_hc, + n_embd, n_hc, n_tokens, + mix_hc, mix_hc, 0); + return hip_ok(hipGetLastError(), "hc_expand_split launch"); +} +extern "C" int ds4_gpu_hc_expand_add_split_tensor(ds4_gpu_tensor *out_hc, const ds4_gpu_tensor *block_out, const ds4_gpu_tensor *block_add, const ds4_gpu_tensor *residual_hc, const ds4_gpu_tensor *split, uint32_t n_embd, uint32_t n_hc) { + if (!out_hc || !block_out || !block_add || !residual_hc || !split || n_embd == 0 || n_hc == 0) return 0; + uint32_t n_tokens = (uint32_t)(out_hc->bytes / ((uint64_t)n_hc * n_embd * sizeof(float))); + uint32_t mix_hc = 2u * n_hc + n_hc * n_hc; + uint64_t n_elem = (uint64_t)n_tokens * n_hc * n_embd; + const float *base = (const float *)split->ptr; + hc_expand_kernel<<<(n_elem + 255) / 256, 256>>>((float *)out_hc->ptr, + (const float *)block_out->ptr, + (const float *)block_add->ptr, + (const float *)residual_hc->ptr, + base + n_hc, + base + 2u * n_hc, + n_embd, n_hc, n_tokens, + mix_hc, mix_hc, 1); + return hip_ok(hipGetLastError(), "hc_expand_add_split launch"); +} +extern "C" int ds4_gpu_shared_down_hc_expand_q8_0_tensor( + ds4_gpu_tensor *out_hc, + ds4_gpu_tensor *shared_out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *shared_mid, + const ds4_gpu_tensor *routed_out, + const ds4_gpu_tensor *residual_hc, + const ds4_gpu_tensor *split, + uint32_t n_embd, + uint32_t n_hc) { + if (getenv("DS4_HIP_DISABLE_Q8_HC_EXPAND_FUSED") == NULL) { + return hip_matmul_q8_0_hc_expand_tensor_labeled(out_hc, shared_out, + model_map, model_size, + weight_offset, + in_dim, out_dim, + shared_mid, + routed_out, + residual_hc, + split, + n_embd, n_hc, + "shared_down_hc_expand"); + } + return ds4_gpu_matmul_q8_0_tensor(shared_out, model_map, model_size, + weight_offset, in_dim, out_dim, + shared_mid, 1) && + ds4_gpu_hc_expand_add_split_tensor(out_hc, shared_out, routed_out, + residual_hc, split, n_embd, n_hc); +} + +extern "C" int ds4_gpu_matmul_q8_0_hc_expand_tensor( + ds4_gpu_tensor *out_hc, + ds4_gpu_tensor *block_out, + const void *model_map, + uint64_t model_size, + uint64_t weight_offset, + uint64_t in_dim, + uint64_t out_dim, + const ds4_gpu_tensor *x, + const ds4_gpu_tensor *residual_hc, + const ds4_gpu_tensor *split, + uint32_t n_embd, + uint32_t n_hc) { + if (getenv("DS4_HIP_DISABLE_Q8_HC_EXPAND_FUSED") == NULL) { + return hip_matmul_q8_0_hc_expand_tensor_labeled(out_hc, block_out, + model_map, model_size, + weight_offset, + in_dim, out_dim, + x, + NULL, + residual_hc, + split, + n_embd, n_hc, + "q8_hc_expand"); + } + return ds4_gpu_matmul_q8_0_tensor(block_out, model_map, model_size, + weight_offset, in_dim, out_dim, x, 1) && + ds4_gpu_hc_expand_split_tensor(out_hc, block_out, residual_hc, + split, n_embd, n_hc); +} diff --git a/ds4_iq2_tables_hip.inc b/ds4_iq2_tables_hip.inc new file mode 100644 index 00000000..2060140f --- /dev/null +++ b/ds4_iq2_tables_hip.inc @@ -0,0 +1,77 @@ +__device__ __constant__ uint8_t hip_ksigns_iq2xs[128] = { + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +}; + +__device__ __constant__ uint64_t hip_iq2xxs_grid[256] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +}; diff --git a/ds4_server.c b/ds4_server.c index bc8abbbd..d4b384a8 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -7906,8 +7906,8 @@ static void usage(FILE *fp) { " Apply steering after attention outputs. Default: 0\n" " --warm-weights\n" " Touch mapped tensor pages before serving. Slower startup, fewer first-use stalls.\n" - " --metal | --cuda | --cpu | --backend NAME\n" - " Select backend explicitly. Defaults to Metal on macOS and CUDA on CUDA builds.\n" + " --metal | --cuda | --rocm | --cpu | --backend NAME\n" + " Select backend explicitly. Defaults to Metal on macOS and ROCm on ROCm builds.\n" "\n" "HTTP API:\n" " --host HOST\n" @@ -7968,9 +7968,10 @@ static void usage(FILE *fp) { static ds4_backend parse_backend_arg(const char *s, const char *arg) { if (!strcmp(s, "metal")) return DS4_BACKEND_METAL; if (!strcmp(s, "cuda")) return DS4_BACKEND_CUDA; + if (!strcmp(s, "rocm")) return DS4_BACKEND_ROCM; if (!strcmp(s, "cpu")) return DS4_BACKEND_CPU; server_log(DS4_LOG_DEFAULT, "ds4-server: invalid %s value: %s", arg, s); - server_log(DS4_LOG_DEFAULT, "ds4-server: valid server backends are: metal, cuda, cpu"); + server_log(DS4_LOG_DEFAULT, "ds4-server: valid server backends are: metal, cuda, rocm, cpu"); exit(2); } @@ -7979,6 +7980,8 @@ static ds4_backend default_server_backend(void) { return DS4_BACKEND_CPU; #elif defined(__APPLE__) return DS4_BACKEND_METAL; +#elif defined(DS4_HAVE_ROCM) + return DS4_BACKEND_ROCM; #else return DS4_BACKEND_CUDA; #endif @@ -8062,6 +8065,8 @@ static server_config parse_options(int argc, char **argv) { c.engine.backend = DS4_BACKEND_METAL; } else if (!strcmp(arg, "--cuda")) { c.engine.backend = DS4_BACKEND_CUDA; + } else if (!strcmp(arg, "--rocm")) { + c.engine.backend = DS4_BACKEND_ROCM; } else if (!strcmp(arg, "--backend")) { c.engine.backend = parse_backend_arg(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--cpu")) { diff --git a/rocm_start_server.sh b/rocm_start_server.sh new file mode 100755 index 00000000..4bbf1624 --- /dev/null +++ b/rocm_start_server.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Unified DS4 Start Script (ROCm/HIP optimized) +# Fuses cleanup, memory flushing, and server execution. + +set -e + +echo "--- Step 1: Killing Stale Server Instances ---" +pkill -9 -x ds4-server || true +rm -f /tmp/ds4.lock + +echo "--- Step 2: Cleaning System Cache Memory ---" +sudo sync +echo 3 | sudo tee /proc/sys/vm/drop_caches + +echo "--- Step 3: Setting ROCm Environment ---" +# For HSA/Strix Halo, unsetting COPY_MODEL enables optimal Zero-Copy path +unset DS4_HIP_COPY_MODEL +export DS4_HIP_PREFILL_CHUNK=4096 + +echo "--- Step 4: Starting DS4 Server ---" +# We run this in the background if it's called with --bg, otherwise we tail the log. +# For simplicity, we'll always use a log file to track initialization. +LOG_FILE="/tmp/ds4-server.log" + +cd "$(dirname "$0")" +nohup ./ds4-server --rocm --ctx 65536 \ + --warm-weights \ + --kv-disk-dir /tmp/ds4-kv --kv-disk-space-mb 8192 \ + > "$LOG_FILE" 2>&1 & + +echo "--- Step 5: Waiting for Initialization ---" +sleep 2 +tail -f "$LOG_FILE" | sed '/listening on/q' + +echo "--- DONE: Server is running at http://127.0.0.1:8000 ---" diff --git a/start_server.sh b/start_server.sh new file mode 100755 index 00000000..0e65a334 --- /dev/null +++ b/start_server.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# 1. Clean system cache memory (Requires sudo) +sudo sync +echo 3 | sudo tee /proc/sys/vm/drop_caches + +# 2. Ensure no stale processes or locks +pkill -9 -x ds4-server || true +rm -f /tmp/ds4.lock + +# 3. Use Zero-Copy UMA Mode (Direct access to RAM) +# For HIP, this is controlled by DS4_HIP_COPY_MODEL. Unsetting it enables zero-copy if possible. +unset DS4_HIP_COPY_MODEL + +# 4. Set Prefill Chunk Size (Backend specific) +export DS4_HIP_PREFILL_CHUNK=4096 + +# 5. Start the optimized ds4-server with ROCm backend +# --rocm enables the ROCm/HIP graph backend. +# To use MTP, add: --mtp gguf/DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf +exec ./ds4-server --rocm --ctx 65536 \ + --warm-weights \ + --kv-disk-dir /tmp/ds4-kv --kv-disk-space-mb 8192 diff --git a/tests/rocm_long_context_smoke.c b/tests/rocm_long_context_smoke.c new file mode 100644 index 00000000..933a7a3c --- /dev/null +++ b/tests/rocm_long_context_smoke.c @@ -0,0 +1,158 @@ +#include "ds4_gpu.h" + +#include +#include +#include +#include + +static double monotonic_seconds(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (double)ts.tv_sec + (double)ts.tv_nsec / 1000000000.0; +} + +static double getenv_seconds(const char *name, double fallback) { + const char *s = getenv(name); + if (!s || !s[0]) return fallback; + char *end = NULL; + const double v = strtod(s, &end); + return end != s && v > 0.0 ? v : fallback; +} + +static int check_large_topk(void) { + const uint32_t n_comp = 32768; + const uint32_t n_tokens = 32; + const uint32_t top_k = 512; + const uint64_t score_count = (uint64_t)n_comp * n_tokens; + float *scores_host = (float *)malloc((size_t)score_count * sizeof(float)); + uint32_t *selected_host = (uint32_t *)malloc((size_t)n_tokens * top_k * sizeof(uint32_t)); + if (!scores_host || !selected_host) return 1; + + for (uint32_t t = 0; t < n_tokens; t++) { + for (uint32_t i = 0; i < n_comp; i++) { + scores_host[(uint64_t)t * n_comp + i] = (float)i; + } + } + + ds4_gpu_tensor *scores = ds4_gpu_tensor_alloc(score_count * sizeof(float)); + ds4_gpu_tensor *selected = ds4_gpu_tensor_alloc((uint64_t)n_tokens * top_k * sizeof(uint32_t)); + int rc = 1; + double elapsed = 0.0; + if (scores && selected && + ds4_gpu_tensor_write(scores, 0, scores_host, score_count * sizeof(float))) { + const double t0 = monotonic_seconds(); + if (ds4_gpu_indexer_topk_tensor(selected, scores, n_comp, n_tokens, top_k) && + ds4_gpu_synchronize()) { + elapsed = monotonic_seconds() - t0; + rc = ds4_gpu_tensor_read(selected, 0, selected_host, + (uint64_t)n_tokens * top_k * sizeof(uint32_t)) ? 0 : 1; + } + } + if (rc == 0) { + for (uint32_t t = 0; t < n_tokens && rc == 0; t++) { + for (uint32_t i = 0; i < top_k; i++) { + const uint32_t expected = n_comp - 1u - i; + const uint32_t got = selected_host[(uint64_t)t * top_k + i]; + if (got != expected) { + fprintf(stderr, "top-k mismatch token=%u rank=%u got=%u expected=%u\n", + t, i, got, expected); + rc = 1; + break; + } + } + } + } + if (rc == 0) { + const double max_seconds = getenv_seconds("DS4_ROCM_TOPK_REGRESSION_SEC", 2.0); + fprintf(stderr, "rocm-regression: top-k n_comp=%u n_tokens=%u elapsed=%.3fs\n", + n_comp, n_tokens, elapsed); + if (elapsed > max_seconds) { + fprintf(stderr, "top-k regression: %.3fs exceeds %.3fs\n", elapsed, max_seconds); + rc = 1; + } + } + + ds4_gpu_tensor_free(selected); + ds4_gpu_tensor_free(scores); + free(selected_host); + free(scores_host); + return rc; +} + +static int check_decode_attention_overflow_path(void) { + const uint32_t n_head = 8; + const uint32_t head_dim = 512; + const uint32_t n_raw = 128; + const uint32_t n_comp = 8100; + const uint64_t q_count = (uint64_t)n_head * head_dim; + const uint64_t raw_count = (uint64_t)n_raw * head_dim; + const uint64_t comp_count = (uint64_t)n_comp * head_dim; + + float *sinks = (float *)calloc(n_head, sizeof(float)); + float *q_host = (float *)calloc((size_t)q_count, sizeof(float)); + float *raw_host = (float *)calloc((size_t)raw_count, sizeof(float)); + float *comp_host = (float *)calloc((size_t)comp_count, sizeof(float)); + float *heads_host = (float *)calloc((size_t)q_count, sizeof(float)); + if (!sinks || !q_host || !raw_host || !comp_host || !heads_host) return 1; + + for (uint32_t c = 0; c < n_comp; c++) { + comp_host[(uint64_t)c * head_dim] = 1.0f; + } + + ds4_gpu_tensor *heads = ds4_gpu_tensor_alloc(q_count * sizeof(float)); + ds4_gpu_tensor *q = ds4_gpu_tensor_alloc(q_count * sizeof(float)); + ds4_gpu_tensor *raw = ds4_gpu_tensor_alloc(raw_count * sizeof(float)); + ds4_gpu_tensor *comp = ds4_gpu_tensor_alloc(comp_count * sizeof(float)); + int rc = 1; + if (heads && q && raw && comp && + ds4_gpu_tensor_write(q, 0, q_host, q_count * sizeof(float)) && + ds4_gpu_tensor_write(raw, 0, raw_host, raw_count * sizeof(float)) && + ds4_gpu_tensor_write(comp, 0, comp_host, comp_count * sizeof(float)) && + ds4_gpu_attention_decode_heads_tensor(heads, + sinks, + n_head * sizeof(float), + 0, + q, + raw, + n_raw, + n_raw, + 0, + comp, + n_comp, + NULL, + 0, + n_head, + head_dim) && + ds4_gpu_synchronize() && + ds4_gpu_tensor_read(heads, 0, heads_host, q_count * sizeof(float))) { + rc = 0; + for (uint32_t h = 0; h < n_head; h++) { + const float v = heads_host[(uint64_t)h * head_dim]; + if (v < 0.90f) { + fprintf(stderr, "attention fallback ignored compressed rows for head=%u value=%f\n", + h, (double)v); + rc = 1; + } + } + } + + ds4_gpu_tensor_free(comp); + ds4_gpu_tensor_free(raw); + ds4_gpu_tensor_free(q); + ds4_gpu_tensor_free(heads); + free(heads_host); + free(comp_host); + free(raw_host); + free(q_host); + free(sinks); + return rc; +} + +int main(void) { + if (!ds4_gpu_init()) return 1; + int rc = check_large_topk(); + if (check_decode_attention_overflow_path() != 0) rc = 1; + ds4_gpu_cleanup(); + if (rc == 0) puts("rocm long-context regression: OK"); + return rc; +}