Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9193df5
Merge NripeshN/mlx rocm-support into upstream main
Geramy Mar 25, 2026
9fddf1c
Add RDNA 3.5/4 architectures and parallel HIP compilation
Geramy Mar 25, 2026
3ae44dc
Fix parallel-jobs flag: single dash for hipcc/clang
Geramy Mar 25, 2026
2b8a7d1
Limit HIP parallel-jobs to half of available CPUs
Geramy Mar 25, 2026
c2eb919
Add missing gpu::init() and SliceUpdate::eval_gpu stub for ROCm
Geramy Mar 25, 2026
26e733c
Implement ROCm-optimized SliceUpdate::eval_gpu
Geramy Mar 25, 2026
edd89a1
Fix bfloat16/half JIT compilation for ROCm fused kernels
Geramy Mar 25, 2026
1ab4186
Simplify JIT preamble ops: always promote through float
Geramy Mar 25, 2026
d03fa7c
Fix critical bug: JIT KernelArgs passed CPU pointers instead of GPU
Geramy Mar 25, 2026
76741bc
Remove gfx1150/1151/1152/1200/1201 from rocBLAS supported list
Geramy Mar 25, 2026
9336df8
Add rocBLAS fallback to naive_gemm when Tensile kernel missing
Geramy Mar 25, 2026
f92d2d2
Add missing kernel_utils.hpp include for gpu_ptr in rocblas_gemm
Geramy Mar 25, 2026
8acadb4
Probe rocBLAS bf16 GEMM at device init, fallback to naive_gemm
Geramy Mar 25, 2026
bfab6fb
Always use naive_gemm for bfloat16 GEMM on ROCm
Geramy Mar 25, 2026
c8c9c8e
ROCm bug fixes + optimized quantized GEMV kernel
Geramy Mar 26, 2026
2f47aeb
Promote JIT binary ops through float, restore rocBLAS for gfx1151
Geramy Mar 26, 2026
6520667
GatherQMM: ensure contiguous indices, SDPA: add head_dim=256
Geramy Mar 26, 2026
00d8c2e
SDPA GPU decomposition, naive_gemm for all types, GatherQMM contiguou…
Geramy Mar 26, 2026
4a5bb0f
Metal-compatible QMM accumulation, JIT stderr suppression
Geramy Mar 26, 2026
73470d8
Fix GatherQMM memory corruption, add index bounds clamping
Geramy Mar 26, 2026
1e50c74
Kernel audit: match Metal precision across RMSNorm, sort, softmax, ops
Geramy Mar 26, 2026
1793485
Fix batched matmul: missing bfloat16/float16 in loop-based GQA path
Geramy Mar 27, 2026
840d028
Add head_dim=256 dispatch to SDPA vector kernel
Geramy Mar 27, 2026
d30fe29
Merge upstream NripeshN/mlx rocm-support with ROCm optimizations
Geramy Mar 27, 2026
5ffb863
Enable 4-bit fast gather QMV dispatch for MoE decode
Geramy Mar 27, 2026
b1300b9
Optimize ROCm allocator for integrated GPUs (APU)
Geramy Mar 27, 2026
780b4fe
Prefer shared-memory QMV over noshared variant for decode
Geramy Mar 27, 2026
0ec6b45
Add expert-grouped prefill kernel for GatherQMM (3.4x prompt speedup)
Geramy Mar 27, 2026
c9167d2
Allocator: prefer hipExtMallocWithFlags for APU, fallback to hipMallo…
Geramy Mar 27, 2026
a66e273
Add WMMA-accelerated prefill kernel for GatherQMM on RDNA 3/3.5/4
Geramy Mar 27, 2026
e35d6aa
WMMA prefill kernel: support non-aligned M, sort unsorted indices
Geramy Mar 27, 2026
435afdc
Add GPU-only expert-batched gather QMV kernel for low-expert MoE
Geramy Mar 27, 2026
bc4d62f
Add hipBLASLt GEMM integration for bf16/fp16 matmul on ROCm
Geramy Mar 27, 2026
b8b56b1
hipBLASLt: add to QMM dequant+GEMM path for bf16 (2.6x prompt speedup)
Geramy Mar 27, 2026
7ac6efd
hipBLASLt in QMM dequant path + CommandEncoder graph capture API
Geramy Mar 27, 2026
b913c68
Strided copy kernels for ensure_row_contiguous in QMM
Geramy Mar 27, 2026
da1925b
Allocator: power-of-2 rounding for large allocs (>= 1MB)
Geramy Mar 28, 2026
65958fa
Allocator: use system RAM limit for iGPU, power-of-2 rounding for lar…
Geramy Mar 28, 2026
b010eee
Allocator: revert power-of-2 rounding, keep hipExtMallocWithFlags
Geramy Mar 28, 2026
f26c802
Fix CU count comment: 40 CUs (20 WGPs) on gfx1151
Geramy Mar 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions mlx/backend/rocm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@ find_package(rocblas REQUIRED CONFIG)
find_package(rocthrust REQUIRED CONFIG)
find_package(rocprim REQUIRED CONFIG)
find_package(hiprand REQUIRED CONFIG)
find_package(rocwmma REQUIRED CONFIG)

