From 17ade469761aaca5ade456447f76f124e0484bc1 Mon Sep 17 00:00:00 2001 From: TheTom Date: Mon, 20 Apr 2026 15:39:45 -0500 Subject: [PATCH] hip: direct alloc for FA f16 temp buffers On HIP without VMM, the legacy pool retains these at peak size causing quantized KV to OOM before f16. ggml_cuda_direct_alloc uses raw hipMalloc/hipFree instead. HIP-only, complements #22155. Fixes #22107 without performance degradation. Tested: gfx1100, gfx1200, gfx1201. --- ggml/src/ggml-cuda/common.cuh | 32 +++++++++++++++++++++++++++++ ggml/src/ggml-cuda/fattn-common.cuh | 5 +++++ 2 files changed, 37 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 3aec1742ee1..9ac9ad70919 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1153,6 +1153,38 @@ struct ggml_cuda_pool_alloc { ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete; }; +#ifdef GGML_USE_HIP +// Direct alloc/free, avoids legacy pool retention on HIP without VMM +template +struct ggml_cuda_direct_alloc { + T * ptr = nullptr; + cudaStream_t stream; + + ggml_cuda_direct_alloc() = default; + explicit ggml_cuda_direct_alloc(cudaStream_t s) : stream(s) {} + + ~ggml_cuda_direct_alloc() { + if (ptr) { + cudaStreamSynchronize(stream); + cudaFree(ptr); + } + } + + T * alloc(size_t size) { + GGML_ASSERT(ptr == nullptr); + CUDA_CHECK(cudaMalloc(&ptr, size * sizeof(T))); + return ptr; + } + + T * get() { return ptr; } + + ggml_cuda_direct_alloc(const ggml_cuda_direct_alloc &) = delete; + ggml_cuda_direct_alloc(ggml_cuda_direct_alloc &&) = delete; + ggml_cuda_direct_alloc& operator=(const ggml_cuda_direct_alloc &) = delete; + ggml_cuda_direct_alloc& operator=(ggml_cuda_direct_alloc &&) = delete; +}; +#endif + // backend interface diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb5238946..0e1ea0c04cc 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -946,8 +946,13 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; +#ifdef GGML_USE_HIP + ggml_cuda_direct_alloc K_f16(main_stream); + ggml_cuda_direct_alloc V_f16(main_stream); +#else ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); +#endif ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool);