diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index 9a90fe48cd..a7468991db 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -9,6 +9,7 @@ "gfx942": {"fp8": torch.float8_e4m3fnuz}, "gfx950": {"fp8": torch.float8_e4m3fn}, "gfx1250": {"fp8": torch.float8_e4m3fn}, + "gfx1201": {"fp8": torch.float8_e4m3fn}, } _8bit_fallback = torch.uint8 diff --git a/csrc/include/ck_tile/vec_convert.h b/csrc/include/ck_tile/vec_convert.h index aaabcd3508..e8c954c799 100644 --- a/csrc/include/ck_tile/vec_convert.h +++ b/csrc/include/ck_tile/vec_convert.h @@ -43,6 +43,14 @@ CK_TILE_DEVICE fp32x2_v amd_assembly_pk_mul_f32(fp32x2_v a, fp32x2_t b) asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); return c; } +// use scalar math for RDNA4/3 without v_pk_mul_f32 +CK_TILE_DEVICE fp32x2_v amd_scalar_mul_f32(fp32x2_v a, fp32x2_t b) +{ + fp32x2_v c; + c[0] = a[0] * b[0]; + c[1] = a[1] * b[1]; + return c; +} CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_fp8_f32(fp32_t a, fp32_t b) { int16x2_t c; @@ -145,8 +153,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv using vec_ti = vector_traits; constexpr int vec_size = vec_ti::vector_size; constexpr auto interpret = numeric_traits::f8_interpret; - fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); - + fp32x2_v tmp; +#if defined(__gfx11__) || defined(__gfx12__) + tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#else + tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#endif return (interpret == fp8_interpretation::E4M3_FNUZ) || (interpret == fp8_interpretation::E4M3_OCP) ? amd_assembly_cvt_pk_fp8_f32(tmp[0], tmp[1]) @@ -155,7 +167,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv // fp32x2 -> int8x2 CK_TILE_HOST_DEVICE constexpr int8x2_v fp32x2_t_to_int8x2_t(fp32x2_v x, fp32_t inverted_scale) { - fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); + fp32x2_v tmp; +#if defined(__gfx11__) || defined(__gfx12__) + tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#else + tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#endif int8x2_v out; out[0] = static_cast(tmp[0]); diff --git a/csrc/include/hip_reduce.h b/csrc/include/hip_reduce.h index 79a88aec3e..cab6ee69db 100644 --- a/csrc/include/hip_reduce.h +++ b/csrc/include/hip_reduce.h @@ -112,14 +112,28 @@ __device__ constexpr T wave_reduce(T local, F reduce_op) if constexpr(WarpSize > 16) { +// DPP broadcasts (0x142, 0x143) are not supported on GFX10+ (gfx12 included) +// Use ds_bpermute instead for cross-lane communication +#if defined(__gfx12__) || defined(__gfx11__) + // Use shuffle for gfx12 instead of DPP broadcast + T v_remote = rocprim::warp_shuffle(local, 15, WarpSize); + local = reduce_op(v_remote, local); +#else // row_bcast:15 local = reduce_op(rocprim::detail::warp_move_dpp(local), local); +#endif } if constexpr(WarpSize > 32) { +#if defined(__gfx12__) || defined(__gfx11__) + // Use shuffle for gfx12 instead of DPP broadcast + T v_remote = rocprim::warp_shuffle(local, 31, WarpSize); + local = reduce_op(v_remote, local); +#else // row_bcast:31 local = reduce_op(rocprim::detail::warp_move_dpp(local), local); +#endif } if constexpr(threadBroadcast && WarpSize > 4) @@ -166,7 +180,12 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num) data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#if defined(__gfx12__) || defined(__gfx11__) + // DPP broadcast 0x142 not supported on gfx12, use shuffle + data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data); +#else data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#endif if constexpr(threadBroadcast) { data = rocprim::warp_shuffle(data, thread_num - 1, thread_num); @@ -179,8 +198,14 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num) data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#if defined(__gfx12__) || defined(__gfx11__) + // DPP broadcasts not supported on gfx12, use shuffle + data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data); + data = reduce_op(rocprim::warp_shuffle(data, 31, WarpSize), data); +#else data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#endif if constexpr(threadBroadcast) { data = rocprim::warp_shuffle(data, thread_num - 1, thread_num); diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 5c28742138..b0b58a1412 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -500,12 +500,21 @@ __global__ void smooth_per_token_scaled_quant_kernel(DTYPE_O* __restrict__ out, #pragma unroll for(int i = 0; i < async_load_num; i++) { - // buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size); - const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); - uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); - asm volatile( "s_mov_b32 m0 %0\n\t" + #if defined(__gfx12__) + int idx = threadIdx.x + i * block_size; + if(idx < smooth_scale_map_hash_size) + { + // RDNA4 doesn't support buffer_load_* with LDS modifier + // Use standard global load to VGPR then write to LDS + smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx]; + } + #else + const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); + uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); + asm volatile( "s_mov_b32 m0 %0\n\t" "buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t" ::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0"); + #endif } } @@ -1210,12 +1219,21 @@ __global__ void moe_smooth_per_token_scaled_quant_kernel_v1(DTYPE_O* __restrict_ #pragma unroll for(int i = 0; i < async_load_num; i++) { - // buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size); - const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); - uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); - asm volatile( "s_mov_b32 m0 %0\n\t" + #if defined(__gfx12__) + int idx = threadIdx.x + i * block_size; + if(idx < smooth_scale_map_hash_size) + { + // RDNA4 doesn't support buffer_load_* with LDS modifier + // Use standard global load to VGPR then write to LDS + smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx]; + } + #else + const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); + uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); + asm volatile( "s_mov_b32 m0 %0\n\t" "buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t" ::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0"); + #endif } } int smscale_map_idx_list = 0;