# Ensure HIP architectures are set - respect user-provided value from command
# line The user can set this via -DCMAKE_HIP_ARCHITECTURES=gfx1011
#
# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix: CDNA:
# gfx908 (MI100), gfx90a (MI200), gfx942 (MI300) CDNA4: gfx950 (MI400 series)
# RDNA2: gfx1030 (RX 6000 series) RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600)
# RDNA4: gfx1200, gfx1201 (RX 8000 series)
# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix:
# CDNA: gfx908 (MI100), gfx90a (MI200), gfx942 (MI300)
# RDNA2: gfx1030 (RX 6000 series)
# RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600)
# RDNA3.5: gfx1150, gfx1151, gfx1152 (Ryzen AI / Radeon 8060S)
# RDNA4: gfx1200, gfx1201 (RX 9000 series)
if(NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES
"gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102"
"gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102;gfx1150;gfx1151;gfx1152;gfx1200;gfx1201"
CACHE STRING "HIP architectures" FORCE)
endif()
message(
Expand All @@ -39,6 +42,8 @@ get_target_property(ROCTHRUST_INCLUDES roc::rocthrust
INTERFACE_INCLUDE_DIRECTORIES)
get_target_property(ROCPRIM_INCLUDES roc::rocprim INTERFACE_INCLUDE_DIRECTORIES)
get_target_property(HIPRAND_INCLUDES hip::hiprand INTERFACE_INCLUDE_DIRECTORIES)
get_target_property(ROCWMMA_INCLUDES roc::rocwmma
INTERFACE_INCLUDE_DIRECTORIES)

# Find GCC installation for C++ standard library headers ROCm's clang needs to
# know where to find libstdc++ headers
Expand Down Expand Up @@ -101,6 +106,11 @@ foreach(inc ${HIPRAND_INCLUDES})
list(APPEND HIP_INCLUDE_FLAGS "-I${inc}")
endif()
endforeach()
foreach(inc ${ROCWMMA_INCLUDES})
if(inc)
list(APPEND HIP_INCLUDE_FLAGS "-I${inc}")
endif()
endforeach()

message(STATUS "HIP include flags: ${HIP_INCLUDE_FLAGS}")

Expand Down Expand Up @@ -147,6 +157,20 @@ set(HIP_SOURCES
set(HIP_OBJ_DIR "${CMAKE_CURRENT_BINARY_DIR}/hip_objs")
file(MAKE_DIRECTORY ${HIP_OBJ_DIR})

# Detect CPU count for parallel HIP offload compilation
# Use half of available CPUs for parallel HIP offload compilation per file
# (Ninja already parallelizes across files, so this avoids oversubscription)
include(ProcessorCount)
ProcessorCount(NPROC)
if(NPROC EQUAL 0)
set(NPROC 4)
else()
math(EXPR NPROC "${NPROC} / 2")
if(NPROC LESS 2)
set(NPROC 2)
endif()
endif()

# Compile each HIP file to object file using custom commands Use -fno-gpu-rdc to
# avoid needing device link step
set(HIP_OBJECTS "")
Expand All @@ -168,6 +192,7 @@ foreach(hip_src ${HIP_SOURCES})
OUTPUT ${hip_obj}
COMMAND ${CMAKE_HIP_COMPILER} -c ${hip_src} -o ${hip_obj} -fPIC
-DMLX_USE_ROCM ${HIP_ARCH_FLAGS} ${HIP_INCLUDE_FLAGS} -std=c++17
-parallel-jobs=${NPROC}
DEPENDS ${hip_src}
COMMENT "Compiling HIP source ${hip_src}"
VERBATIM)
Expand Down Expand Up @@ -211,7 +236,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/rocblas_gemm.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/gemms/rocblas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/hipblaslt_gemm.cpp)

target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)

Expand Down Expand Up @@ -247,16 +273,21 @@ find_library(AMDHIP64_LIB amdhip64 PATHS ${ROCM_PATH}/lib /opt/rocm/lib
find_library(HIPRTC_LIB hiprtc PATHS ${ROCM_PATH}/lib /opt/rocm/lib
/opt/rocm-6.0.0/lib)

# Find hipBLASLt library (optimized GEMM for half-precision)
find_library(HIPBLASLT_LIB hipblaslt PATHS ${ROCM_PATH}/lib /opt/rocm/lib
/opt/rocm-6.0.0/lib)

message(
STATUS
"ROCm libraries: rocblas=${ROCBLAS_LIB}, hiprand=${HIPRAND_LIB}, amdhip64=${AMDHIP64_LIB}, hiprtc=${HIPRTC_LIB}"
"ROCm libraries: rocblas=${ROCBLAS_LIB}, hiprand=${HIPRAND_LIB}, amdhip64=${AMDHIP64_LIB}, hiprtc=${HIPRTC_LIB}, hipblaslt=${HIPBLASLT_LIB}"
)

# Link the static library and ROCm libraries to mlx We link directly to the .so
# files instead of using CMake targets to avoid propagating compile options like
# -x hip
target_link_libraries(mlx PRIVATE ${HIP_STATIC_LIB} ${AMDHIP64_LIB}
${ROCBLAS_LIB} ${HIPRAND_LIB} ${HIPRTC_LIB})
${ROCBLAS_LIB} ${HIPRAND_LIB} ${HIPRTC_LIB}
${HIPBLASLT_LIB})

