diff --git a/mlx/backend/rocm/CMakeLists.txt b/mlx/backend/rocm/CMakeLists.txt index bb66736959..1be84641bb 100644 --- a/mlx/backend/rocm/CMakeLists.txt +++ b/mlx/backend/rocm/CMakeLists.txt @@ -10,17 +10,20 @@ find_package(rocblas REQUIRED CONFIG) find_package(rocthrust REQUIRED CONFIG) find_package(rocprim REQUIRED CONFIG) find_package(hiprand REQUIRED CONFIG) +find_package(rocwmma REQUIRED CONFIG) # Ensure HIP architectures are set - respect user-provided value from command # line The user can set this via -DCMAKE_HIP_ARCHITECTURES=gfx1011 # -# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix: CDNA: -# gfx908 (MI100), gfx90a (MI200), gfx942 (MI300) CDNA4: gfx950 (MI400 series) -# RDNA2: gfx1030 (RX 6000 series) RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600) -# RDNA4: gfx1200, gfx1201 (RX 8000 series) +# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix: +# CDNA: gfx908 (MI100), gfx90a (MI200), gfx942 (MI300) +# RDNA2: gfx1030 (RX 6000 series) +# RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600) +# RDNA3.5: gfx1150, gfx1151, gfx1152 (Ryzen AI / Radeon 8060S) +# RDNA4: gfx1200, gfx1201 (RX 9000 series) if(NOT CMAKE_HIP_ARCHITECTURES) set(CMAKE_HIP_ARCHITECTURES - "gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102" + "gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102;gfx1150;gfx1151;gfx1152;gfx1200;gfx1201" CACHE STRING "HIP architectures" FORCE) endif() message( @@ -39,6 +42,8 @@ get_target_property(ROCTHRUST_INCLUDES roc::rocthrust INTERFACE_INCLUDE_DIRECTORIES) get_target_property(ROCPRIM_INCLUDES roc::rocprim INTERFACE_INCLUDE_DIRECTORIES) get_target_property(HIPRAND_INCLUDES hip::hiprand INTERFACE_INCLUDE_DIRECTORIES) +get_target_property(ROCWMMA_INCLUDES roc::rocwmma + INTERFACE_INCLUDE_DIRECTORIES) # Find GCC installation for C++ standard library headers ROCm's clang needs to # know where to find libstdc++ headers @@ -101,6 +106,11 @@ foreach(inc ${HIPRAND_INCLUDES}) list(APPEND HIP_INCLUDE_FLAGS "-I${inc}") endif() endforeach() +foreach(inc ${ROCWMMA_INCLUDES}) + if(inc) + list(APPEND HIP_INCLUDE_FLAGS "-I${inc}") + endif() +endforeach() message(STATUS "HIP include flags: ${HIP_INCLUDE_FLAGS}") @@ -147,6 +157,20 @@ set(HIP_SOURCES set(HIP_OBJ_DIR "${CMAKE_CURRENT_BINARY_DIR}/hip_objs") file(MAKE_DIRECTORY ${HIP_OBJ_DIR}) +# Detect CPU count for parallel HIP offload compilation +# Use half of available CPUs for parallel HIP offload compilation per file +# (Ninja already parallelizes across files, so this avoids oversubscription) +include(ProcessorCount) +ProcessorCount(NPROC) +if(NPROC EQUAL 0) + set(NPROC 4) +else() + math(EXPR NPROC "${NPROC} / 2") + if(NPROC LESS 2) + set(NPROC 2) + endif() +endif() + # Compile each HIP file to object file using custom commands Use -fno-gpu-rdc to # avoid needing device link step set(HIP_OBJECTS "") @@ -168,6 +192,7 @@ foreach(hip_src ${HIP_SOURCES}) OUTPUT ${hip_obj} COMMAND ${CMAKE_HIP_COMPILER} -c ${hip_src} -o ${hip_obj} -fPIC -DMLX_USE_ROCM ${HIP_ARCH_FLAGS} ${HIP_INCLUDE_FLAGS} -std=c++17 + -parallel-jobs=${NPROC} DEPENDS ${hip_src} COMMENT "Compiling HIP source ${hip_src}" VERBATIM) @@ -211,7 +236,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemms/rocblas_gemm.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/rocblas_gemm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/hipblaslt_gemm.cpp) target_compile_definitions(mlx PRIVATE MLX_USE_ROCM) @@ -247,16 +273,21 @@ find_library(AMDHIP64_LIB amdhip64 PATHS ${ROCM_PATH}/lib /opt/rocm/lib find_library(HIPRTC_LIB hiprtc PATHS ${ROCM_PATH}/lib /opt/rocm/lib /opt/rocm-6.0.0/lib) +# Find hipBLASLt library (optimized GEMM for half-precision) +find_library(HIPBLASLT_LIB hipblaslt PATHS ${ROCM_PATH}/lib /opt/rocm/lib + /opt/rocm-6.0.0/lib) + message( STATUS - "ROCm libraries: rocblas=${ROCBLAS_LIB}, hiprand=${HIPRAND_LIB}, amdhip64=${AMDHIP64_LIB}, hiprtc=${HIPRTC_LIB}" + "ROCm libraries: rocblas=${ROCBLAS_LIB}, hiprand=${HIPRAND_LIB}, amdhip64=${AMDHIP64_LIB}, hiprtc=${HIPRTC_LIB}, hipblaslt=${HIPBLASLT_LIB}" ) # Link the static library and ROCm libraries to mlx We link directly to the .so # files instead of using CMake targets to avoid propagating compile options like # -x hip target_link_libraries(mlx PRIVATE ${HIP_STATIC_LIB} ${AMDHIP64_LIB} - ${ROCBLAS_LIB} ${HIPRAND_LIB} ${HIPRTC_LIB}) + ${ROCBLAS_LIB} ${HIPRAND_LIB} ${HIPRTC_LIB} + ${HIPBLASLT_LIB}) # Include ROCm headers for mlx C++ files Get the HIP include directory from the # hip package diff --git a/mlx/backend/rocm/allocator.cpp b/mlx/backend/rocm/allocator.cpp index cd6bb68683..5393faa609 100644 --- a/mlx/backend/rocm/allocator.cpp +++ b/mlx/backend/rocm/allocator.cpp @@ -35,13 +35,26 @@ static bool rocm_available() { return available == 1; } -// Check if managed memory is supported on this device +// Check if managed memory (HMM) is supported on this device. +// On integrated GPUs (Strix Halo), HMM is actually fast since there's no +// discrete VRAM — managed memory avoids the overhead of hipExtMallocWithFlags. static bool managed_memory_supported() { - // Always return false to force the use of hipHostMalloc (GTT RAM). - // hipMallocManaged uses HMM, which causes implicit page migrations and - // significant memory copying between host and device on access. - // Using hipHostMalloc maps pinned host memory directly to the GPU's address space. - return false; + static int supported = -1; + if (supported < 0) { + if (!rocm_available()) { + supported = 0; + } else { + void* test_ptr = nullptr; + hipError_t err = hipMallocManaged(&test_ptr, 64); + if (err == hipSuccess) { + (void)hipFree(test_ptr); + supported = 1; + } else { + supported = 0; + } + } + } + return supported == 1; } static bool is_integrated() { @@ -64,18 +77,18 @@ inline void* rocm_unified_malloc(size_t size, bool& is_managed) { void* data = nullptr; hipError_t err; if (is_integrated()) { + // Unified memory device (iGPU/APU): CPU and GPU share system RAM. + // Try hipExtMallocWithFlags first (fine-grained coherent, best GPU + // bandwidth). Falls back to hipMallocManaged for large allocations + // that exceed the small device-local VRAM (~2GB). err = hipExtMallocWithFlags(&data, size, hipDeviceMallocFinegrained); - is_managed = true; // Use is_managed=true to signify hipFree should be used + if (err != hipSuccess) { + err = hipMallocManaged(&data, size); + } + is_managed = true; } else if (managed_memory_supported()) { err = hipMallocManaged(&data, size); is_managed = true; - if (err == hipSuccess) { - int device_count = 0; - (void)hipGetDeviceCount(&device_count); - for (int i = 0; i < device_count; ++i) { - (void)hipMemAdvise(data, size, hipMemAdviseSetAccessedBy, i); - } - } } else { err = hipHostMalloc(&data, size, hipHostMallocDefault); is_managed = false; @@ -193,6 +206,14 @@ Buffer RocmAllocator::malloc(size_t size) { } // Find available buffer from cache. + // Use aggressive size rounding to maximize cache hit rate: + // - Small (<=8B): scalar pool + // - Medium (<16KB): power-of-2 + // - Large (<1MB): 16KB page aligned + // - Very large (>=1MB): power-of-2 (coarser buckets = more cache hits) + // The power-of-2 rounding for large allocations is critical for decode — + // without it, slightly different sizes (e.g., 1.01MB vs 1.02MB) miss the + // cache and trigger hipExtMallocWithFlags at ~7ms each. auto orig_size = size; std::unique_lock lock(mutex_); if (size <= small_block_size) { @@ -219,14 +240,11 @@ Buffer RocmAllocator::malloc(size_t size) { lock.unlock(); if (!buf) { if (is_integrated()) { - buf = new RocmBuffer{nullptr, size, false, -1}; - hipError_t err = hipExtMallocWithFlags(&buf->data, size, hipDeviceMallocFinegrained); - if (err != hipSuccess) { - delete buf; - std::ostringstream oss; - oss << "hipExtMallocWithFlags failed: " << hipGetErrorString(err) << "."; - throw std::runtime_error(oss.str()); - } + // Integrated GPU: allocate unified memory (CPU+GPU accessible). + // device=-1 signals unified memory — no move_to_unified_memory needed. + bool is_managed = false; + void* data = rocm_unified_malloc(size, is_managed); + buf = new RocmBuffer{data, size, is_managed, -1}; } else { int device = 0; hipGetDevice(&device); @@ -373,12 +391,18 @@ void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } - // Synchronize all streams before accessing memory from CPU - // This ensures all GPU operations have completed - (void)hipDeviceSynchronize(); - auto& cbuf = *static_cast(ptr_); - rocm::allocator().move_to_unified_memory(cbuf); + + if (cbuf.device == -1) { + // Unified memory (integrated GPU or hipMallocManaged): CPU-accessible. + // hipStreamSynchronize(nullptr) waits for the default stream — lighter + // than hipDeviceSynchronize which waits for ALL streams. + (void)hipStreamSynchronize(nullptr); + } else { + // Discrete GPU VRAM: full sync + migrate to host-accessible memory. + (void)hipDeviceSynchronize(); + rocm::allocator().move_to_unified_memory(cbuf); + } return cbuf.data; } diff --git a/mlx/backend/rocm/compiled.cpp b/mlx/backend/rocm/compiled.cpp index da9c28b2be..db0b67560e 100644 --- a/mlx/backend/rocm/compiled.cpp +++ b/mlx/backend/rocm/compiled.cpp @@ -442,25 +442,34 @@ __device__ inline __half atan2f(__half y, __half x) { // Include device operations namespace mlx::core::rocm { -// Binary ops +// Binary ops — promote half/bfloat16 through float to avoid precision loss +// that compounds across 28-36 transformer layers in LLM inference. struct Add { template - __device__ T operator()(T x, T y) { return x + y; } + __device__ T operator()(T x, T y) { + return T(static_cast(x) + static_cast(y)); + } }; struct Subtract { template - __device__ T operator()(T x, T y) { return x - y; } + __device__ T operator()(T x, T y) { + return T(static_cast(x) - static_cast(y)); + } }; struct Multiply { template - __device__ T operator()(T x, T y) { return x * y; } + __device__ T operator()(T x, T y) { + return T(static_cast(x) * static_cast(y)); + } }; struct Divide { template - __device__ T operator()(T x, T y) { return x / y; } + __device__ T operator()(T x, T y) { + return T(static_cast(x) / static_cast(y)); + } }; struct Maximum { @@ -475,7 +484,9 @@ struct Minimum { struct Power { template - __device__ T operator()(T base, T exp) { return powf(base, exp); } + __device__ T operator()(T base, T exp) { + return T(powf(static_cast(base), static_cast(exp))); + } }; struct Equal { @@ -520,17 +531,23 @@ struct LogicalOr { struct ArcTan2 { template - __device__ T operator()(T y, T x) { return atan2f(y, x); } + __device__ T operator()(T y, T x) { + return T(atan2f(static_cast(y), static_cast(x))); + } }; struct Remainder { template - __device__ T operator()(T x, T y) { return fmodf(x, y); } + __device__ T operator()(T x, T y) { + return T(fmodf(static_cast(x), static_cast(y))); + } }; struct FloorDivide { template - __device__ T operator()(T x, T y) { return truncf(x / y); } + __device__ T operator()(T x, T y) { + return T(truncf(static_cast(x) / static_cast(y))); + } }; struct LogAddExp { @@ -552,9 +569,11 @@ struct LogAddExp { template __device__ T operator()(T x, T y) { - T maxval = x > y ? x : y; - T minval = x > y ? y : x; - return static_cast(maxval + log1pf(expf(minval - maxval))); + float fx = static_cast(x); + float fy = static_cast(y); + float maxval = fx > fy ? fx : fy; + float minval = fx > fy ? fy : fx; + return T(maxval + log1pf(expf(minval - maxval))); } }; @@ -583,26 +602,21 @@ struct RightShift { __device__ T operator()(T x, T y) { return x >> y; } }; -// Unary ops -struct Abs { - template - __device__ T operator()(T x) { return abs(x); } -}; - -struct Exp { - template - __device__ T operator()(T x) { return exp(x); } -}; - -struct Log { - template - __device__ T operator()(T x) { return log(x); } +// All unary math ops promote through float to support half/bfloat16. +// For float inputs the static_cast is a no-op. +#define UNARY_FLOAT_OP(name, op) \ +struct name { \ + template \ + __device__ T operator()(T x) { \ + return T(op(static_cast(x))); \ + } \ }; -struct Sqrt { - template - __device__ T operator()(T x) { return sqrt(x); } -}; +// Unary ops +UNARY_FLOAT_OP(Abs, fabsf) +UNARY_FLOAT_OP(Exp, expf) +UNARY_FLOAT_OP(Log, logf) +UNARY_FLOAT_OP(Sqrt, sqrtf) struct Negative { template @@ -611,7 +625,10 @@ struct Negative { struct Square { template - __device__ T operator()(T x) { return x * x; } + __device__ T operator()(T x) { + float fx = static_cast(x); + return T(fx * fx); + } }; struct Sigmoid { @@ -629,125 +646,43 @@ struct Sigmoid { template __device__ T operator()(T x) { - T y = T(1) / (T(1) + exp(-abs(x))); - return (x < T(0)) ? (T(1) - y) : y; + float fx = static_cast(x); + float y = 1.0f / (1.0f + expf(-fabsf(fx))); + return T((fx < 0.0f) ? 1.0f - y : y); } }; -struct Tanh { - template - __device__ T operator()(T x) { return tanh(x); } -}; - -struct Sin { - template - __device__ T operator()(T x) { return sin(x); } -}; - -struct Cos { - template - __device__ T operator()(T x) { return cos(x); } -}; - -struct Tan { - template - __device__ T operator()(T x) { return tan(x); } -}; - -struct Sinh { - template - __device__ T operator()(T x) { return sinh(x); } -}; - -struct Cosh { - template - __device__ T operator()(T x) { return cosh(x); } -}; - -struct Erf { - template - __device__ T operator()(T x) { return erff(x); } -}; - -struct ErfInv { - template - __device__ T operator()(T x) { return erfinvf(x); } -}; - -struct Expm1 { - template - __device__ T operator()(T x) { return expm1f(x); } -}; - -struct Log1p { - template - __device__ T operator()(T x) { return log1pf(x); } -}; - -struct Log2 { - template - __device__ T operator()(T x) { return log2(x); } -}; - -struct Log10 { - template - __device__ T operator()(T x) { return log10(x); } -}; - -struct Ceil { - template - __device__ T operator()(T x) { return ceil(x); } -}; - -struct Floor { - template - __device__ T operator()(T x) { return floor(x); } -}; - -struct Round { - template - __device__ T operator()(T x) { return rint(x); } -}; - -struct Rsqrt { - template - __device__ T operator()(T x) { return rsqrt(x); } -}; +UNARY_FLOAT_OP(Tanh, tanhf) +UNARY_FLOAT_OP(Sin, sinf) +UNARY_FLOAT_OP(Cos, cosf) +UNARY_FLOAT_OP(Tan, tanf) +UNARY_FLOAT_OP(Sinh, sinhf) +UNARY_FLOAT_OP(Cosh, coshf) +UNARY_FLOAT_OP(Erf, erff) +UNARY_FLOAT_OP(ErfInv, erfinvf) +UNARY_FLOAT_OP(Expm1, expm1f) +UNARY_FLOAT_OP(Log1p, log1pf) +UNARY_FLOAT_OP(Log2, log2f) +UNARY_FLOAT_OP(Log10, log10f) +UNARY_FLOAT_OP(Ceil, ceilf) +UNARY_FLOAT_OP(Floor, floorf) +UNARY_FLOAT_OP(Round, rintf) +UNARY_FLOAT_OP(Rsqrt, rsqrtf) struct Sign { template - __device__ T operator()(T x) { return T((x > T(0)) - (x < T(0))); } -}; - -struct Asin { - template - __device__ T operator()(T x) { return asin(x); } -}; - -struct Acos { - template - __device__ T operator()(T x) { return acos(x); } -}; - -struct Atan { - template - __device__ T operator()(T x) { return atan(x); } -}; - -struct Asinh { - template - __device__ T operator()(T x) { return asinh(x); } -}; - -struct Acosh { - template - __device__ T operator()(T x) { return acosh(x); } + __device__ T operator()(T x) { + float fx = static_cast(x); + return T((fx > 0.0f) - (fx < 0.0f)); + } }; -struct Atanh { - template - __device__ T operator()(T x) { return atanh(x); } -}; +UNARY_FLOAT_OP(Asin, asinf) +UNARY_FLOAT_OP(Acos, acosf) +UNARY_FLOAT_OP(Atan, atanf) +UNARY_FLOAT_OP(Asinh, asinhf) +UNARY_FLOAT_OP(Acosh, acoshf) +UNARY_FLOAT_OP(Atanh, atanhf) struct LogicalNot { template @@ -759,9 +694,11 @@ struct BitwiseNot { __device__ T operator()(T x) { return ~x; } }; +#undef UNARY_FLOAT_OP + struct Reciprocal { template - __device__ T operator()(T x) { return T(1) / x; } + __device__ T operator()(T x) { return T(1.0f / static_cast(x)); } }; // Ternary ops diff --git a/mlx/backend/rocm/device.cpp b/mlx/backend/rocm/device.cpp index 9254b6ba18..de9f1c89a9 100644 --- a/mlx/backend/rocm/device.cpp +++ b/mlx/backend/rocm/device.cpp @@ -42,8 +42,7 @@ rocblas_handle Device::get_rocblas_handle() { std::string arch_name = props.gcnArchName; // List of architectures supported by rocBLAS (based on TensileLibrary - // files) These are the architectures that have TensileLibrary_lazy_*.dat - // files + // files). These are the architectures that have TensileLibrary_lazy_*.dat. static const std::vector supported_archs = { "gfx908", "gfx90a", @@ -55,6 +54,7 @@ rocblas_handle Device::get_rocblas_handle() { "gfx1102", "gfx1150", "gfx1151", + "gfx1152", "gfx1200", "gfx1201"}; @@ -105,16 +105,90 @@ rocblas_handle Device::get_rocblas_handle() { bool Device::is_rocblas_available() { if (!rocblas_initialized_) { - // Trigger initialization to check availability try { get_rocblas_handle(); } catch (...) { - // Ignore exception, rocblas_available_ is already set } } return rocblas_available_; } +bool Device::is_rocblas_bf16_available() { + if (!rocblas_bf16_probed_) { + rocblas_bf16_probed_ = true; + rocblas_bf16_available_ = false; + + if (!is_rocblas_available()) { + return false; + } + + // Probe: run a tiny bf16 GEMM and check if the GPU survives. + // rocBLAS may claim support but crash if the Tensile .co files + // are corrupt or missing specific kernel variants. + make_current(); + void* a_ptr = nullptr; + void* b_ptr = nullptr; + void* c_ptr = nullptr; + hipError_t err; + + err = hipMalloc(&a_ptr, 4 * 4 * 2); // 4x4 bf16 + if (err != hipSuccess) return false; + err = hipMalloc(&b_ptr, 4 * 4 * 2); + if (err != hipSuccess) { hipFree(a_ptr); return false; } + err = hipMalloc(&c_ptr, 4 * 4 * 2); + if (err != hipSuccess) { hipFree(a_ptr); hipFree(b_ptr); return false; } + + (void)hipMemset(a_ptr, 0, 4 * 4 * 2); + (void)hipMemset(b_ptr, 0, 4 * 4 * 2); + (void)hipMemset(c_ptr, 0, 4 * 4 * 2); + + float alpha = 1.0f, beta = 0.0f; + rocblas_status status = rocblas_gemm_ex( + rocblas_, + rocblas_operation_none, + rocblas_operation_none, + 4, 4, 4, + &alpha, + a_ptr, rocblas_datatype_bf16_r, 4, + b_ptr, rocblas_datatype_bf16_r, 4, + &beta, + c_ptr, rocblas_datatype_bf16_r, 4, + c_ptr, rocblas_datatype_bf16_r, 4, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, 0); + + // Sync and check if the GPU is still alive + hipError_t sync_err = hipDeviceSynchronize(); + // Clear any lingering error + (void)hipGetLastError(); + + hipFree(a_ptr); + hipFree(b_ptr); + hipFree(c_ptr); + + if (status == rocblas_status_success && sync_err == hipSuccess) { + rocblas_bf16_available_ = true; + } else { + // GPU may be in a bad state — need to reset + (void)hipDeviceReset(); + // Re-initialize device + make_current(); + // Re-create rocBLAS handle + if (rocblas_) { + rocblas_destroy_handle(rocblas_); + rocblas_ = nullptr; + } + rocblas_status rs = rocblas_create_handle(&rocblas_); + if (rs != rocblas_status_success) { + rocblas_available_ = false; + } + std::cerr << "Warning: rocBLAS bfloat16 GEMM probe failed on this GPU. " + << "Using fallback kernels for bf16 matmul." << std::endl; + } + } + return rocblas_bf16_available_; +} + void Device::make_current() { // We need to set/get current HIP device very frequently, cache it to reduce // actual calls of HIP APIs. This function assumes single-thread in host. @@ -193,6 +267,59 @@ void CommandEncoder::synchronize() { f.wait(); } +void CommandEncoder::begin_capture() { + if (capturing_) return; + device_.make_current(); + // hipStreamBeginCapture records all subsequent operations on this stream + // into a graph instead of executing them. + hipError_t err = hipStreamBeginCapture(stream_, hipStreamCaptureModeGlobal); + if (err == hipSuccess) { + capturing_ = true; + } +} + +bool CommandEncoder::end_capture() { + if (!capturing_) return false; + capturing_ = false; + + hipGraph_t new_graph = nullptr; + hipError_t err = hipStreamEndCapture(stream_, &new_graph); + if (err != hipSuccess || new_graph == nullptr) { + return false; + } + + // Destroy previous graph if any + reset_graph(); + + graph_ = new_graph; + err = hipGraphInstantiate(&graph_exec_, graph_, nullptr, nullptr, 0); + if (err != hipSuccess) { + hipGraphDestroy(graph_); + graph_ = nullptr; + graph_exec_ = nullptr; + return false; + } + return true; +} + +bool CommandEncoder::replay() { + if (!graph_exec_) return false; + device_.make_current(); + hipError_t err = hipGraphLaunch(graph_exec_, stream_); + return err == hipSuccess; +} + +void CommandEncoder::reset_graph() { + if (graph_exec_) { + hipGraphExecDestroy(graph_exec_); + graph_exec_ = nullptr; + } + if (graph_) { + hipGraphDestroy(graph_); + graph_ = nullptr; + } +} + Device& device(mlx::core::Device device) { static std::unordered_map devices; static bool flags_set = false; diff --git a/mlx/backend/rocm/device.h b/mlx/backend/rocm/device.h index 1e75eeb963..de40f793a6 100644 --- a/mlx/backend/rocm/device.h +++ b/mlx/backend/rocm/device.h @@ -58,6 +58,25 @@ class CommandEncoder { // Wait until kernels and completion handlers are finished void synchronize(); + // --- Graph capture API --- + // Begin recording all kernel launches into a HIP graph. + // While capturing, launch_kernel dispatches are recorded (not executed). + void begin_capture(); + + // End recording and instantiate the captured graph. + // Returns true if capture succeeded (graph is ready to replay). + bool end_capture(); + + // Replay the previously captured graph. All recorded kernels execute + // in a single GPU dispatch. Returns false if no graph is available. + bool replay(); + + // Returns true if a captured graph is ready to replay. + bool has_graph() const { return graph_exec_ != nullptr; } + + // Discard the captured graph. + void reset_graph(); + private: Device& device_; HipStream stream_; @@ -65,6 +84,9 @@ class CommandEncoder { int node_count_{0}; std::vector> temporaries_; std::unordered_set temporary_ptrs_; + bool capturing_{false}; + hipGraph_t graph_{nullptr}; + hipGraphExec_t graph_exec_{nullptr}; }; class Device { @@ -90,12 +112,17 @@ class Device { // Check if rocBLAS is available for the current GPU architecture bool is_rocblas_available(); + // Check if rocBLAS bf16 GEMM works on this device (probed at init) + bool is_rocblas_bf16_available(); + private: int device_; rocblas_handle rocblas_{nullptr}; hipStream_t rocblas_stream_{nullptr}; bool rocblas_initialized_{false}; bool rocblas_available_{true}; + bool rocblas_bf16_probed_{false}; + bool rocblas_bf16_available_{false}; std::unordered_map> encoders_; }; @@ -114,6 +141,9 @@ inline auto thrust_policy(hipStream_t stream) { template void CommandEncoder::launch_kernel(F&& func) { device_.make_current(); + // When capturing, kernel launches are recorded into the HIP graph + // automatically via hipStreamBeginCapture. No special handling needed — + // hipLaunchKernel on a capturing stream records instead of executing. func(static_cast(stream_)); node_count_++; } diff --git a/mlx/backend/rocm/device/binary_ops.hpp b/mlx/backend/rocm/device/binary_ops.hpp index f07f3a7cb4..59dd1c8e69 100644 --- a/mlx/backend/rocm/device/binary_ops.hpp +++ b/mlx/backend/rocm/device/binary_ops.hpp @@ -13,6 +13,10 @@ struct Add { __device__ T operator()(T x, T y) { if constexpr (is_complex_v) { return hipCaddf(x, y); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(static_cast(x) + static_cast(y)); + } else if constexpr (std::is_same_v) { + return __float2half(__half2float(x) + __half2float(y)); } else { return x + y; } @@ -40,6 +44,10 @@ struct Divide { __device__ T operator()(T x, T y) { if constexpr (is_complex_v) { return hipCdivf(x, y); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(static_cast(x) / static_cast(y)); + } else if constexpr (std::is_same_v) { + return __float2half(__half2float(x) / __half2float(y)); } else { return x / y; } @@ -289,6 +297,10 @@ struct Multiply { __device__ T operator()(T x, T y) { if constexpr (is_complex_v) { return hipCmulf(x, y); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(static_cast(x) * static_cast(y)); + } else if constexpr (std::is_same_v) { + return __float2half(__half2float(x) * __half2float(y)); } else { return x * y; } @@ -350,6 +362,10 @@ struct Subtract { __device__ T operator()(T x, T y) { if constexpr (is_complex_v) { return hipCsubf(x, y); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(static_cast(x) - static_cast(y)); + } else if constexpr (std::is_same_v) { + return __float2half(__half2float(x) - __half2float(y)); } else { return x - y; } diff --git a/mlx/backend/rocm/device/unary_ops.hpp b/mlx/backend/rocm/device/unary_ops.hpp index 04e677f201..3b31c75303 100644 --- a/mlx/backend/rocm/device/unary_ops.hpp +++ b/mlx/backend/rocm/device/unary_ops.hpp @@ -38,6 +38,8 @@ struct ArcCos { return ::acosf(x); } else if constexpr (std::is_same_v) { return ::acos(x); + } else if constexpr (std::is_same_v) { + return __float2half(acosf(__half2float(x))); } else { return acos(x); } @@ -51,6 +53,8 @@ struct ArcCosh { return ::acoshf(x); } else if constexpr (std::is_same_v) { return ::acosh(x); + } else if constexpr (std::is_same_v) { + return __float2half(acoshf(__half2float(x))); } else { return acosh(x); } @@ -64,6 +68,8 @@ struct ArcSin { return ::asinf(x); } else if constexpr (std::is_same_v) { return ::asin(x); + } else if constexpr (std::is_same_v) { + return __float2half(asinf(__half2float(x))); } else { return asin(x); } @@ -77,6 +83,8 @@ struct ArcSinh { return ::asinhf(x); } else if constexpr (std::is_same_v) { return ::asinh(x); + } else if constexpr (std::is_same_v) { + return __float2half(asinhf(__half2float(x))); } else { return asinh(x); } @@ -90,6 +98,8 @@ struct ArcTan { return ::atanf(x); } else if constexpr (std::is_same_v) { return ::atan(x); + } else if constexpr (std::is_same_v) { + return __float2half(atanf(__half2float(x))); } else { return atan(x); } @@ -103,6 +113,8 @@ struct ArcTanh { return ::atanhf(x); } else if constexpr (std::is_same_v) { return ::atanh(x); + } else if constexpr (std::is_same_v) { + return __float2half(atanhf(__half2float(x))); } else { return atanh(x); } @@ -157,6 +169,8 @@ struct Cos { return cosf(x); } else if constexpr (std::is_same_v) { return ::cos(x); + } else if constexpr (std::is_same_v) { + return __float2half(cosf(__half2float(x))); } else { return cos(x); } @@ -170,6 +184,8 @@ struct Cosh { return ::coshf(x); } else if constexpr (std::is_same_v) { return ::cosh(x); + } else if constexpr (std::is_same_v) { + return __float2half(coshf(__half2float(x))); } else { return cosh(x); } @@ -213,6 +229,8 @@ struct Exp { return expf(x); } else if constexpr (std::is_same_v) { return ::exp(x); + } else if constexpr (std::is_same_v) { + return __float2half(expf(__half2float(x))); } else { return exp(x); } @@ -270,6 +288,8 @@ struct Log { return logf(x); } else if constexpr (std::is_same_v) { return ::log(x); + } else if constexpr (std::is_same_v) { + return __float2half(logf(__half2float(x))); } else { return log(x); } @@ -287,6 +307,8 @@ struct Log2 { return ::log2f(x); } else if constexpr (std::is_same_v) { return ::log2(x); + } else if constexpr (std::is_same_v) { + return __float2half(log2f(__half2float(x))); } else { return log2(x); } @@ -300,6 +322,8 @@ struct Log10 { return ::log10f(x); } else if constexpr (std::is_same_v) { return ::log10(x); + } else if constexpr (std::is_same_v) { + return __float2half(log10f(__half2float(x))); } else { return log10(x); } @@ -427,6 +451,8 @@ struct Sin { return sinf(x); } else if constexpr (std::is_same_v) { return ::sin(x); + } else if constexpr (std::is_same_v) { + return __float2half(sinf(__half2float(x))); } else { return sin(x); } @@ -440,6 +466,8 @@ struct Sinh { return ::sinhf(x); } else if constexpr (std::is_same_v) { return ::sinh(x); + } else if constexpr (std::is_same_v) { + return __float2half(sinhf(__half2float(x))); } else { return sinh(x); } @@ -451,6 +479,12 @@ struct Square { __device__ T operator()(T x) { if constexpr (is_complex_v) { return hipCmulf(x, x); + } else if constexpr (std::is_same_v) { + float fx = static_cast(x); + return hip_bfloat16(fx * fx); + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + return __float2half(fx * fx); } else { return x * x; } @@ -464,6 +498,8 @@ struct Sqrt { return ::sqrtf(x); } else if constexpr (std::is_same_v) { return ::sqrt(x); + } else if constexpr (std::is_same_v) { + return __float2half(sqrtf(__half2float(x))); } else { return sqrt(x); } @@ -479,6 +515,8 @@ struct Rsqrt { return ::rsqrtf(x); } else if constexpr (std::is_same_v) { return ::rsqrt(x); + } else if constexpr (std::is_same_v) { + return __float2half(rsqrtf(__half2float(x))); } else { return rsqrt(x); } @@ -492,6 +530,8 @@ struct Tan { return ::tanf(x); } else if constexpr (std::is_same_v) { return ::tan(x); + } else if constexpr (std::is_same_v) { + return __float2half(tanf(__half2float(x))); } else { return tan(x); } @@ -505,6 +545,8 @@ struct Tanh { return ::tanhf(x); } else if constexpr (std::is_same_v) { return ::tanh(x); + } else if constexpr (std::is_same_v) { + return __float2half(tanhf(__half2float(x))); } else { return tanh(x); } diff --git a/mlx/backend/rocm/eval.cpp b/mlx/backend/rocm/eval.cpp index 2f526ca9de..825941fa20 100644 --- a/mlx/backend/rocm/eval.cpp +++ b/mlx/backend/rocm/eval.cpp @@ -6,8 +6,15 @@ #include "mlx/backend/rocm/event.h" #include "mlx/primitives.h" +#include + namespace mlx::core::gpu { +void init() { + // Force initialization of ROCm runtime + hipFree(nullptr); +} + void new_stream(Stream s) { // Force initialization of ROCm by creating an event, so the HIP runtime and // our HIP event pool get destroyed last. diff --git a/mlx/backend/rocm/gemms/hipblaslt_gemm.cpp b/mlx/backend/rocm/gemms/hipblaslt_gemm.cpp new file mode 100644 index 0000000000..935128ec60 --- /dev/null +++ b/mlx/backend/rocm/gemms/hipblaslt_gemm.cpp @@ -0,0 +1,548 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/gemms/hipblaslt_gemm.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" + +#include +#include + +#include +#include +#include + +namespace mlx::core::rocm { + +namespace { + +// Maximum workspace size for hipBLASLt algorithms (32 MB). +// hipBLASLt may request scratch memory for certain algorithm choices. +constexpr size_t kMaxWorkspaceBytes = 32u * 1024u * 1024u; + +// Per-device hipBLASLt handle cache. Lazily initialised, thread-safe. +struct HipblasltState { + hipblasLtHandle_t handle{nullptr}; + bool initialized{false}; + bool available{false}; + std::mutex mutex; + + // Persistent workspace allocation (grown as needed, never shrunk). + void* workspace{nullptr}; + size_t workspace_size{0}; +}; + +// One state per device (indexed by HIP device ordinal). +// 16 devices should be more than enough for any system. +static constexpr int kMaxDevices = 16; +static HipblasltState g_state[kMaxDevices]; + +HipblasltState& get_state(int device_id) { + if (device_id < 0 || device_id >= kMaxDevices) { + throw std::runtime_error( + "hipBLASLt: device id out of range: " + std::to_string(device_id)); + } + return g_state[device_id]; +} + +// Initialise the hipBLASLt handle for the given device. +// Must be called with state.mutex held. +void init_handle(HipblasltState& state, int device_id) { + if (state.initialized) { + return; + } + state.initialized = true; + + hipblasStatus_t status = hipblasLtCreate(&state.handle); + if (status != HIPBLAS_STATUS_SUCCESS) { + state.available = false; + state.handle = nullptr; + std::cerr << "Warning: hipBLASLt initialization failed (status " + << static_cast(status) << ")." << std::endl; + return; + } + state.available = true; +} + +hipblasLtHandle_t get_handle(int device_id) { + auto& state = get_state(device_id); + if (!state.initialized) { + std::lock_guard lock(state.mutex); + init_handle(state, device_id); + } + if (!state.available) { + throw std::runtime_error("hipBLASLt is not available on this device."); + } + return state.handle; +} + +// Ensure the per-device workspace is at least `required` bytes. +// Returns the workspace pointer and the actual allocated size. +// Must be called from within a launch_kernel callback (i.e., on the +// stream-submission thread for this device), so no extra locking is needed +// beyond the device serialisation that CommandEncoder already provides. +std::pair ensure_workspace(int device_id, size_t required) { + auto& state = get_state(device_id); + if (required <= state.workspace_size && state.workspace != nullptr) { + return {state.workspace, state.workspace_size}; + } + // Free old allocation (hipFree is a no-op on nullptr). + if (state.workspace) { + (void)hipFree(state.workspace); + state.workspace = nullptr; + state.workspace_size = 0; + } + if (required == 0) { + return {nullptr, 0}; + } + hipError_t err = hipMalloc(&state.workspace, required); + if (err != hipSuccess) { + state.workspace = nullptr; + state.workspace_size = 0; + return {nullptr, 0}; + } + state.workspace_size = required; + return {state.workspace, state.workspace_size}; +} + +hipDataType to_hipblaslt_dtype(Dtype dtype) { + switch (dtype) { + case float32: + return HIP_R_32F; + case float16: + return HIP_R_16F; + case bfloat16: + return HIP_R_16BF; + default: + throw std::runtime_error("Unsupported dtype for hipBLASLt GEMM"); + } +} + +hipblasOperation_t to_hipblas_op(bool transpose) { + return transpose ? HIPBLAS_OP_T : HIPBLAS_OP_N; +} + +// RAII wrappers for hipBLASLt descriptors to avoid leaks on error paths. +struct MatmulDescGuard { + hipblasLtMatmulDesc_t desc{nullptr}; + ~MatmulDescGuard() { + if (desc) + hipblasLtMatmulDescDestroy(desc); + } +}; +struct MatrixLayoutGuard { + hipblasLtMatrixLayout_t layout{nullptr}; + ~MatrixLayoutGuard() { + if (layout) + hipblasLtMatrixLayoutDestroy(layout); + } +}; +struct PreferenceGuard { + hipblasLtMatmulPreference_t pref{nullptr}; + ~PreferenceGuard() { + if (pref) + hipblasLtMatmulPreferenceDestroy(pref); + } +}; + +// Core implementation: set up descriptors, find the best algorithm, and +// execute the matmul on the given stream. +void hipblaslt_gemm_impl( + hipblasLtHandle_t handle, + int device_id, + hipblasOperation_t op_a, + hipblasOperation_t op_b, + int M, + int N, + int K, + const float* alpha, + const void* a_ptr, + int lda, + int64_t stride_a, + const void* b_ptr, + int ldb, + int64_t stride_b, + const float* beta, + void* c_ptr, + int ldc, + int64_t stride_c, + int batch_count, + hipDataType data_type, + hipStream_t stream) { + hipblasStatus_t status; + + // Compute type: always fp32 accumulation for half-precision inputs. + hipblasComputeType_t compute_type = HIPBLAS_COMPUTE_32F; + hipDataType scale_type = HIP_R_32F; + + // --- Matmul descriptor --- + MatmulDescGuard matmul_guard; + status = + hipblasLtMatmulDescCreate(&matmul_guard.desc, compute_type, scale_type); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatmulDescCreate failed: " + + std::to_string(static_cast(status))); + } + + // Set transpose attributes. + int32_t trans_a_val = static_cast(op_a); + int32_t trans_b_val = static_cast(op_b); + hipblasLtMatmulDescSetAttribute( + matmul_guard.desc, + HIPBLASLT_MATMUL_DESC_TRANSA, + &trans_a_val, + sizeof(trans_a_val)); + hipblasLtMatmulDescSetAttribute( + matmul_guard.desc, + HIPBLASLT_MATMUL_DESC_TRANSB, + &trans_b_val, + sizeof(trans_b_val)); + + // --- Matrix layouts (column-major, as expected by BLAS) --- + // A is (op_a == N) ? M x K : K x M in column-major + // B is (op_b == N) ? K x N : N x K in column-major + // C is M x N in column-major + uint64_t a_rows = (op_a == HIPBLAS_OP_N) ? M : K; + uint64_t a_cols = (op_a == HIPBLAS_OP_N) ? K : M; + uint64_t b_rows = (op_b == HIPBLAS_OP_N) ? K : N; + uint64_t b_cols = (op_b == HIPBLAS_OP_N) ? N : K; + + MatrixLayoutGuard layout_a, layout_b, layout_c, layout_d; + + status = hipblasLtMatrixLayoutCreate( + &layout_a.layout, data_type, a_rows, a_cols, lda); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatrixLayoutCreate(A) failed: " + + std::to_string(static_cast(status))); + } + + status = hipblasLtMatrixLayoutCreate( + &layout_b.layout, data_type, b_rows, b_cols, ldb); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatrixLayoutCreate(B) failed: " + + std::to_string(static_cast(status))); + } + + status = hipblasLtMatrixLayoutCreate( + &layout_c.layout, data_type, M, N, ldc); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatrixLayoutCreate(C) failed: " + + std::to_string(static_cast(status))); + } + + // D has the same layout as C (in-place: D == C). + status = hipblasLtMatrixLayoutCreate( + &layout_d.layout, data_type, M, N, ldc); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatrixLayoutCreate(D) failed: " + + std::to_string(static_cast(status))); + } + + // Set batch attributes when doing strided batched GEMM. + if (batch_count > 1) { + int32_t bc = batch_count; + hipblasLtMatrixLayoutSetAttribute( + layout_a.layout, + HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &bc, + sizeof(bc)); + hipblasLtMatrixLayoutSetAttribute( + layout_a.layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride_a, + sizeof(stride_a)); + + hipblasLtMatrixLayoutSetAttribute( + layout_b.layout, + HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &bc, + sizeof(bc)); + hipblasLtMatrixLayoutSetAttribute( + layout_b.layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride_b, + sizeof(stride_b)); + + hipblasLtMatrixLayoutSetAttribute( + layout_c.layout, + HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &bc, + sizeof(bc)); + hipblasLtMatrixLayoutSetAttribute( + layout_c.layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride_c, + sizeof(stride_c)); + + hipblasLtMatrixLayoutSetAttribute( + layout_d.layout, + HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &bc, + sizeof(bc)); + hipblasLtMatrixLayoutSetAttribute( + layout_d.layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride_c, + sizeof(stride_c)); + } + + // --- Algorithm selection via heuristic --- + PreferenceGuard pref_guard; + status = hipblasLtMatmulPreferenceCreate(&pref_guard.pref); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatmulPreferenceCreate failed: " + + std::to_string(static_cast(status))); + } + + uint64_t max_ws = kMaxWorkspaceBytes; + hipblasLtMatmulPreferenceSetAttribute( + pref_guard.pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_ws, + sizeof(max_ws)); + + hipblasLtMatmulHeuristicResult_t heuristic; + int returned_algo_count = 0; + + status = hipblasLtMatmulAlgoGetHeuristic( + handle, + matmul_guard.desc, + layout_a.layout, + layout_b.layout, + layout_c.layout, + layout_d.layout, + pref_guard.pref, + 1, // requestedAlgoCount + &heuristic, + &returned_algo_count); + + if (status != HIPBLAS_STATUS_SUCCESS || returned_algo_count == 0) { + throw std::runtime_error( + "hipblasLtMatmulAlgoGetHeuristic failed (status=" + + std::to_string(static_cast(status)) + + ", returned=" + std::to_string(returned_algo_count) + ")"); + } + + // --- Workspace allocation --- + size_t ws_needed = heuristic.workspaceSize; + void* ws_ptr = nullptr; + size_t ws_actual = 0; + if (ws_needed > 0) { + auto [p, s] = ensure_workspace(device_id, ws_needed); + ws_ptr = p; + ws_actual = s; + if (ws_ptr == nullptr && ws_needed > 0) { + throw std::runtime_error( + "hipBLASLt: failed to allocate workspace of " + + std::to_string(ws_needed) + " bytes"); + } + } + + // --- Execute the matmul --- + status = hipblasLtMatmul( + handle, + matmul_guard.desc, + alpha, + a_ptr, + layout_a.layout, + b_ptr, + layout_b.layout, + beta, + c_ptr, + layout_c.layout, + c_ptr, // D == C (in-place) + layout_d.layout, + &heuristic.algo, + ws_ptr, + ws_actual, + stream); + + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatmul failed: " + + std::to_string(static_cast(status))); + } +} + +} // namespace + +bool is_hipblaslt_available() { + int device_id = 0; + (void)hipGetDevice(&device_id); + auto& state = get_state(device_id); + if (!state.initialized) { + std::lock_guard lock(state.mutex); + init_handle(state, device_id); + } + return state.available; +} + +void hipblaslt_gemm( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + const array& b, + int ldb, + float beta, + array& c, + int ldc, + Dtype dtype) { + int device_id = encoder.device().hip_device(); + hipblasLtHandle_t handle = get_handle(device_id); + hipDataType hip_dtype = to_hipblaslt_dtype(dtype); + + // hipBLASLt uses column-major layout. MLX stores row-major, so we swap A + // and B and compute C^T = B^T * A^T, just like the rocBLAS path. + hipblasOperation_t op_a = to_hipblas_op(transpose_b); + hipblasOperation_t op_b = to_hipblas_op(transpose_a); + + static bool dbg = []{ + fprintf(stderr, "[hipBLASLt] first call\n"); + return true; + }(); + (void)dbg; + fprintf(stderr, "[hipBLASLt] M=%d N=%d K=%d ta=%d tb=%d lda=%d ldb=%d ldc=%d\n", + M, N, K, (int)transpose_a, (int)transpose_b, lda, ldb, ldc); + + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* c_ptr = gpu_ptr(c); + + encoder.launch_kernel( + [=, &encoder](hipStream_t stream) { + hipblaslt_gemm_impl( + handle, + device_id, + op_a, + op_b, + N, // swap M/N for col-major trick + M, + K, + &alpha, + b_ptr, // swap A/B + ldb, + 0, // stride_a (unused for non-batched) + a_ptr, + lda, + 0, // stride_b (unused for non-batched) + &beta, + c_ptr, + ldc, + 0, // stride_c (unused for non-batched) + 1, // batch_count + hip_dtype, + stream); + }); +} + +void hipblaslt_gemm_batched( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + int64_t stride_a, + const array& b, + int ldb, + int64_t stride_b, + float beta, + array& c, + int ldc, + int64_t stride_c, + int batch_count, + Dtype dtype) { + int device_id = encoder.device().hip_device(); + hipblasLtHandle_t handle = get_handle(device_id); + hipDataType hip_dtype = to_hipblaslt_dtype(dtype); + + // Same column-major swap as above. + hipblasOperation_t op_a = to_hipblas_op(transpose_b); + hipblasOperation_t op_b = to_hipblas_op(transpose_a); + + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* c_ptr = gpu_ptr(c); + + encoder.launch_kernel( + [=, &encoder](hipStream_t stream) { + hipblaslt_gemm_impl( + handle, + device_id, + op_a, + op_b, + N, + M, + K, + &alpha, + b_ptr, + ldb, + stride_b, // swapped: was b, now is "A" in col-major + a_ptr, + lda, + stride_a, // swapped: was a, now is "B" in col-major + &beta, + c_ptr, + ldc, + stride_c, + batch_count, + hip_dtype, + stream); + }); +} + +void hipblaslt_gemm_raw( + hipStream_t stream, + int op_a, + int op_b, + int M, int N, int K, + const float* alpha, + const void* a_ptr, int lda, + const void* b_ptr, int ldb, + const float* beta, + void* c_ptr, int ldc, + int data_type_hint, + int /*compute_type_hint*/) { + int device_id = 0; + (void)hipGetDevice(&device_id); + hipblasLtHandle_t handle = get_handle(device_id); + + // Map data_type_hint: 1=fp16, 2=bf16, 3=fp32 + hipDataType hip_dtype; + switch (data_type_hint) { + case 1: hip_dtype = HIP_R_16F; break; + case 2: hip_dtype = HIP_R_16BF; break; + default: hip_dtype = HIP_R_32F; break; + } + + hipblaslt_gemm_impl( + handle, + device_id, + static_cast(op_a), + static_cast(op_b), + M, N, K, + alpha, + a_ptr, lda, 0, + b_ptr, ldb, 0, + beta, + c_ptr, ldc, 0, + 1, // batch_count + hip_dtype, + stream); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/gemms/hipblaslt_gemm.h b/mlx/backend/rocm/gemms/hipblaslt_gemm.h new file mode 100644 index 0000000000..c6e980c608 --- /dev/null +++ b/mlx/backend/rocm/gemms/hipblaslt_gemm.h @@ -0,0 +1,71 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/rocm/device.h" + +namespace mlx::core::rocm { + +// hipBLASLt GEMM wrapper functions +// hipBLASLt provides optimized GEMM kernels that can outperform rocBLAS +// for half-precision (fp16/bf16) matrix multiplications by using hardware +// matrix cores more efficiently and selecting algorithms via heuristics. + +// Returns true if hipBLASLt is available and usable on the current device. +bool is_hipblaslt_available(); + +void hipblaslt_gemm( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + const array& b, + int ldb, + float beta, + array& c, + int ldc, + Dtype dtype); + +void hipblaslt_gemm_batched( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + int64_t stride_a, + const array& b, + int ldb, + int64_t stride_b, + float beta, + array& c, + int ldc, + int64_t stride_c, + int batch_count, + Dtype dtype); + +// Raw hipBLASLt GEMM — parameters already in column-major convention +// (A/B swapped, M/N swapped). Call directly from inside kernel lambdas. +void hipblaslt_gemm_raw( + hipStream_t stream, + int op_a, // rocblas_operation / hipblasOperation_t value + int op_b, + int M, int N, int K, + const float* alpha, + const void* a_ptr, int lda, + const void* b_ptr, int ldb, + const float* beta, + void* c_ptr, int ldc, + int data_type, // hipDataType value (HIP_R_16BF, HIP_R_16F, HIP_R_32F) + int compute_type); // hipblasComputeType_t value + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/indexing.hip b/mlx/backend/rocm/indexing.hip index 46b0f42dc5..53a12b5d84 100644 --- a/mlx/backend/rocm/indexing.hip +++ b/mlx/backend/rocm/indexing.hip @@ -4,8 +4,11 @@ #include "mlx/backend/rocm/kernel_utils.hpp" #include "mlx/backend/rocm/jit_module.h" #include "mlx/backend/rocm/device/indexing.hpp" +#include "mlx/backend/rocm/device/binary_ops.hpp" #include "mlx/backend/rocm/device/utils.hpp" #include "mlx/backend/gpu/copy.h" +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/common/utils.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -397,6 +400,69 @@ __global__ void scatter_general_kernel( } } +// SliceUpdate kernel: applies Op to combine existing output values with +// update values at computed slice positions. +template < + typename T, + typename IdxT, + typename Op, + bool OUT_ROW_CONTIG, + bool UPD_ROW_CONTIG, + bool UPD_SCALAR, + int NWORK> +__global__ void slice_update_op_kernel( + const T* updates, + T* out, + int64_t update_size, + hip_array update_shape, + hip_array update_strides, + int32_t update_ndim, + hip_array output_strides, + int64_t output_offset) { + Op op; + + IdxT idx = (IdxT(blockIdx.x) * IdxT(blockDim.x) + IdxT(threadIdx.x)) * NWORK; + IdxT out_idx; + IdxT update_idx; + + if constexpr (OUT_ROW_CONTIG) { + out_idx = idx; + } else { + out_idx = elem_to_loc( + idx, update_shape.data_, output_strides.data_, update_ndim); + } + + if constexpr (!UPD_SCALAR) { + if constexpr (UPD_ROW_CONTIG) { + update_idx = idx; + } else { + update_idx = elem_to_loc( + idx, update_shape.data_, update_strides.data_, update_ndim); + } + } else { + update_idx = 0; + } + + out += output_offset; + + for (int j = 0; j < NWORK && idx < update_size; j++) { + out[out_idx] = op(out[out_idx], updates[update_idx]); + idx++; + + if constexpr (OUT_ROW_CONTIG) { + out_idx = idx; + } else { + out_idx += output_strides[update_ndim - 1]; + } + + if constexpr (UPD_ROW_CONTIG) { + update_idx = idx; + } else if constexpr (!UPD_SCALAR) { + update_idx += update_strides[update_ndim - 1]; + } + } +} + template __global__ void masked_scatter_offsets_kernel( const bool* mask, @@ -1116,6 +1182,147 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { #undef DISPATCH_IDX_TYPE } +void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + + // Calculate out strides, initial offset + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); + + // Do copy for None reduce type + if (reduce_type_ == SliceUpdate::None) { + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ stream()); + return; + } + + // For reduce types (Sum/Prod/Max/Min), launch a kernel + auto [shape, strides] = + collapse_contiguous_dims(upd.shape(), {upd.strides(), out_strides}); + int nwork = 1; + if (shape.back() % 4 == 0) { + nwork = 4; + } else if (shape.back() % 2 == 0) { + nwork = 2; + } + + auto [ds, rc, cc] = check_contiguity(shape, strides[1]); + bool upd_contiguous = upd.flags().row_contiguous; + bool upd_scalar = upd.data_size() == 1; + bool out_contiguous = rc; + + int ndim = shape.size(); + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + encoder.set_input_array(upd); + encoder.set_output_array(out); + + auto shape_param = const_param(shape); + auto upd_strides_param = const_param(strides[0]); + auto out_strides_param = const_param(strides[1]); + + int64_t update_size = upd.size(); + int block_size = 256; + int64_t adjusted_size = (update_size + nwork - 1) / nwork; + int num_blocks = static_cast( + std::min((adjusted_size + block_size - 1) / block_size, (int64_t)65535)); + + #define SLICE_UPDATE_LAUNCH(T, Op, OUT_C, UPD_C, UPD_S, NWORK_VAL) \ + hipLaunchKernelGGL( \ + (rocm::slice_update_op_kernel), \ + dim3(num_blocks), dim3(block_size), 0, stream, \ + gpu_ptr(upd), gpu_ptr(out), update_size, \ + shape_param, upd_strides_param, ndim, \ + out_strides_param, data_offset) + + // Dispatch helper for NWORK + #define DISPATCH_NWORK(T, Op, OUT_C, UPD_C, UPD_S) \ + switch (nwork) { \ + case 4: SLICE_UPDATE_LAUNCH(T, Op, OUT_C, UPD_C, UPD_S, 4); break; \ + case 2: SLICE_UPDATE_LAUNCH(T, Op, OUT_C, UPD_C, UPD_S, 2); break; \ + default: SLICE_UPDATE_LAUNCH(T, Op, OUT_C, UPD_C, UPD_S, 1); break; \ + } + + // Dispatch helper for contiguity flags + #define DISPATCH_CONTIG(T, Op) \ + if (upd_scalar) { \ + if (out_contiguous) { \ + DISPATCH_NWORK(T, Op, true, false, true); \ + } else { \ + DISPATCH_NWORK(T, Op, false, false, true); \ + } \ + } else if (upd_contiguous && out_contiguous) { \ + DISPATCH_NWORK(T, Op, true, true, false); \ + } else if (upd_contiguous) { \ + DISPATCH_NWORK(T, Op, false, true, false); \ + } else if (out_contiguous) { \ + DISPATCH_NWORK(T, Op, true, false, false); \ + } else { \ + DISPATCH_NWORK(T, Op, false, false, false); \ + } + + // Dispatch helper for reduce type + #define DISPATCH_SLICE_OP(T) \ + switch (reduce_type_) { \ + case SliceUpdate::Max: DISPATCH_CONTIG(T, rocm::Maximum); break; \ + case SliceUpdate::Min: DISPATCH_CONTIG(T, rocm::Minimum); break; \ + case SliceUpdate::Sum: DISPATCH_CONTIG(T, rocm::Add); break; \ + case SliceUpdate::Prod: DISPATCH_CONTIG(T, rocm::Multiply); break; \ + default: \ + throw std::runtime_error("SliceUpdate: unsupported reduce type"); \ + } + + encoder.launch_kernel([&](hipStream_t stream) { + switch (out.dtype()) { + case float32: DISPATCH_SLICE_OP(float); break; + case float16: DISPATCH_SLICE_OP(__half); break; + case bfloat16: DISPATCH_SLICE_OP(hip_bfloat16); break; + case int32: DISPATCH_SLICE_OP(int32_t); break; + case int64: DISPATCH_SLICE_OP(int64_t); break; + case uint32: DISPATCH_SLICE_OP(uint32_t); break; + case uint64: DISPATCH_SLICE_OP(uint64_t); break; + case int8: DISPATCH_SLICE_OP(int8_t); break; + case int16: DISPATCH_SLICE_OP(int16_t); break; + case uint8: DISPATCH_SLICE_OP(uint8_t); break; + case uint16: DISPATCH_SLICE_OP(uint16_t); break; + case bool_: DISPATCH_SLICE_OP(bool); break; + default: + throw std::runtime_error("Unsupported dtype for SliceUpdate"); + } + }); + + #undef DISPATCH_SLICE_OP + #undef DISPATCH_CONTIG + #undef DISPATCH_NWORK + #undef SLICE_UPDATE_LAUNCH +} + void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); diff --git a/mlx/backend/rocm/jit_module.cpp b/mlx/backend/rocm/jit_module.cpp index d7f751da65..f94c03c86e 100644 --- a/mlx/backend/rocm/jit_module.cpp +++ b/mlx/backend/rocm/jit_module.cpp @@ -6,11 +6,14 @@ #include "mlx/version.h" #include +#include #include +#include #include #include #include +#include #include #include @@ -18,6 +21,81 @@ namespace mlx::core::rocm { namespace { +// RAII helper that silences stderr during hipRTC compilation. +// AMD's comgr library (used by hipRTC) unconditionally writes preprocessed +// source and internal diagnostics to fd 2. This floods the terminal with +// thousands of lines of compiler-internal defines every time a new fused +// kernel is JIT-compiled. +struct StderrSuppressor { + StderrSuppressor() { + saved_fd_ = dup(STDERR_FILENO); + if (saved_fd_ >= 0) { + int devnull = open("/dev/null", O_WRONLY); + if (devnull >= 0) { + dup2(devnull, STDERR_FILENO); + close(devnull); + active_ = true; + } else { + // Could not open /dev/null — leave stderr alone. + close(saved_fd_); + saved_fd_ = -1; + } + } + } + ~StderrSuppressor() { restore(); } + void restore() { + if (active_) { + fflush(stderr); + dup2(saved_fd_, STDERR_FILENO); + close(saved_fd_); + saved_fd_ = -1; + active_ = false; + } + } + StderrSuppressor(const StderrSuppressor&) = delete; + StderrSuppressor& operator=(const StderrSuppressor&) = delete; + + private: + int saved_fd_ = -1; + bool active_ = false; +}; + +// Extract the last N lines from a compiler log. AMD comgr prepends the +// entire preprocessed source to the error log, making it enormous. The +// actual compiler errors are always at the end. +std::string tail_lines(const std::string& text, size_t n = 60) { + if (text.empty()) { + return text; + } + // Walk backwards to find the start of the last `n` lines. + size_t count = 0; + size_t pos = text.size(); + while (pos > 0 && count < n) { + --pos; + if (text[pos] == '\n') { + ++count; + } + } + if (pos > 0) { + // Skip past the newline we stopped on. + return "... [preprocessed source truncated] ...\n" + text.substr(pos + 1); + } + return text; +} + +// Truncate long kernel names to avoid exceeding filesystem 255-byte limit. +// Names > 200 chars are replaced with a prefix + hash. +std::string safe_filename(const std::string& name) { + constexpr size_t kMaxLen = 200; + if (name.size() <= kMaxLen) { + return name; + } + auto h = std::hash{}(name); + std::ostringstream oss; + oss << name.substr(0, 64) << "_" << std::hex << h; + return oss.str(); +} + #define CHECK_HIPRTC_ERROR(cmd) check_hiprtc_error(#cmd, (cmd)) void check_hiprtc_error(const char* name, hiprtcResult err) { @@ -222,15 +300,24 @@ void compile( args.push_back(arg.c_str()); } + // Suppress stderr during hipRTC compilation. AMD's comgr backend + // unconditionally dumps the entire preprocessed source to fd 2, flooding + // the terminal with thousands of lines of compiler-internal defines. + StderrSuppressor suppressor; hiprtcResult compile_result = hiprtcCompileProgram(prog, args.size(), args.data()); + suppressor.restore(); // restore stderr before any error reporting + if (compile_result != HIPRTC_SUCCESS) { size_t log_size; CHECK_HIPRTC_ERROR(hiprtcGetProgramLogSize(prog, &log_size)); std::vector log(log_size + 1, 0); CHECK_HIPRTC_ERROR(hiprtcGetProgramLog(prog, log.data())); + // The comgr log prepends the entire preprocessed source before the + // actual error messages. Truncate to only the trailing error lines. + std::string truncated = tail_lines(std::string(log.data())); std::ostringstream oss; - oss << "Failed to compile kernel: " << log.data() << "."; + oss << "Failed to compile kernel '" << module_name << "': " << truncated; throw std::runtime_error(oss.str()); } @@ -282,9 +369,12 @@ JitModule::JitModule( std::string hsaco; std::vector> hsaco_kernels; + // Use a safe filename for disk cache to avoid exceeding 255-byte limit + std::string cache_name = safe_filename(module_name); + // Try to load them from the file cache if (!read_cached_hsaco( - hsaco_cache_dir(), module_name, hsaco, hsaco_kernels)) { + hsaco_cache_dir(), cache_name, hsaco, hsaco_kernels)) { auto [precompiled, source_code, kernel_names] = builder(); // Get the HSACO (AMD GPU binary) @@ -301,7 +391,7 @@ JitModule::JitModule( // If requested save them in the file cache for the next launch if (use_disk_cache) { write_cached_hsaco( - hsaco_cache_dir(), module_name, hsaco, hsaco_kernels, source_code); + hsaco_cache_dir(), cache_name, hsaco, hsaco_kernels, source_code); } } diff --git a/mlx/backend/rocm/jit_module.h b/mlx/backend/rocm/jit_module.h index 200e896e97..db2064c425 100644 --- a/mlx/backend/rocm/jit_module.h +++ b/mlx/backend/rocm/jit_module.h @@ -5,6 +5,7 @@ #include "mlx/array.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" #include #include @@ -37,9 +38,7 @@ struct KernelArgs { } void append(const array& a) { - // Use const_cast since HIP APIs expect non-const pointers but we know - // the data won't be modified for input arrays - append(reinterpret_cast(const_cast(a.data()))); + append(reinterpret_cast(gpu_ptr(a))); } template diff --git a/mlx/backend/rocm/layer_norm.hip b/mlx/backend/rocm/layer_norm.hip index 47c8ebfc97..7a2514c76f 100644 --- a/mlx/backend/rocm/layer_norm.hip +++ b/mlx/backend/rocm/layer_norm.hip @@ -111,7 +111,9 @@ __global__ void layer_norm_kernel( shared_sum[0] = var_sum; } __syncthreads(); - float normalizer = rsqrtf(shared_sum[0] / axis_size + eps); + // Use 1/sqrt instead of rsqrtf for IEEE-compliant precision + // (matches Metal's metal::precise::rsqrt behavior) + float normalizer = 1.0f / sqrtf(shared_sum[0] / axis_size + eps); // Write output for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { diff --git a/mlx/backend/rocm/matmul.cpp b/mlx/backend/rocm/matmul.cpp index 9d36728183..35d3a97579 100644 --- a/mlx/backend/rocm/matmul.cpp +++ b/mlx/backend/rocm/matmul.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/backend/rocm/device.h" #include "mlx/backend/rocm/gemms/gemv.h" +#include "mlx/backend/rocm/gemms/hipblaslt_gemm.h" #include "mlx/backend/rocm/gemms/naive_gemm.h" #include "mlx/backend/rocm/kernel_utils.hpp" #include "mlx/primitives.h" @@ -132,6 +133,33 @@ void gemm_rocblas( const array& b, float alpha = 1.0f, float beta = 0.0f) { + // Try hipBLASLt for bf16/fp16 GEMMs -- it often picks faster kernels than + // rocBLAS for half-precision on RDNA 3/3.5/4 and CDNA GPUs. + if ((a.dtype() == bfloat16 || a.dtype() == float16) && + rocm::is_hipblaslt_available()) { + try { + rocm::hipblaslt_gemm( + encoder, + a_transposed, + b_transposed, + M, + N, + K, + alpha, + a, + lda, + b, + ldb, + beta, + out, + N, // ldc = N for row-major output + a.dtype()); + return; + } catch (...) { + // hipBLASLt failed (unsupported config, etc.) -- fall through to rocBLAS. + } + } + auto& device = encoder.device(); rocblas_handle handle = device.get_rocblas_handle(); @@ -365,6 +393,36 @@ void gemm_strided_batched_rocblas( const array& b, float alpha = 1.0f, float beta = 0.0f) { + // Try hipBLASLt for bf16/fp16 batched GEMMs. + if ((a.dtype() == bfloat16 || a.dtype() == float16) && + rocm::is_hipblaslt_available()) { + try { + rocm::hipblaslt_gemm_batched( + encoder, + a_transposed, + b_transposed, + M, + N, + K, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + out, + N, // ldc = N for row-major output + stride_c, + batch_count, + a.dtype()); + return; + } catch (...) { + // hipBLASLt failed -- fall through to rocBLAS. + } + } + auto& device = encoder.device(); rocblas_handle handle = device.get_rocblas_handle(); diff --git a/mlx/backend/rocm/quantized/qdequant.hpp b/mlx/backend/rocm/quantized/qdequant.hpp new file mode 100644 index 0000000000..cb67f458bb --- /dev/null +++ b/mlx/backend/rocm/quantized/qdequant.hpp @@ -0,0 +1,111 @@ +// Shared dequantization utilities for optimized QMM kernels. +// Used by qmv_kernel.hip (GEMV) and qmm_kernel.hip (GEMM). + +#pragma once + +#include "mlx/backend/rocm/device/config.h" +#include +#include +#include + +namespace mlx::core::rocm { + +// --- Compile-time constants --- + +// Number of quantized values packed per uint32 word. +// 4-bit: 8 values, 2-bit: 16 values, 8-bit: 4 values. +template +inline constexpr int pack_factor_u32 = 32 / BITS; + +// Number of uint32 words each thread loads per K-iteration. +// Chosen so that values_per_thread = 16 for all bit widths. +template +inline constexpr int packs_per_thread = 16 / pack_factor_u32; +// 4-bit: 16/8=2, 2-bit: 16/16=1, 8-bit: 16/4=4 + +// Number of quantized values each thread processes per K-iteration. +template +inline constexpr int values_per_thread = 16; + +// Number of K-elements consumed per warp per iteration. +// = values_per_thread * WARP_SIZE = 16 * 32 = 512 +inline constexpr int block_size_k = values_per_thread<4> * WARP_SIZE; + +// Number of output rows computed per thread block. +inline constexpr int ROWS_PER_BLOCK = 8; + +// --- Warp reduction --- + +__device__ __forceinline__ float warp_reduce_sum(float val) { + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_xor(val, offset); + } + return val; +} + +// --- Dequant-and-dot: integer dot product + x-sum accumulation --- +// +// Metal-compatible accumulation: accumulates raw integer dot product and +// x-sum separately. The caller applies scale and bias ONCE per group: +// result += scale * total_qdot + bias * total_xsum +// +// This matches Metal's qdot() which returns scale * accum + sum * bias, +// where accum and sum span all values_per_thread elements at once. +// +// The naive per-element form `acc += x[i] * (scale * q[i] + bias)` is +// mathematically equivalent but produces different float32 rounding due to +// a different number of scale/bias multiply operations, causing LLM output +// to degenerate into repetitive loops after ~10 tokens. + +template +__device__ __forceinline__ void dequant_and_dot( + uint32_t packed, + const float* __restrict__ x_local, + float& qdot_acc, + float& x_sum) +{ + constexpr int pf = pack_factor_u32; + constexpr uint32_t mask = (1u << BITS) - 1u; + + #pragma unroll + for (int i = 0; i < pf; i++) { + float q = static_cast((packed >> (i * BITS)) & mask); + qdot_acc += x_local[i] * q; + x_sum += x_local[i]; + } +} + +// --- Type conversion helpers --- + +__device__ __forceinline__ float to_float(__half x) { + return __half2float(x); +} + +__device__ __forceinline__ float to_float(hip_bfloat16 x) { + return static_cast(x); +} + +__device__ __forceinline__ float to_float(float x) { + return x; +} + +template +__device__ __forceinline__ T from_float(float x); + +template <> +__device__ __forceinline__ __half from_float<__half>(float x) { + return __float2half(x); +} + +template <> +__device__ __forceinline__ hip_bfloat16 from_float(float x) { + return hip_bfloat16(x); +} + +template <> +__device__ __forceinline__ float from_float(float x) { + return x; +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/quantized/qmm.hip b/mlx/backend/rocm/quantized/qmm.hip index 3e55264d5c..1b3c5e57a9 100644 --- a/mlx/backend/rocm/quantized/qmm.hip +++ b/mlx/backend/rocm/quantized/qmm.hip @@ -3,6 +3,7 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/backend/rocm/device.h" #include "mlx/backend/rocm/gemms/naive_gemm.h" +#include "mlx/backend/rocm/gemms/hipblaslt_gemm.h" #include "mlx/backend/rocm/gemms/rocblas_gemm.h" #include "mlx/backend/rocm/kernel_utils.hpp" #include "mlx/backend/rocm/quantized/quantized.h" @@ -13,6 +14,23 @@ #include #include #include +#include +// rocWMMA is only supported on CDNA (gfx9xx) and RDNA 3+ (gfx11xx, gfx12xx). +// Guard the include so it doesn't trigger static_assert on RDNA 1/2 (gfx10xx). +// During host compilation __HIP_DEVICE_COMPILE__ is 0 so rocwmma defines +// ROCWMMA_ARCH_HOST and compiles fine. During device compilation for +// unsupported architectures like gfx1030 the header would static_assert. +#if !defined(__HIP_DEVICE_COMPILE__) || !__HIP_DEVICE_COMPILE__ || \ + defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || \ + defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ + defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) || \ + defined(__gfx1200__) || defined(__gfx1201__) +#define ROCM_HAS_WMMA 1 +#include +#else +#define ROCM_HAS_WMMA 0 +#endif +#include #include #include #include @@ -21,6 +39,111 @@ namespace mlx::core { +namespace rocm { + +// Strided 2D row-copy kernel: copies rows from a source with row_stride != cols +// into a contiguous destination. +// src layout: row i starts at src + i * src_row_stride (elements contiguous within row) +// dst layout: row i starts at dst + i * cols (fully contiguous) +// +// When both row strides and cols_bytes are 4-byte aligned, uses uint32_t +// copies (one 4-byte word per thread iteration) for good throughput without +// alignment concerns. Falls back to byte-by-byte for the non-aligned tail. +__global__ void strided_row_copy_kernel( + const char* __restrict__ src, + char* __restrict__ dst, + int64_t num_rows, + int64_t cols_bytes, + int64_t src_row_stride_bytes, + int64_t dst_row_stride_bytes, + bool use_word_copy) { + int64_t tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t grid_stride = static_cast(blockDim.x) * gridDim.x; + + if (use_word_copy) { + // Fast path: 4-byte word copies. All row strides are 4-byte aligned. + constexpr int64_t WORD = 4; + int64_t cols_words = cols_bytes / WORD; + int64_t total_words = num_rows * cols_words; + for (int64_t i = tid; i < total_words; i += grid_stride) { + int64_t row = i / cols_words; + int64_t word_in_row = i % cols_words; + int64_t src_off = row * src_row_stride_bytes + word_in_row * WORD; + int64_t dst_off = row * dst_row_stride_bytes + word_in_row * WORD; + *reinterpret_cast(dst + dst_off) = + *reinterpret_cast(src + src_off); + } + // Handle remainder bytes (cols_bytes % 4) + int64_t remainder_start = cols_words * WORD; + int64_t remainder_bytes = cols_bytes - remainder_start; + if (remainder_bytes > 0) { + for (int64_t i = tid; i < num_rows * remainder_bytes; i += grid_stride) { + int64_t row = i / remainder_bytes; + int64_t byte_in_tail = i % remainder_bytes; + int64_t src_off = row * src_row_stride_bytes + remainder_start + byte_in_tail; + int64_t dst_off = row * dst_row_stride_bytes + remainder_start + byte_in_tail; + dst[dst_off] = src[src_off]; + } + } + } else { + // Slow path: byte-by-byte copy for non-aligned strides. + int64_t total_bytes = num_rows * cols_bytes; + for (int64_t i = tid; i < total_bytes; i += grid_stride) { + int64_t row = i / cols_bytes; + int64_t byte_in_row = i % cols_bytes; + int64_t src_off = row * src_row_stride_bytes + byte_in_row; + int64_t dst_off = row * dst_row_stride_bytes + byte_in_row; + dst[dst_off] = src[src_off]; + } + } +} + +// General strided copy kernel with strides passed as kernel arguments +// (by-value hip_array structs). Avoids device memory allocation + +// hipMemcpyAsync overhead that contiguous_copy_gpu -> copy_general_input +// would incur. Falls back to contiguous_copy_gpu only for ndim > MAX_NDIM. +__global__ void strided_general_copy_kernel( + const char* __restrict__ src, + char* __restrict__ dst, + int64_t total_elems, + int elem_bytes, + int ndim, + hip_array shapes, + hip_array strides_bytes) { + int64_t tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t grid_stride = static_cast(blockDim.x) * gridDim.x; + for (int64_t idx = tid; idx < total_elems; idx += grid_stride) { + // Convert linear index to strided source offset + int64_t src_offset = 0; + int64_t remaining = idx; + for (int d = ndim - 1; d >= 0; --d) { + int64_t coord = remaining % shapes[d]; + remaining /= shapes[d]; + src_offset += coord * strides_bytes[d]; + } + // Copy element bytes -- specialize for common QMM element sizes + int64_t dst_offset = idx * elem_bytes; + if (elem_bytes == 2) { + *reinterpret_cast(dst + dst_offset) = + *reinterpret_cast(src + src_offset); + } else if (elem_bytes == 4) { + *reinterpret_cast(dst + dst_offset) = + *reinterpret_cast(src + src_offset); + } else if (elem_bytes == 1) { + dst[dst_offset] = src[src_offset]; + } else if (elem_bytes == 8) { + *reinterpret_cast(dst + dst_offset) = + *reinterpret_cast(src + src_offset); + } else { + for (int b = 0; b < elem_bytes; ++b) { + dst[dst_offset + b] = src[src_offset + b]; + } + } + } +} + +} // namespace rocm + namespace { template @@ -28,6 +151,32 @@ struct local_type_identity { using type = T; }; +// Fast contiguous-copy helper for QMM inputs. +// +// Design goals vs the previous implementation (which called contiguous_copy_gpu +// unconditionally when strides didn't match row-major): +// +// 1. **Already contiguous** -- return immediately (unchanged). +// +// 2. **Inner-contiguous with outer stride gap** -- the most common +// non-contiguous pattern from `take` / `gather_sort`. The inner N-1 +// dimensions are packed (stride-1 on the last dim, products match for +// the rest), but the outermost dimension has a stride larger than the +// product of inner shapes. We handle this with a single +// `strided_row_copy_kernel` launch -- no device memory allocation for +// shapes/strides, no hipMemcpyAsync. One kernel dispatch total. +// +// 3. **General non-contiguous** (rare for QMM inputs) -- uses +// `strided_general_copy_kernel` which takes shapes and strides as +// kernel arguments (up to QMM_COPY_MAX_DIMS dimensions). This avoids +// the 2x allocator::malloc + 2x hipMemcpyAsync that +// `contiguous_copy_gpu -> copy_general_input` would issue. One kernel +// dispatch total. Falls back to `contiguous_copy_gpu` only for arrays +// with more than MAX_NDIM (10) dimensions (extremely unlikely for +// QMM operands). +// +// Net effect: non-contiguous copies go from 5 GPU operations (2 allocs + +// 2 memcpy + 1 kernel) down to 1 kernel launch. inline array ensure_row_contiguous_matrix( const array& x, rocm::CommandEncoder& enc, @@ -36,12 +185,19 @@ inline array ensure_row_contiguous_matrix( return x; } + // --- Fast path 1: already row-major contiguous --- + int ndim = x.ndim(); + const auto& strides = x.strides(); bool row_major_contiguous = true; int64_t expected_stride = 1; - for (int i = x.ndim() - 1; i >= 0; --i) { + // Track the innermost contiguous dimensions while checking. + // If we break at dimension i, dimensions [i+1 .. ndim-1] are packed. + int first_noncontig_dim = -1; + for (int i = ndim - 1; i >= 0; --i) { if (x.shape(i) > 1) { - if (x.strides()[i] != expected_stride) { + if (strides[i] != expected_stride) { row_major_contiguous = false; + first_noncontig_dim = i; break; } expected_stride *= x.shape(i); @@ -52,6 +208,174 @@ inline array ensure_row_contiguous_matrix( return x; } + // Empty arrays don't need copying. + if (x.size() == 0) { + return x; + } + + size_t elem_bytes = x.itemsize(); + + // Helper: allocate a contiguous output array and return src/dst pointers. + // Deferred until we know a copy is actually needed and which path to use. + auto make_output = [&]() -> array { + array out(x.shape(), x.dtype(), nullptr, {}); + out.set_data(allocator::malloc(out.nbytes())); + enc.add_temporary(out); + return out; + }; + + // --- Fast path 2: inner-contiguous, only outermost dim has a stride gap --- + // This covers the common case where x comes from take/gather of a [E, K] + // or [B, M, K] array -- inner dims are packed, outer dim stride > product. + // We also handle the case where the gap is at any single dimension (not + // just dim 0) as long as all dimensions below it are packed. + if (first_noncontig_dim >= 0) { + // Verify that all dimensions below first_noncontig_dim are packed, + // and only first_noncontig_dim itself has a non-standard stride. + // Dimensions above first_noncontig_dim (if any) must also be consistent + // with first_noncontig_dim's layout. + bool is_simple_outer_gap = true; + // Check: first_noncontig_dim's stride must be >= expected_stride + // (i.e. the inner block is correct, just spaced further apart). + if (strides[first_noncontig_dim] < expected_stride) { + is_simple_outer_gap = false; + } + // Check dimensions above first_noncontig_dim: their strides must be + // consistent with first_noncontig_dim's stride * shape products. + if (is_simple_outer_gap) { + int64_t outer_expected = strides[first_noncontig_dim] * x.shape(first_noncontig_dim); + for (int i = first_noncontig_dim - 1; i >= 0; --i) { + if (x.shape(i) <= 1) continue; + if (strides[i] != outer_expected) { + is_simple_outer_gap = false; + break; + } + outer_expected *= x.shape(i); + } + } + + if (is_simple_outer_gap && first_noncontig_dim == 0) { + // Simplest case: only the outermost dim has extra stride. + // inner_size = product of shapes[1..ndim-1] + array x_copy = make_output(); + const char* src = reinterpret_cast(gpu_ptr(x)); + char* dst = reinterpret_cast(gpu_ptr(x_copy)); + + int64_t inner_size = 1; + for (int i = 1; i < ndim; ++i) { + inner_size *= x.shape(i); + } + int64_t num_rows = x.shape(0); + int64_t cols_bytes = inner_size * static_cast(elem_bytes); + int64_t src_row_stride_bytes = strides[0] * static_cast(elem_bytes); + int64_t dst_row_stride_bytes = cols_bytes; + bool word_copy = (cols_bytes % 4 == 0) && + (src_row_stride_bytes % 4 == 0) && + (dst_row_stride_bytes % 4 == 0); + + int block_size = 256; + int64_t work_items = word_copy + ? num_rows * (cols_bytes / 4) + : num_rows * cols_bytes; + int num_blocks = static_cast( + std::min((work_items + block_size - 1) / block_size, 65535)); + + enc.launch_kernel([=](hipStream_t stream) { + hipLaunchKernelGGL( + rocm::strided_row_copy_kernel, + dim3(num_blocks), dim3(block_size), 0, stream, + src, dst, + num_rows, cols_bytes, + src_row_stride_bytes, dst_row_stride_bytes, + word_copy); + }); + return x_copy; + } + + if (is_simple_outer_gap) { + // Gap at an interior dimension. batch_count == 1 is common here. + int64_t batch_count = 1; + for (int i = 0; i < first_noncontig_dim; ++i) { + batch_count *= x.shape(i); + } + if (batch_count == 1) { + array x_copy = make_output(); + const char* src = reinterpret_cast(gpu_ptr(x)); + char* dst = reinterpret_cast(gpu_ptr(x_copy)); + + int64_t inner_size = 1; + for (int i = first_noncontig_dim + 1; i < ndim; ++i) { + inner_size *= x.shape(i); + } + int64_t slab_rows = x.shape(first_noncontig_dim); + int64_t cols_bytes = inner_size * static_cast(elem_bytes); + int64_t src_row_stride_bytes = strides[first_noncontig_dim] * static_cast(elem_bytes); + int64_t dst_row_stride_bytes = cols_bytes; + bool word_copy = (cols_bytes % 4 == 0) && + (src_row_stride_bytes % 4 == 0) && + (dst_row_stride_bytes % 4 == 0); + + int block_size = 256; + int64_t work_items = word_copy + ? slab_rows * (cols_bytes / 4) + : slab_rows * cols_bytes; + int num_blocks = static_cast( + std::min((work_items + block_size - 1) / block_size, 65535)); + + enc.launch_kernel([=](hipStream_t stream) { + hipLaunchKernelGGL( + rocm::strided_row_copy_kernel, + dim3(num_blocks), dim3(block_size), 0, stream, + src, dst, + slab_rows, cols_bytes, + src_row_stride_bytes, dst_row_stride_bytes, + word_copy); + }); + return x_copy; + } + // batch_count > 1 with interior gap: fall through to general path + } + } + + // --- Fast path 3: general non-contiguous, strides as kernel args --- + // Handles arbitrary stride patterns with up to MAX_NDIM dimensions. + // Shapes and byte-strides are passed as hip_array structs (by value), + // so no device memory allocation or hipMemcpyAsync is needed. + // One kernel launch total. + if (ndim <= MAX_NDIM) { + array x_copy = make_output(); + const char* src = reinterpret_cast(gpu_ptr(x)); + char* dst = reinterpret_cast(gpu_ptr(x_copy)); + + int64_t total_elems = x.size(); + int eb = static_cast(elem_bytes); + + int block_size = 256; + int num_blocks = static_cast( + std::min((total_elems + block_size - 1) / block_size, 65535)); + + // Pack into hip_array structs that can be passed by value to the kernel. + rocm::hip_array shapes_arg = {}; + rocm::hip_array strides_bytes_arg = {}; + for (int i = 0; i < ndim; ++i) { + shapes_arg.data_[i] = x.shape(i); + strides_bytes_arg.data_[i] = strides[i] * static_cast(elem_bytes); + } + + enc.launch_kernel([=](hipStream_t stream) { + hipLaunchKernelGGL( + rocm::strided_general_copy_kernel, + dim3(num_blocks), dim3(block_size), 0, stream, + src, dst, + total_elems, eb, ndim, + shapes_arg, strides_bytes_arg); + }); + return x_copy; + } + + // --- Fallback: ndim > MAX_NDIM (extremely rare for QMM) --- + // Use the generic copy infrastructure which allocates device buffers + // for shape/strides arrays (2 allocs + 2 hipMemcpyAsync + 1 kernel). array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; @@ -205,12 +529,19 @@ inline int select_qmv_cols_per_block(int K, int N, int bits) { } inline int select_qmv_threads_per_col(int K, int N, int bits, int batch_count) { + // On RDNA 3.5 (wave32), 16 threads per column gives better occupancy + // than 32 for most LLM decode shapes. 32 threads only helps for very + // large K where the extra parallelism in the reduction outweighs the + // reduced block count. int threads_per_col = 16; if (WARP_SIZE == 32) { bool quant_bits_supported = (bits == 2 || bits == 4 || bits == 5 || bits == 6 || bits == 8); - bool large_decode_like = (batch_count == 1) && (N >= 4096 || K >= 4096); - if (quant_bits_supported && large_decode_like) { + // On RDNA 3.5 (40 CUs / 20 WGPs), 16 threads/col allows 2 columns + // per warp, increasing memory-level parallelism for decode. Only use + // full warp (32) for extreme K where reduction parallelism dominates. + bool extreme = (batch_count == 1) && (K >= 16384); + if (quant_bits_supported && extreme) { threads_per_col = WARP_SIZE; } } @@ -665,6 +996,25 @@ void dequant_rocblas_gemm( case bfloat16: { float alpha_f = alpha; float beta_f = beta; + + // Try hipBLASLt first for bf16 GEMMs — often faster on RDNA 3.5/CDNA + if (rocm::is_hipblaslt_available()) { + try { + // data_type=0 means "use bfloat16", impl maps internally + rocm::hipblaslt_gemm_raw( + stream, + static_cast(op_b), static_cast(op_a), + N, M, K, + &alpha_f, b_ptr, ldb, a_ptr, lda, + &beta_f, c_ptr, ldc, + 2, // 2 = bfloat16 (mapped in impl) + 0); // unused + break; + } catch (...) { + // Fall through to rocBLAS + } + } + int solution_index = qmm_gemm_solution_index_bf16(false); static std::atomic solution_valid{true}; @@ -2562,34 +2912,13 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { const void* biases_ptr = has_bias ? gpu_ptr(*biases) : nullptr; void* out_ptr = gpu_ptr(out); - bool use_alignment_qmv = should_use_alignment_qmv_noshared_path( - M, - N, - K, - batch_count, - transpose_, - can_use_batched_qmv, - bits_, - mode_, - x_ptr, - w_ptr, - scales_ptr, - biases_ptr, - has_bias); - bool use_noshared_qmv_variant = use_tiny_k_qmv || use_alignment_qmv; - - if (use_alignment_qmv) { - fast_cols_per_block = std::max(fast_cols_per_block, 64); - while (fast_cols_per_block > max_cols_per_block) { - fast_cols_per_block /= 2; - } - while (fast_cols_per_block > 1 && (N % fast_cols_per_block) != 0 && - fast_cols_per_block > 8) { - fast_cols_per_block /= 2; - } - fast_block = dim3(fast_threads_per_col, fast_cols_per_block); - fast_grid = dim3((N + fast_cols_per_block - 1) / fast_cols_per_block, M); - } + // The noshared variant reads x from global memory redundantly per warp. + // The shared variant caches x in LDS and is ~15x faster for decode shapes. + // Always prefer shared unless K is tiny (where LDS overhead isn't worth it). + bool use_noshared_qmv_variant = use_tiny_k_qmv; + + // The noshared path used to increase cols_per_block for aligned data. + // Since we always use the shared variant now, no special grid adjustment needed. enc.launch_kernel([&, x_ptr, @@ -3068,6 +3397,421 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { } namespace rocm { + +// ====================================================================== +// GPU-only expert-batched gather QMV for sorted indices. +// +// Grid: (M, ceil(N/cols_per_block), max_unique_experts) +// Each block in z-dimension finds its expert by binary-searching the sorted +// rhs_indices array. No CPU-side run computation needed. +// +// The kernel reads the weight column ONCE per expert and iterates over all +// batch elements assigned to that expert, amortizing weight memory traffic. +// ====================================================================== +template < + typename T, + typename ScaleT, + int BITS, + int GROUP_SIZE, + bool AFFINE, + int THREADS_PER_COL> +__global__ void __launch_bounds__(1024) gather_qmv_expert_batched_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, // SORTED + T* __restrict__ out, + int B, + int M, + int N, + int K, + int E, + bool has_bias, + bool implicit_lhs, + int64_t implicit_x_batch_stride) { + const int lane = threadIdx.x; + const int warp_idx = threadIdx.y; + const int col = blockIdx.y * blockDim.y + warp_idx; + const int row = blockIdx.x; + const int expert_slot = blockIdx.z; // which unique expert this block handles + + if (row >= M || col >= N) return; + + // Find this expert's token range using the expert_slot as a run index. + // Since rhs_indices is sorted, run boundaries are where values change. + // We use a parallel scan: all threads cooperate to count unique experts + // up to expert_slot, then binary-search for the run boundaries. + // + // Fast path: lane 0 does a boundary skip using binary search. + int run_start = 0, run_end = 0; + uint32_t expert_id = 0; + + if (lane == 0 && warp_idx == 0) { + // Skip to the expert_slot-th unique expert by jumping over run boundaries. + // Each boundary is where rhs_indices[i] != rhs_indices[i-1]. + int pos = 0; + for (int skip = 0; skip < expert_slot && pos < B; ++skip) { + // Binary search for end of current run (first index where value differs) + uint32_t cur_val = rhs_indices[pos]; + int lo = pos + 1, hi = B; + while (lo < hi) { + int mid = (lo + hi) >> 1; + if (rhs_indices[mid] == cur_val) lo = mid + 1; + else hi = mid; + } + pos = lo; + } + if (pos < B) { + run_start = pos; + expert_id = rhs_indices[pos]; + // Binary search for end of this expert's run + int lo = pos + 1, hi = B; + while (lo < hi) { + int mid = (lo + hi) >> 1; + if (rhs_indices[mid] == expert_id) lo = mid + 1; + else hi = mid; + } + run_end = lo; + } + } + + // Broadcast via shared memory + __shared__ int s_run_start, s_run_end; + __shared__ uint32_t s_expert_id; + if (lane == 0 && warp_idx == 0) { + s_run_start = run_start; + s_run_end = run_end; + s_expert_id = expert_id; + } + __syncthreads(); + run_start = s_run_start; + run_end = s_run_end; + expert_id = s_expert_id; + + if (run_end <= run_start) return; // this block has no work + if (expert_id >= static_cast(E)) return; + + // Weight pointers for this expert (loaded ONCE, reused for all tokens in run) + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + int64_t w_expert_stride = static_cast(N) * row_bytes; + int64_t sb_expert_stride = static_cast(N) * num_groups; + + const uint8_t* w_row = w + static_cast(expert_id) * w_expert_stride + + static_cast(col) * row_bytes; + const ScaleT* scales_row = scales + static_cast(expert_id) * sb_expert_stride + + static_cast(col) * num_groups; + const ScaleT* biases_row = has_bias + ? (biases + static_cast(expert_id) * sb_expert_stride + + static_cast(col) * num_groups) + : nullptr; + + // Process each batch element in the run + int64_t x_batch_stride = static_cast(M) * K; + for (int b = run_start; b < run_end; ++b) { + uint32_t lhs_idx = implicit_lhs ? 0u : lhs_indices[b]; + int64_t x_offset = implicit_lhs + ? (static_cast(b) * implicit_x_batch_stride) + : (static_cast(lhs_idx) * x_batch_stride); + const T* x_row = x + x_offset + static_cast(row) * K; + + float acc = 0.0f; + + for (int g = 0; g < num_groups; ++g) { + int k_start = g * GROUP_SIZE; + int k_end = min(k_start + GROUP_SIZE, K); + + float scale = load_scale_value(scales_row[g]); + float bias_val = has_bias ? static_cast(biases_row[g]) : 0.0f; + + if constexpr (AFFINE) { + float qx_acc = 0.0f; + float x_group_sum = 0.0f; + + if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k / 2]); + float w0 = static_cast(w_packed & 0xF); + float w1 = static_cast((w_packed >> 4) & 0xF); + float w2 = static_cast((w_packed >> 8) & 0xF); + float w3 = static_cast((w_packed >> 12) & 0xF); + float w4 = static_cast((w_packed >> 16) & 0xF); + float w5 = static_cast((w_packed >> 20) & 0xF); + float w6 = static_cast((w_packed >> 24) & 0xF); + float w7 = static_cast((w_packed >> 28) & 0xF); + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + float x4 = static_cast(x_row[k + 4]); + float x5 = static_cast(x_row[k + 5]); + float x6 = static_cast(x_row[k + 6]); + float x7 = static_cast(x_row[k + 7]); + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + if (has_bias) + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) x_group_sum += x_val; + } + } else if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = static_cast(w_packed & 0xFF); + float w1 = static_cast((w_packed >> 8) & 0xFF); + float w2 = static_cast((w_packed >> 16) & 0xFF); + float w3 = static_cast((w_packed >> 24) & 0xFF); + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + if (has_bias) x_group_sum += x0 + x1 + x2 + x3; + } + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + float w_val = static_cast(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + if (has_bias) x_group_sum += x_val; + } + } else { + for (int k_local = lane; k_start + k_local < k_end; k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) x_group_sum += x_val; + } + } + + qx_acc = subgroup_reduce_sum_qmm(qx_acc); + x_group_sum = subgroup_reduce_sum_qmm(x_group_sum); + acc += scale * qx_acc + bias_val * x_group_sum; + } else { + float qx_acc = 0.0f; + for (int k_local = lane; k_start + k_local < k_end; k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, dequantize_value(quant_val, 1.0f, 0.0f), qx_acc); + } + acc += scale * subgroup_reduce_sum_qmm(qx_acc); + } + } + + if (lane == 0) { + out[static_cast(b) * M * N + static_cast(row) * N + col] = static_cast(acc); + } + } +} + +// ====================================================================== +// Prefill-optimized gather QMV: groups batch elements by expert. +// +// For sorted rhs_indices, consecutive batch elements hit the same expert. +// This kernel assigns blockIdx.z to contiguous runs of same-expert batches, +// so all rows for one expert share weight reads from global memory. +// Each block handles one column (via warp cooperation) and iterates over +// all M rows for each batch element in the run. +// +// Grid: (num_runs, ceil(N/cols_per_block), max_rows_per_run) +// Where num_runs = number of contiguous expert runs. +// ====================================================================== +template < + typename T, + typename ScaleT, + int BITS, + int GROUP_SIZE, + bool AFFINE, + int THREADS_PER_COL> +__global__ void __launch_bounds__(1024) gather_qmv_prefill_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + const int* __restrict__ run_starts, // [num_runs]: start batch idx of each run + const int* __restrict__ run_lengths, // [num_runs]: length of each run + const int* __restrict__ out_perm, // [B]: sorted batch idx → original batch idx + T* __restrict__ out, + int B, + int M, + int N, + int K, + int E, + bool has_bias, + int64_t x_batch_stride) { + const int lane = threadIdx.x; + const int warp_idx = threadIdx.y; + const int col = blockIdx.y * blockDim.y + warp_idx; + const int run_id = blockIdx.z; + const int row = blockIdx.x; + + if (row >= M || col >= N) return; + + int run_start = run_starts[run_id]; + int run_len = run_lengths[run_id]; + + // All batches in this run have the same expert + uint32_t rhs_idx = rhs_indices[run_start]; + if (rhs_idx >= static_cast(E)) return; + + // Weight pointers (same for all batches in run) + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + int64_t w_expert_stride = static_cast(N) * row_bytes; + int64_t sb_expert_stride = static_cast(N) * num_groups; + int64_t col_w_offset = static_cast(col) * row_bytes; + int64_t col_sb_offset = static_cast(col) * num_groups; + + const uint8_t* w_row = w + static_cast(rhs_idx) * w_expert_stride + col_w_offset; + const ScaleT* scales_row = scales + static_cast(rhs_idx) * sb_expert_stride + col_sb_offset; + const ScaleT* biases_row = has_bias + ? (biases + static_cast(rhs_idx) * sb_expert_stride + col_sb_offset) + : nullptr; + + // Process each batch element in the run + for (int r = 0; r < run_len; ++r) { + int batch = run_start + r; + uint32_t lhs_idx = lhs_indices[batch]; + const T* x_row = x + static_cast(lhs_idx) * x_batch_stride + static_cast(row) * K; + + float acc = 0.0f; + + for (int g = 0; g < num_groups; ++g) { + int k_start = g * GROUP_SIZE; + int k_end = min(k_start + GROUP_SIZE, K); + + float scale = load_scale_value(scales_row[g]); + float bias_val = has_bias ? static_cast(biases_row[g]) : 0.0f; + + if constexpr (AFFINE) { + float qx_acc = 0.0f; + float x_group_sum = 0.0f; + + if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k / 2]); + float w0 = static_cast(w_packed & 0xF); + float w1 = static_cast((w_packed >> 4) & 0xF); + float w2 = static_cast((w_packed >> 8) & 0xF); + float w3 = static_cast((w_packed >> 12) & 0xF); + float w4 = static_cast((w_packed >> 16) & 0xF); + float w5 = static_cast((w_packed >> 20) & 0xF); + float w6 = static_cast((w_packed >> 24) & 0xF); + float w7 = static_cast((w_packed >> 28) & 0xF); + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + float x4 = static_cast(x_row[k + 4]); + float x5 = static_cast(x_row[k + 5]); + float x6 = static_cast(x_row[k + 6]); + float x7 = static_cast(x_row[k + 7]); + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + if (has_bias) + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) x_group_sum += x_val; + } + } else if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = static_cast(w_packed & 0xFF); + float w1 = static_cast((w_packed >> 8) & 0xFF); + float w2 = static_cast((w_packed >> 16) & 0xFF); + float w3 = static_cast((w_packed >> 24) & 0xFF); + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + if (has_bias) x_group_sum += x0 + x1 + x2 + x3; + } + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + float w_val = static_cast(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + if (has_bias) x_group_sum += x_val; + } + } else { + for (int k_local = lane; k_start + k_local < k_end; k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) x_group_sum += x_val; + } + } + + qx_acc = subgroup_reduce_sum_qmm(qx_acc); + x_group_sum = subgroup_reduce_sum_qmm(x_group_sum); + acc += scale * qx_acc + bias_val * x_group_sum; + } else { + float qx_acc = 0.0f; + for (int k_local = lane; k_start + k_local < k_end; k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, dequantize_value(quant_val, 1.0f, 0.0f), qx_acc); + } + acc += scale * subgroup_reduce_sum_qmm(qx_acc); + } + } + + if (lane == 0) { + const int orig_batch = out_perm[batch]; + out[static_cast(orig_batch) * M * N + static_cast(row) * N + col] = static_cast(acc); + } + } +} + template < typename T, typename ScaleT, @@ -3615,6 +4359,204 @@ __global__ void gather_qmv_kernel( } out[batch * M * N + row * N + col] = (T)acc; } + +// ====================================================================== +// WMMA-accelerated gather QMV prefill kernel using rocwmma 16x16x16 tiles. +// +// Each wavefront (32 lanes on RDNA 3.5 / gfx1151) computes one 16x16 +// output tile. Weights are dequantized from 4-bit packed format into +// bf16 in shared memory, then loaded into rocwmma fragments for the +// matrix multiply-accumulate. Accumulation is in float32; the final +// result is converted back to bf16 on store. +// +// Grid: (ceil(M/16), ceil(N/16), num_runs) +// Block: (32, 1, 1) -- one wave32 per 16x16 output tile +// +// On architectures without WMMA support (RDNA 1/2) the kernel body is +// an empty stub; dispatch checks prevent it from being launched there. +// ====================================================================== +template +__global__ void __launch_bounds__(32) gather_qmv_wmma_prefill_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + const int* __restrict__ run_starts, + const int* __restrict__ run_lengths, + const int* __restrict__ out_perm, // maps sorted batch idx → original batch idx + T* __restrict__ out, + int B, int M, int N, int K, int E, + bool has_bias, int64_t x_batch_stride) { + +#if ROCM_HAS_WMMA + + static_assert(BITS == 4, "WMMA prefill kernel only supports 4-bit quantized weights"); + static_assert(AFFINE, "WMMA prefill kernel only supports affine quantization"); + + constexpr int WMMA_M = 16; + constexpr int WMMA_N = 16; + constexpr int WMMA_K = 16; + + // Tile coordinates in the output matrix + const int tile_row = blockIdx.x * WMMA_M; // starting row of this 16x16 tile + const int tile_col = blockIdx.y * WMMA_N; // starting col of this 16x16 tile + const int run_id = blockIdx.z; + + // Bounds check -- the dispatch guarantees M and N are multiples of 16, + // but guard anyway for safety. + if (tile_row >= M || tile_col >= N) return; + + const int lane = threadIdx.x; // 0..31 + + // Run info + const int run_start = run_starts[run_id]; + const int run_len = run_lengths[run_id]; + + const uint32_t rhs_idx = rhs_indices[run_start]; + if (rhs_idx >= static_cast(E)) return; + + // Weight layout constants + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; // bytes per weight row (one output col) + const int64_t w_expert_stride = static_cast(N) * row_bytes; + const int64_t sb_expert_stride = static_cast(N) * num_groups; + + // Base pointers for this expert + const uint8_t* w_expert = w + static_cast(rhs_idx) * w_expert_stride; + const ScaleT* s_expert = scales + static_cast(rhs_idx) * sb_expert_stride; + const ScaleT* b_expert = has_bias + ? (biases + static_cast(rhs_idx) * sb_expert_stride) + : nullptr; + + // Shared memory for dequantized weight tile [WMMA_K x WMMA_N] in row-major + // and for x tile [WMMA_M x WMMA_K] in row-major. + // Total: (16*16 + 16*16) * sizeof(hip_bfloat16) = 1024 bytes + __shared__ hip_bfloat16 smem_w[WMMA_K * WMMA_N]; // [16][16] row-major + __shared__ hip_bfloat16 smem_x[WMMA_M * WMMA_K]; // [16][16] row-major + + // Fragment types for bf16 input, f32 accumulation + using frag_a = rocwmma::fragment; + using frag_b = rocwmma::fragment; + using frag_acc = rocwmma::fragment; + + // Process each batch element in the run + for (int r = 0; r < run_len; ++r) { + const int batch = run_start + r; + const uint32_t lhs_idx = lhs_indices[batch]; + const T* x_base = x + static_cast(lhs_idx) * x_batch_stride + + static_cast(tile_row) * K; + + // Zero the accumulator for this batch element + frag_acc acc; + rocwmma::fill_fragment(acc, 0.0f); + + // Loop over K dimension in chunks of WMMA_K (16) + for (int k_base = 0; k_base < K; k_base += WMMA_K) { + // --- Load x tile [WMMA_M x WMMA_K] into shared memory --- + // 32 lanes load 256 elements (16x16) -> 8 elements per lane + // Pad with zero for rows beyond M (handles non-16-aligned M) + #pragma unroll + for (int i = 0; i < (WMMA_M * WMMA_K + 31) / 32; ++i) { + int idx = lane + i * 32; + if (idx < WMMA_M * WMMA_K) { + int m_local = idx / WMMA_K; + int k_local = idx % WMMA_K; + int m_global = tile_row + m_local; + int k_global = k_base + k_local; + if (m_global < M && k_global < K) { + smem_x[idx] = x_base[m_local * K + k_global]; + } else { + smem_x[idx] = static_cast(0.0f); + } + } + } + + // --- Dequantize weight tile [WMMA_K x WMMA_N] into shared memory --- + // Layout: smem_w[k][n] = dequant(w[expert, tile_col + n, k_base + k]) + // w is stored as [N, row_bytes], each row for one output column. + // We need 16 columns x 16 K values = 256 values, 8 per lane. + #pragma unroll + for (int i = 0; i < (WMMA_K * WMMA_N + 31) / 32; ++i) { + int idx = lane + i * 32; + if (idx < WMMA_K * WMMA_N) { + int k_local = idx / WMMA_N; // row in [K, N] + int n_local = idx % WMMA_N; // col in [K, N] + int k_global = k_base + k_local; + int n_global = tile_col + n_local; + + if (k_global < K) { + // Pointer to weight row for output column n_global + const uint8_t* w_row = w_expert + static_cast(n_global) * row_bytes; + + // Extract 4-bit quantized value + uint8_t packed = w_row[k_global >> 1]; + uint8_t quant_val = (k_global & 1) ? (packed >> 4) : (packed & 0xF); + + // Dequantize: val = scale * quant_val + bias + int group_idx = k_global / GROUP_SIZE; + float scale = static_cast( + s_expert[static_cast(n_global) * num_groups + group_idx]); + float bias_val = has_bias + ? static_cast( + b_expert[static_cast(n_global) * num_groups + group_idx]) + : 0.0f; + float dequant = scale * static_cast(quant_val) + bias_val; + smem_w[idx] = static_cast(dequant); + } else { + smem_w[idx] = static_cast(0.0f); + } + } + } + + __syncthreads(); + + // --- Load fragments from shared memory and perform MMA --- + frag_a a_frag; + frag_b b_frag; + + // Load A from smem_x [WMMA_M x WMMA_K], row-major, ldm = WMMA_K + rocwmma::load_matrix_sync(a_frag, smem_x, WMMA_K); + // Load B from smem_w [WMMA_K x WMMA_N], row-major, ldm = WMMA_N + rocwmma::load_matrix_sync(b_frag, smem_w, WMMA_N); + + // D = A * B + C + rocwmma::mma_sync(acc, a_frag, b_frag, acc); + + __syncthreads(); + } + + // --- Store the 16x16 result tile --- + // Store f32 accumulator to shared memory, then convert to bf16 for output. + __shared__ float smem_out_f32[WMMA_M * WMMA_N]; + + rocwmma::store_matrix_sync(smem_out_f32, acc, WMMA_N, rocwmma::mem_row_major); + __syncthreads(); + + // Convert f32 -> bf16 and write to global output (mask out-of-bounds rows) + // Use out_perm to map sorted batch position back to original output position + const int orig_batch = out_perm[batch]; + T* out_base = out + static_cast(orig_batch) * M * N + + static_cast(tile_row) * N + + tile_col; + #pragma unroll + for (int i = 0; i < (WMMA_M * WMMA_N + 31) / 32; ++i) { + int idx = lane + i * 32; + if (idx < WMMA_M * WMMA_N) { + int m_local = idx / WMMA_N; + int n_local = idx % WMMA_N; + if (tile_row + m_local < M) { + out_base[m_local * N + n_local] = static_cast(smem_out_f32[idx]); + } + } + } + __syncthreads(); + } + +#endif // ROCM_HAS_WMMA +} + } // namespace rocm void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { @@ -3690,16 +4632,202 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { bool use_fast_gather_qmv = transpose_ && bits_supported_by_fast; use_fast_gather_qmv = parse_warp_kernel_env( "MLX_ROCM_GATHER_QMV_USE_WARP", use_fast_gather_qmv); + // ---- Prefill optimization: group by expert for M>1 ---- + // Works with both sorted and unsorted rhs_indices; we sort on CPU. + // NOTE: MLX's MoE expands tokens to B individual M=1 calls, so M>1 is rare. + // The WMMA prefill kernel is used when upstream batching produces M>1. + if (M > 1 && transpose_ && E > 0 && batch_ndim == 1 && + mode_ == QuantizationMode::Affine && x.dtype() == bfloat16 && + group_size_ == 64 && (bits_ == 4 || bits_ == 8)) { + // Sort batch elements by expert to form contiguous runs. + // This allows the kernel to process all tokens for one expert together, + // sharing weight reads. We create a sorted permutation on CPU. + const auto* ri_cpu = rhs_indices.data(); + const auto* li_cpu = lhs_indices.data(); + + // Create sort permutation by expert index + std::vector perm(B); + std::iota(perm.begin(), perm.end(), 0); + std::sort(perm.begin(), perm.end(), [&](int a, int b) { + return ri_cpu[a] < ri_cpu[b]; + }); + + // Build sorted index arrays and compute runs + std::vector sorted_ri(B), sorted_li(B); + for (int i = 0; i < B; ++i) { + sorted_ri[i] = ri_cpu[perm[i]]; + sorted_li[i] = li_cpu[perm[i]]; + } + + std::vector run_starts_vec, run_lengths_vec; + run_starts_vec.reserve(E); + run_lengths_vec.reserve(E); + int run_begin = 0; + for (int b = 1; b <= B; ++b) { + if (b == B || sorted_ri[b] != sorted_ri[run_begin]) { + run_starts_vec.push_back(run_begin); + run_lengths_vec.push_back(b - run_begin); + run_begin = b; + } + } + int num_runs = static_cast(run_starts_vec.size()); + + // Upload sorted indices to GPU + array sorted_ri_arr({B}, uint32, nullptr, {}); + array sorted_li_arr({B}, uint32, nullptr, {}); + sorted_ri_arr.set_data(allocator::malloc(sorted_ri_arr.nbytes())); + sorted_li_arr.set_data(allocator::malloc(sorted_li_arr.nbytes())); + std::memcpy(sorted_ri_arr.data(), sorted_ri.data(), B * sizeof(uint32_t)); + std::memcpy(sorted_li_arr.data(), sorted_li.data(), B * sizeof(uint32_t)); + enc.set_input_array(sorted_ri_arr); + enc.set_input_array(sorted_li_arr); + + // Also need a mapping from sorted position back to original batch index for output + array perm_arr({B}, int32, nullptr, {}); + perm_arr.set_data(allocator::malloc(perm_arr.nbytes())); + std::memcpy(perm_arr.data(), perm.data(), B * sizeof(int)); + enc.set_input_array(perm_arr); + + // Upload run info to GPU + array run_starts_arr({num_runs}, int32, nullptr, {}); + array run_lengths_arr({num_runs}, int32, nullptr, {}); + run_starts_arr.set_data(allocator::malloc(run_starts_arr.nbytes())); + run_lengths_arr.set_data(allocator::malloc(run_lengths_arr.nbytes())); + std::memcpy(run_starts_arr.data(), run_starts_vec.data(), num_runs * sizeof(int)); + std::memcpy(run_lengths_arr.data(), run_lengths_vec.data(), num_runs * sizeof(int)); + enc.set_input_array(run_starts_arr); + enc.set_input_array(run_lengths_arr); + + int64_t x_bs = (x_batch_count == 1) ? 0 : static_cast(M) * K; + + // ---- WMMA path: use 16x16x16 wave matrix multiply when tiles align ---- + // WMMA tiles are 16x16; kernel handles non-aligned M with bounds masking. + // N must be 16-aligned (typical for transformer hidden dimensions). + bool use_wmma = (M >= 2) && (N % 16 == 0) && (bits_ == 4); + use_wmma = parse_warp_kernel_env("MLX_ROCM_GATHER_QMV_USE_WMMA", use_wmma); + + if (use_wmma) { + // One wave32 per 16x16 output tile + dim3 wmma_block(32, 1, 1); + dim3 wmma_grid((M + 15) / 16, (N + 15) / 16, num_runs); + // Shared memory: smem_w[16*16] + smem_x[16*16] bf16 + smem_out_f32[16*16] f32 + // = 512 + 512 + 1024 = 2048 bytes + size_t wmma_smem = 0; // static shared memory, declared in-kernel + + enc.launch_kernel([&](hipStream_t stream) { + hipLaunchKernelGGL( + (rocm::gather_qmv_wmma_prefill_kernel), + wmma_grid, wmma_block, wmma_smem, stream, + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + has_bias ? gpu_ptr(*biases) : nullptr, + gpu_ptr(sorted_li_arr), + gpu_ptr(sorted_ri_arr), + gpu_ptr(run_starts_arr), + gpu_ptr(run_lengths_arr), + gpu_ptr(perm_arr), + gpu_ptr(out), + B, M, N, K, E, has_bias, x_bs); + }); + return; + } + + // ---- Scalar prefill fallback ---- + int fast_threads_per_col_pf = select_qmv_threads_per_col(K, N, bits_, num_runs); + int fast_cols_per_block_pf = select_qmv_cols_per_block(K, N, bits_); + int max_cpb = rocm::kMaxThreadsPerBlock / fast_threads_per_col_pf; + while (fast_cols_per_block_pf > max_cpb) fast_cols_per_block_pf /= 2; + while (fast_cols_per_block_pf > 1 && (N % fast_cols_per_block_pf) != 0 && fast_cols_per_block_pf > 8) + fast_cols_per_block_pf /= 2; + + dim3 pf_block(fast_threads_per_col_pf, fast_cols_per_block_pf); + dim3 pf_grid(M, (N + fast_cols_per_block_pf - 1) / fast_cols_per_block_pf, num_runs); + + enc.launch_kernel([&](hipStream_t stream) { + auto launch_pf = [&](auto bits_tag) { + constexpr int BITS = decltype(bits_tag)::value; + hipLaunchKernelGGL( + (rocm::gather_qmv_prefill_kernel), + pf_grid, pf_block, 0, stream, + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + has_bias ? gpu_ptr(*biases) : nullptr, + gpu_ptr(sorted_li_arr), + gpu_ptr(sorted_ri_arr), + gpu_ptr(run_starts_arr), + gpu_ptr(run_lengths_arr), + gpu_ptr(perm_arr), + gpu_ptr(out), + B, M, N, K, E, has_bias, x_bs); + }; + if (bits_ == 4) launch_pf(std::integral_constant{}); + else launch_pf(std::integral_constant{}); + }); + return; + } + const void *x_ptr = gpu_ptr(x), *w_ptr = gpu_ptr(w), *scales_ptr = gpu_ptr(scales), *biases_ptr = has_bias ? gpu_ptr(*biases) : nullptr; const uint32_t *li_ptr = gpu_ptr(lhs_indices), *ri_ptr = gpu_ptr(rhs_indices); void* out_ptr = gpu_ptr(out); + + // GPU-only expert-batched kernel: when indices are sorted, each block finds + // its expert's token range on-GPU and processes them together. Weight data + // loaded once per expert column, reused across all tokens for that expert. + // max_unique_experts = min(B, E) is an upper bound on unique experts. + // Expert-batched kernel: beneficial when few experts have many tokens each. + // For high-expert-count models (E=512, top_k=10), most runs have 1-4 tokens, + // so the per-block run-finding overhead outweighs the shared weight benefit. + // Enable only when B/E is high enough (e.g., low expert count with long prompt). + bool use_expert_batched = transpose_ && right_sorted_ && (M == 1) && + (B >= 64) && (E > 0) && (E <= 64) && (B / E >= 4) && + mode_ == QuantizationMode::Affine && + x.dtype() == bfloat16 && group_size_ == 64 && (bits_ == 4 || bits_ == 8); + use_expert_batched = parse_warp_kernel_env( + "MLX_ROCM_GATHER_QMV_EXPERT_BATCHED", use_expert_batched); + + if (use_expert_batched) { + int max_unique_experts = std::min(B, E); + int eb_threads_per_col = select_qmv_threads_per_col(K, N, bits_, max_unique_experts); + int eb_cols_per_block = select_qmv_cols_per_block(K, N, bits_); + int eb_max_cpb = rocm::kMaxThreadsPerBlock / eb_threads_per_col; + while (eb_cols_per_block > eb_max_cpb) eb_cols_per_block /= 2; + while (eb_cols_per_block > 1 && (N % eb_cols_per_block) != 0 && eb_cols_per_block > 8) + eb_cols_per_block /= 2; + + dim3 eb_block(eb_threads_per_col, eb_cols_per_block); + dim3 eb_grid(M, (N + eb_cols_per_block - 1) / eb_cols_per_block, max_unique_experts); + + enc.launch_kernel([&](hipStream_t stream) { + auto launch_eb = [&](auto bits_tag) { + constexpr int BITS = decltype(bits_tag)::value; + hipLaunchKernelGGL( + (rocm::gather_qmv_expert_batched_kernel< + hip_bfloat16, hip_bfloat16, BITS, 64, true, 16>), + eb_grid, eb_block, 0, stream, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, ri_ptr, + (hip_bfloat16*)out_ptr, + B, M, N, K, E, has_bias, + use_sorted_rhs_schedule, implicit_x_batch_stride); + }; + if (bits_ == 4) launch_eb(std::integral_constant{}); + else launch_eb(std::integral_constant{}); + }); + return; + } + enc.launch_kernel([&](hipStream_t stream) { if (use_fast_gather_qmv && mode_ == QuantizationMode::Affine && x.dtype() == bfloat16 && group_size_ == 64 && - (bits_ == 6 || bits_ == 8)) { + (bits_ == 4 || bits_ == 6 || bits_ == 8)) { auto launch_fast_kernel = [&](auto bits_tag) { constexpr int BITS = decltype(bits_tag)::value; if (fast_threads_per_col == 16) { @@ -3769,7 +4897,9 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { } }; - if (bits_ == 6) { + if (bits_ == 4) { + launch_fast_kernel(std::integral_constant{}); + } else if (bits_ == 6) { launch_fast_kernel(std::integral_constant{}); } else { launch_fast_kernel(std::integral_constant{}); diff --git a/mlx/backend/rocm/quantized/qmv_kernel.hip b/mlx/backend/rocm/quantized/qmv_kernel.hip new file mode 100644 index 0000000000..c9c625d39a --- /dev/null +++ b/mlx/backend/rocm/quantized/qmv_kernel.hip @@ -0,0 +1,224 @@ +// Optimized quantized matrix-vector multiply (GEMV) kernel for RDNA 3.5. +// +// Each warp (32 threads) cooperatively computes ONE output element by +// iterating along the K dimension with coalesced uint32 loads. +// 8 warps per block → 8 output elements per block. +// +// Key optimizations vs naive kernel: +// 1. Coalesced global memory access (adjacent threads read adjacent words) +// 2. Vectorized uint32 loads (8 values per word for 4-bit) +// 3. Warp shuffle reduction (no shared memory needed for reduction) +// 4. LDS for x vector sharing across 8 warps in a block + +#include "mlx/backend/rocm/quantized/qdequant.hpp" +#include "mlx/backend/rocm/device/config.h" + +#include + +namespace mlx::core::rocm { + +// --------------------------------------------------------------------------- +// qmv_fast_kernel: Warp-cooperative quantized GEMV +// --------------------------------------------------------------------------- +// Grid: dim3(M, ceildiv(N, ROWS_PER_BLOCK)) +// Block: dim3(WARP_SIZE, ROWS_PER_BLOCK) = dim3(32, 8) = 256 threads +// +// Each warp (threadIdx.y selects the warp) computes one output element. +// All 32 lanes iterate over K together with coalesced weight loads. + +template +__global__ __launch_bounds__(256) +void qmv_fast_kernel( + const T* __restrict__ x, // [M, K] + const uint32_t* __restrict__ w, // [N, K/pack_factor_u32] as uint32 + const ScaleT* __restrict__ scales, // [N, K/GROUP_SIZE] + const ScaleT* __restrict__ biases, // [N, K/GROUP_SIZE] or nullptr + T* __restrict__ out, // [M, N] + int M, + int N, + int K, + bool has_bias) +{ + constexpr int PF = pack_factor_u32; // values per uint32 (8 for 4-bit) + constexpr int PPT = packs_per_thread; // uint32 loads per thread (2 for 4-bit) + constexpr int VPT = values_per_thread; // values per thread per step (16) + constexpr int BSK = VPT * WARP_SIZE; // K-elements per warp per step (512) + + const int m = blockIdx.x; // output row + const int n = blockIdx.y * ROWS_PER_BLOCK + threadIdx.y; // output column + const int lane = threadIdx.x; // lane within warp + const int tid = threadIdx.y * WARP_SIZE + threadIdx.x; // flat thread id + + // NOTE: Do NOT early-return here — all threads must participate in __syncthreads. + const bool valid = (m < M && n < N); + + // --- LDS for x vector (shared across all 8 warps) --- + __shared__ float x_shared[BSK]; + + // Per-warp pointers (safe even if n >= N: we just won't write output) + const int w_stride = K / PF; // number of uint32 per weight row + const int clamped_n = (n < N) ? n : 0; // clamp to avoid OOB on pointer setup + const uint32_t* w_row = w + clamped_n * w_stride; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const ScaleT* s_row = scales + clamped_n * num_groups; + const ScaleT* b_row = has_bias ? (biases + clamped_n * num_groups) : nullptr; + const T* x_row = x + m * K; + + float acc = 0.0f; + + for (int k_base = 0; k_base < K; k_base += BSK) { + // --- Cooperative load of x into LDS --- + // All 256 threads participate (including invalid ones) to avoid barrier mismatch. + __syncthreads(); + #pragma unroll + for (int i = tid; i < BSK; i += ROWS_PER_BLOCK * WARP_SIZE) { + int k = k_base + i; + x_shared[i] = (k < K) ? to_float(x_row[k]) : 0.0f; + } + __syncthreads(); + + if (!valid) continue; // Skip compute but still participate in barriers + + // --- Each lane loads its slice of x from LDS --- + float x_local[VPT]; + #pragma unroll + for (int i = 0; i < VPT; i++) { + x_local[i] = x_shared[lane * VPT + i]; + } + + // --- Coalesced weight load + dequant + accumulate --- + // Metal-compatible accumulation: separate integer dot product from scaling. + // We accumulate dot(x, q_int) and sum(x) across ALL packs in the same + // group, then apply: acc += scale * total_qdot + bias * total_xsum. + // This matches Metal's qdot() which computes scale*accum + sum*bias + // over all values_per_thread at once. + int w_offset = k_base / PF + lane * PPT; + + // Accumulate integer dot and x-sum across all packs (same group for all) + float group_qdot = 0.0f; + float group_xsum = 0.0f; + + // All PPT packs share the same group (thread's 16 values are contiguous) + int k_val = k_base + lane * VPT; + int group_idx = k_val / GROUP_SIZE; + + #pragma unroll + for (int p = 0; p < PPT; p++) { + uint32_t packed = (w_offset + p < w_stride) ? w_row[w_offset + p] : 0u; + dequant_and_dot(packed, &x_local[p * PF], group_qdot, group_xsum); + } + + // Apply scale and bias ONCE for the whole group (matches Metal) + float scale = (group_idx < num_groups) ? to_float(s_row[group_idx]) : 0.0f; + float bias = (has_bias && group_idx < num_groups) ? to_float(b_row[group_idx]) : 0.0f; + acc += scale * group_qdot + bias * group_xsum; + } + + if (!valid) return; + + // --- Warp reduction --- + acc = warp_reduce_sum(acc); + + // --- Lane 0 writes output --- + if (lane == 0) { + out[m * N + n] = from_float(acc); + } +} + +// --------------------------------------------------------------------------- +// gather_qmv_fast_kernel: Warp-cooperative gather-based quantized GEMV +// --------------------------------------------------------------------------- +// Same as qmv_fast_kernel but with batch index indirection for MoE models. + +template +__global__ __launch_bounds__(256) +void gather_qmv_fast_kernel( + const T* __restrict__ x, // [LHS_B, M, K] + const uint32_t* __restrict__ w, // [E, N, K/pack_factor] as uint32 + const ScaleT* __restrict__ scales, // [E, N, K/GROUP_SIZE] + const ScaleT* __restrict__ biases, // [E, N, K/GROUP_SIZE] or nullptr + const uint32_t* __restrict__ lhs_indices, // [B] + const uint32_t* __restrict__ rhs_indices, // [B] + T* __restrict__ out, // [B, M, N] + int B, int M, int N, int K, int E, int LHS_B, + bool has_bias) +{ + constexpr int PF = pack_factor_u32; + constexpr int PPT = packs_per_thread; + constexpr int VPT = values_per_thread; + constexpr int BSK = VPT * WARP_SIZE; + + const int batch = blockIdx.z; + const int m = blockIdx.x; + const int n = blockIdx.y * ROWS_PER_BLOCK + threadIdx.y; + const int lane = threadIdx.x; + const int tid = threadIdx.y * WARP_SIZE + threadIdx.x; + + const bool valid = (batch < B && m < M && n < N); + + uint32_t lhs_idx = valid ? lhs_indices[batch] : 0; + uint32_t rhs_idx = valid ? rhs_indices[batch] : 0; + + // Clamp indices to valid range to prevent catastrophic OOB on corrupt data. + if (lhs_idx >= static_cast(LHS_B)) lhs_idx = 0; + if (rhs_idx >= static_cast(E)) rhs_idx = 0; + + __shared__ float x_shared[BSK]; + + const int w_stride = K / PF; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int clamped_n = (n < N) ? n : 0; + const uint32_t* w_row = w + rhs_idx * N * w_stride + clamped_n * w_stride; + const ScaleT* s_row = scales + rhs_idx * N * num_groups + clamped_n * num_groups; + const ScaleT* b_row = has_bias ? (biases + rhs_idx * N * num_groups + clamped_n * num_groups) : nullptr; + const T* x_row = x + lhs_idx * M * K + m * K; + + float acc = 0.0f; + + for (int k_base = 0; k_base < K; k_base += BSK) { + __syncthreads(); + #pragma unroll + for (int i = tid; i < BSK; i += ROWS_PER_BLOCK * WARP_SIZE) { + int k = k_base + i; + x_shared[i] = (k < K && valid) ? to_float(x_row[k]) : 0.0f; + } + __syncthreads(); + + if (!valid) continue; + + float x_local[VPT]; + #pragma unroll + for (int i = 0; i < VPT; i++) { + x_local[i] = x_shared[lane * VPT + i]; + } + + int w_offset = k_base / PF + lane * PPT; + + // Accumulate integer dot and x-sum across all packs (same group) + float group_qdot = 0.0f; + float group_xsum = 0.0f; + + int k_val = k_base + lane * VPT; + int group_idx = k_val / GROUP_SIZE; + + #pragma unroll + for (int p = 0; p < PPT; p++) { + uint32_t packed = (w_offset + p < w_stride) ? w_row[w_offset + p] : 0u; + dequant_and_dot(packed, &x_local[p * PF], group_qdot, group_xsum); + } + + float scale = (group_idx < num_groups) ? to_float(s_row[group_idx]) : 0.0f; + float bias = (has_bias && group_idx < num_groups) ? to_float(b_row[group_idx]) : 0.0f; + acc += scale * group_qdot + bias * group_xsum; + } + + if (!valid) return; + + acc = warp_reduce_sum(acc); + + if (lane == 0) { + out[batch * M * N + m * N + n] = from_float(acc); + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/rms_norm.hip b/mlx/backend/rocm/rms_norm.hip index 38aa0b5ba7..c54c882f2f 100644 --- a/mlx/backend/rocm/rms_norm.hip +++ b/mlx/backend/rocm/rms_norm.hip @@ -79,16 +79,20 @@ __global__ void rms_norm_kernel( shared_sum[0] = normalizer; } __syncthreads(); - normalizer = rsqrtf(shared_sum[0] / axis_size + eps); + // Use 1/sqrt instead of rsqrtf for IEEE-compliant precision + // (matches Metal's metal::precise::rsqrt behavior) + normalizer = 1.0f / sqrtf(shared_sum[0] / axis_size + eps); // Write output + // Match Metal's weight application order: w * T(x * normalizer) + // Weight multiply in output type T after truncation, not in float32 for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { #pragma unroll for (int j = 0; j < N_READS && i + j < axis_size; ++j) { int idx = i + j; - float y = static_cast(x[idx]) * normalizer; - float wi = (w_stride == 0) ? static_cast(w[0]) : static_cast(w[idx * w_stride]); - out[idx] = static_cast(wi * y); + T normalized = static_cast(static_cast(x[idx]) * normalizer); + T wi = (w_stride == 0) ? w[0] : w[idx * w_stride]; + out[idx] = wi * normalized; } } } @@ -150,7 +154,9 @@ __global__ void rms_norm_vjp_kernel( factors = shared_f2[0]; float meangwx = factors.x / axis_size; - float normalizer = rsqrtf(factors.y / axis_size + eps); + // Use 1/sqrt instead of rsqrtf for IEEE-compliant precision + // (matches Metal's metal::precise::rsqrt behavior) + float normalizer = 1.0f / sqrtf(factors.y / axis_size + eps); float normalizer3 = normalizer * normalizer * normalizer; // Write outputs diff --git a/mlx/backend/rocm/scaled_dot_product_attention.cpp b/mlx/backend/rocm/scaled_dot_product_attention.cpp index be033c148d..b472fc9e48 100644 --- a/mlx/backend/rocm/scaled_dot_product_attention.cpp +++ b/mlx/backend/rocm/scaled_dot_product_attention.cpp @@ -148,11 +148,12 @@ void ScaledDotProductAttention::eval_gpu( sdpa_flash(q, k, v, scale_, out, do_causal_, mask_arr, std::nullopt, s); } } else { - // Fallback: compute attention manually - // This path should rarely be hit due to use_fallback check + // This should not be reached — use_fallback() returns true for unsupported + // configs, causing the framework to decompose SDPA into basic GPU ops + // (matmul + softmax + matmul) before this primitive is created. throw std::runtime_error( - "SDPA configuration not supported by ROCm kernel. " - "Please use CPU fallback or adjust parameters."); + "[ScaledDotProductAttention::eval_gpu] Unsupported configuration reached. " + "This is a bug — use_fallback() should have returned true."); } } diff --git a/mlx/backend/rocm/scaled_dot_product_attention.hip b/mlx/backend/rocm/scaled_dot_product_attention.hip index 3a5f202329..5407172f10 100644 --- a/mlx/backend/rocm/scaled_dot_product_attention.hip +++ b/mlx/backend/rocm/scaled_dot_product_attention.hip @@ -5,6 +5,7 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/backend/rocm/device.h" #include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/device/utils.hpp" #include "mlx/backend/rocm/kernel_utils.hpp" #include "mlx/dtype_utils.h" @@ -179,7 +180,6 @@ __global__ void kernel_sdpav_1pass( U new_max = tile_reduce_max_32(max_score); U factor = exp2f(max_score - new_max); sum_exp_score = tile_reduce_sum_32(sum_exp_scores[lane_idx % BN] * factor); - sum_exp_score = sum_exp_score == 0 ? 0 : 1.0f / sum_exp_score; // Aggregate outputs across tiles #pragma unroll @@ -187,7 +187,8 @@ __global__ void kernel_sdpav_1pass( outputs[lane_idx][tile_idx] = o[i]; __syncthreads(); U ot = outputs[tile_idx][lane_idx] * factor; - o[i] = tile_reduce_sum_32(ot) * sum_exp_score; + o[i] = tile_reduce_sum_32(ot); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); __syncthreads(); } diff --git a/mlx/backend/rocm/sort.hip b/mlx/backend/rocm/sort.hip index df85b7e145..2f00ea9a01 100644 --- a/mlx/backend/rocm/sort.hip +++ b/mlx/backend/rocm/sort.hip @@ -7,6 +7,17 @@ #include "mlx/primitives.h" #include + +// Workaround: rocprim headers use placement new in __device__ code, +// which requires __device__ overloads of operator new/delete. +#ifdef __HIP_DEVICE_COMPILE__ +__device__ inline void* operator new(size_t, void* p) noexcept { return p; } +__device__ inline void* operator new[](size_t, void* p) noexcept { return p; } +__device__ inline void operator delete(void*, void*) noexcept {} +__device__ inline void operator delete[](void*, void*) noexcept {} +#endif + +#include #include #include @@ -34,11 +45,27 @@ __device__ __forceinline__ _Float16 nan_value<_Float16>() { return static_cast<_Float16>(__builtin_nanf("")); } +// __half may or may not be the same as _Float16 depending on HIP version. +// Provide explicit specialization via __float2half conversion. +template <> +__device__ __forceinline__ __half nan_value<__half>() { + return __float2half(__builtin_nanf("")); +} + template <> __device__ __forceinline__ hip_bfloat16 nan_value() { return hip_bfloat16(__builtin_nanf("")); } +// Helper trait: true for all floating-point types including __half and hip_bfloat16. +// std::is_floating_point_v is false for __half and hip_bfloat16, which would +// cause NaN handling to be skipped and produce incorrect sort results. +template +inline constexpr bool is_sort_floating_v = + std::is_floating_point_v || + std::is_same_v || + std::is_same_v; + template struct InitValue { __device__ __forceinline__ static T value() { @@ -47,7 +74,7 @@ struct InitValue { }; template -struct InitValue>> { +struct InitValue>> { __device__ __forceinline__ static T value() { return nan_value(); } @@ -67,7 +94,7 @@ struct LessThan { } __device__ __forceinline__ bool operator()(T a, T b) const { - if constexpr (std::is_floating_point_v) { + if constexpr (is_sort_floating_v) { bool an = isnan(static_cast(a)); bool bn = isnan(static_cast(b)); if (an | bn) { @@ -292,7 +319,8 @@ struct KernelMergeSort { block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis); __syncthreads(); - for (int i = threadIdx.x; i < size_sorted_axis; i += BLOCK_THREADS) { + int out_limit = min(size_sorted_axis, N_PER_BLOCK); + for (int i = threadIdx.x; i < out_limit; i += BLOCK_THREADS) { if constexpr (ARG_SORT) { out[i * out_stride_sorted_axis] = tgp_idxs[i]; } else { @@ -349,6 +377,15 @@ __global__ void block_sort_kernel( } } +// Simple iota kernel: fills output[i] = i for i in [0, n). +// Used to initialize index arrays on-device instead of copying from host. +__global__ void iota_kernel(uint32_t* out, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + out[i] = static_cast(i); + } +} + } // namespace rocm namespace { @@ -386,8 +423,133 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { auto& stream = encoder.stream(); - // Determine block size + // For large arrays that exceed the block sort capacity (512 threads * 8 items = 4096), + // use rocprim radix sort which handles arbitrary sizes correctly. constexpr int tn = N_PER_THREAD; + constexpr int max_block_sort_size = 512 * tn; // 4096 + + if (size_sorted_axis > max_block_sort_size) { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + if constexpr (!std::is_same_v) { + using ValT = hip_type_t; + + encoder.launch_kernel([&](hipStream_t hip_stream) { + int N = size_sorted_axis; + + if (argsort) { + // Allocate all temp buffers once, outside the row loop. + uint32_t* indices_in = nullptr; + uint32_t* indices_out = nullptr; + ValT* vals_tmp = nullptr; + ValT* vals_sorted = nullptr; + CHECK_HIP_ERROR(hipMalloc(&indices_in, N * sizeof(uint32_t))); + CHECK_HIP_ERROR(hipMalloc(&indices_out, N * sizeof(uint32_t))); + CHECK_HIP_ERROR(hipMalloc(&vals_tmp, N * sizeof(ValT))); + CHECK_HIP_ERROR(hipMalloc(&vals_sorted, N * sizeof(ValT))); + + // Query temp storage size (same for all rows with same N). + size_t temp_bytes = 0; + rocprim::radix_sort_pairs( + nullptr, temp_bytes, + vals_tmp, vals_sorted, + indices_in, indices_out, + N, 0, sizeof(ValT) * 8, hip_stream); + + void* temp_storage = nullptr; + CHECK_HIP_ERROR(hipMalloc(&temp_storage, temp_bytes)); + + // Initialize iota indices on device (avoids host vector + memcpy). + { + int block = 256; + int grid = (N + block - 1) / block; + hipLaunchKernelGGL( + rocm::iota_kernel, dim3(grid), dim3(block), 0, hip_stream, + indices_in, N); + } + + for (int row = 0; row < n_rows; ++row) { + const ValT* in_row = in.data() + row * N; + + // Copy input values to mutable buffer for rocprim. + CHECK_HIP_ERROR(hipMemcpyAsync(vals_tmp, in_row, + N * sizeof(ValT), hipMemcpyDeviceToDevice, hip_stream)); + + // Re-initialize indices for each row (iota is idempotent so + // we can re-use the same buffer if we reset it). + if (row > 0) { + hipLaunchKernelGGL( + rocm::iota_kernel, dim3((N + 255) / 256), dim3(256), + 0, hip_stream, indices_in, N); + } + + rocprim::radix_sort_pairs( + temp_storage, temp_bytes, + vals_tmp, vals_sorted, + indices_in, indices_out, + N, 0, sizeof(ValT) * 8, hip_stream); + + // Copy result indices to output. + uint32_t* out_row = out.data() + row * N; + CHECK_HIP_ERROR(hipMemcpyAsync(out_row, indices_out, + N * sizeof(uint32_t), hipMemcpyDeviceToDevice, hip_stream)); + } + + CHECK_HIP_ERROR(hipFree(indices_in)); + CHECK_HIP_ERROR(hipFree(indices_out)); + CHECK_HIP_ERROR(hipFree(vals_tmp)); + CHECK_HIP_ERROR(hipFree(vals_sorted)); + CHECK_HIP_ERROR(hipFree(temp_storage)); + } else { + // Sort values only -- allocate once outside loop. + ValT* vals_in = nullptr; + ValT* vals_out_buf = nullptr; + CHECK_HIP_ERROR(hipMalloc(&vals_in, N * sizeof(ValT))); + CHECK_HIP_ERROR(hipMalloc(&vals_out_buf, N * sizeof(ValT))); + + size_t temp_bytes = 0; + rocprim::radix_sort_keys( + nullptr, temp_bytes, + vals_in, vals_out_buf, + N, 0, sizeof(ValT) * 8, hip_stream); + + void* temp_storage = nullptr; + CHECK_HIP_ERROR(hipMalloc(&temp_storage, temp_bytes)); + + for (int row = 0; row < n_rows; ++row) { + const ValT* in_row = in.data() + row * N; + + CHECK_HIP_ERROR(hipMemcpyAsync(vals_in, in_row, + N * sizeof(ValT), hipMemcpyDeviceToDevice, hip_stream)); + + rocprim::radix_sort_keys( + temp_storage, temp_bytes, + vals_in, vals_out_buf, + N, 0, sizeof(ValT) * 8, hip_stream); + + ValT* out_row = out.data() + row * N; + CHECK_HIP_ERROR(hipMemcpyAsync(out_row, vals_out_buf, + N * sizeof(ValT), hipMemcpyDeviceToDevice, hip_stream)); + } + + CHECK_HIP_ERROR(hipFree(vals_in)); + CHECK_HIP_ERROR(hipFree(vals_out_buf)); + CHECK_HIP_ERROR(hipFree(temp_storage)); + } + }); + } else { + throw std::runtime_error( + "ROCm backend does not support sorting complex numbers"); + } + }); + + if (!is_segmented_sort) { + copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s); + } + return; + } + + // Determine block size for small-array block sort int potential_bn = (size_sorted_axis + tn - 1) / tn; int bn; if (potential_bn > 256) {