Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,38 @@ struct ggml_cuda_pool_alloc {
ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
};

#ifdef GGML_USE_HIP
// Direct alloc/free, avoids legacy pool retention on HIP without VMM
template<typename T>
struct ggml_cuda_direct_alloc {
T * ptr = nullptr;
cudaStream_t stream;

ggml_cuda_direct_alloc() = default;
explicit ggml_cuda_direct_alloc(cudaStream_t s) : stream(s) {}

~ggml_cuda_direct_alloc() {
if (ptr) {
cudaStreamSynchronize(stream);
cudaFree(ptr);
}
}

T * alloc(size_t size) {
GGML_ASSERT(ptr == nullptr);
CUDA_CHECK(cudaMalloc(&ptr, size * sizeof(T)));
return ptr;
}

T * get() { return ptr; }

ggml_cuda_direct_alloc(const ggml_cuda_direct_alloc &) = delete;
ggml_cuda_direct_alloc(ggml_cuda_direct_alloc &&) = delete;
ggml_cuda_direct_alloc& operator=(const ggml_cuda_direct_alloc &) = delete;
ggml_cuda_direct_alloc& operator=(ggml_cuda_direct_alloc &&) = delete;
};
#endif


// backend interface

Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,13 @@ void launch_fattn(
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;

#ifdef GGML_USE_HIP
ggml_cuda_direct_alloc<half> K_f16(main_stream);
ggml_cuda_direct_alloc<half> V_f16(main_stream);
#else
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
#endif
ggml_cuda_pool_alloc<int> KV_max(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
Expand Down