# Include ROCm headers for mlx C++ files Get the HIP include directory from the
# hip package
Expand Down
78 changes: 51 additions & 27 deletions mlx/backend/rocm/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,26 @@ static bool rocm_available() {
return available == 1;
}

// Check if managed memory is supported on this device
// Check if managed memory (HMM) is supported on this device.
// On integrated GPUs (Strix Halo), HMM is actually fast since there's no
// discrete VRAM — managed memory avoids the overhead of hipExtMallocWithFlags.
static bool managed_memory_supported() {
// Always return false to force the use of hipHostMalloc (GTT RAM).
// hipMallocManaged uses HMM, which causes implicit page migrations and
// significant memory copying between host and device on access.
// Using hipHostMalloc maps pinned host memory directly to the GPU's address space.
return false;
static int supported = -1;
if (supported < 0) {
if (!rocm_available()) {
supported = 0;
} else {
void* test_ptr = nullptr;
hipError_t err = hipMallocManaged(&test_ptr, 64);
if (err == hipSuccess) {
(void)hipFree(test_ptr);
supported = 1;
} else {
supported = 0;
}
}
}
return supported == 1;
}

static bool is_integrated() {
Expand All @@ -64,18 +77,18 @@ inline void* rocm_unified_malloc(size_t size, bool& is_managed) {
void* data = nullptr;
hipError_t err;
if (is_integrated()) {
// Unified memory device (iGPU/APU): CPU and GPU share system RAM.
// Try hipExtMallocWithFlags first (fine-grained coherent, best GPU
// bandwidth). Falls back to hipMallocManaged for large allocations
// that exceed the small device-local VRAM (~2GB).
err = hipExtMallocWithFlags(&data, size, hipDeviceMallocFinegrained);
is_managed = true; // Use is_managed=true to signify hipFree should be used
if (err != hipSuccess) {
err = hipMallocManaged(&data, size);
}
is_managed = true;
} else if (managed_memory_supported()) {
err = hipMallocManaged(&data, size);
is_managed = true;
if (err == hipSuccess) {
int device_count = 0;
(void)hipGetDeviceCount(&device_count);
for (int i = 0; i < device_count; ++i) {
(void)hipMemAdvise(data, size, hipMemAdviseSetAccessedBy, i);
}
}
} else {
err = hipHostMalloc(&data, size, hipHostMallocDefault);
is_managed = false;
Expand Down Expand Up @@ -193,6 +206,14 @@ Buffer RocmAllocator::malloc(size_t size) {
}

// Find available buffer from cache.
// Use aggressive size rounding to maximize cache hit rate:
// - Small (<=8B): scalar pool
// - Medium (<16KB): power-of-2
// - Large (<1MB): 16KB page aligned
// - Very large (>=1MB): power-of-2 (coarser buckets = more cache hits)
// The power-of-2 rounding for large allocations is critical for decode —
// without it, slightly different sizes (e.g., 1.01MB vs 1.02MB) miss the
// cache and trigger hipExtMallocWithFlags at ~7ms each.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
Expand All @@ -219,14 +240,11 @@ Buffer RocmAllocator::malloc(size_t size) {
lock.unlock();
if (!buf) {
if (is_integrated()) {
buf = new RocmBuffer{nullptr, size, false, -1};
hipError_t err = hipExtMallocWithFlags(&buf->data, size, hipDeviceMallocFinegrained);
if (err != hipSuccess) {
delete buf;
std::ostringstream oss;
oss << "hipExtMallocWithFlags failed: " << hipGetErrorString(err) << ".";
throw std::runtime_error(oss.str());
}
// Integrated GPU: allocate unified memory (CPU+GPU accessible).
// device=-1 signals unified memory — no move_to_unified_memory needed.
bool is_managed = false;
void* data = rocm_unified_malloc(size, is_managed);
buf = new RocmBuffer{data, size, is_managed, -1};
} else {
int device = 0;
hipGetDevice(&device);
Expand Down Expand Up @@ -373,12 +391,18 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
// Synchronize all streams before accessing memory from CPU
// This ensures all GPU operations have completed
(void)hipDeviceSynchronize();

auto& cbuf = *static_cast<rocm::RocmBuffer*>(ptr_);
rocm::allocator().move_to_unified_memory(cbuf);

if (cbuf.device == -1) {
// Unified memory (integrated GPU or hipMallocManaged): CPU-accessible.
// hipStreamSynchronize(nullptr) waits for the default stream — lighter
// than hipDeviceSynchronize which waits for ALL streams.
(void)hipStreamSynchronize(nullptr);
} else {
// Discrete GPU VRAM: full sync + migrate to host-accessible memory.
(void)hipDeviceSynchronize();
rocm::allocator().move_to_unified_memory(cbuf);
}
return cbuf.data;
}

Expand Down
Loading
Loading