diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp index 5983ca1ea6a..96047c232c4 100644 --- a/src/cpu/cpu_inner_product_list.cpp +++ b/src/cpu/cpu_inner_product_list.cpp @@ -134,13 +134,19 @@ const std::map> &impl_list_map() CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, }}, + // VNNI instance handles the BF16 source dynamic-quantization path: + // BF16 src is quantized to s8 and dispatched to the VNNI s8 GEMM. + // Plain BF16 (no src dyn-quant) is rejected by init_conf and falls + // through to avx512_core_bf16. {{forward, bf16, u8, f32}, { CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) nullptr, }}, {{forward, bf16, u8, bf16}, { CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) nullptr, }}, @@ -184,13 +190,18 @@ const std::map> &impl_list_map() CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) nullptr, }}, + // VNNI instance here serves the BF16 source dynamic-quantization path + // (BF16 src -> s8 -> VNNI GEMM). See the {bf16, u8, *} entries above + // for the dispatch details. {{forward, bf16, u4, f32}, { CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) nullptr, }}, {{forward, bf16, u4, bf16}, { CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) nullptr, }}, diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index 8a7cf236362..508d57821fb 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -97,15 +97,23 @@ status_t brgemm_inner_product_fwd_t::execute_forward( const void *dst_scales = CTX_IN_MEM(const void *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); - DEFINE_ZERO_POINTS_BUFFER_ATTR_U8(pd()->attr(), wei_zero_points, DNNL_ARG_WEIGHTS); - auto wei_scales = reinterpret_cast(wei_scales_f); - - const auto wei_scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); - const auto wei_zero_points_d = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS); + DEFINE_ZERO_POINTS_BUFFER_ATTR_U8( + pd()->attr(), wei_zero_points, DNNL_ARG_WEIGHTS); + auto wei_scales = reinterpret_cast(wei_scales_f); + + const auto wei_scales_d + = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); + const auto wei_zero_points_d + = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS); int wei_scales_oc_stride = wei_scales_d.dims()[0] > 1 ? 1 : 0; int wei_zero_points_oc_stride = wei_zero_points_d.dims()[0] > 1 ? 1 : 0; - size_t wei_scales_dt_size = jbgp.wei_decomp_scales_dt == data_type::undef ? 0 : types::data_type_size(jbgp.wei_decomp_scales_dt); - size_t wei_zero_points_dt_size = jbgp.wei_decomp_zero_points_dt == data_type::undef ? 0 : types::data_type_size(jbgp.wei_decomp_zero_points_dt); + size_t wei_scales_dt_size = jbgp.wei_decomp_scales_dt == data_type::undef + ? 0 + : types::data_type_size(jbgp.wei_decomp_scales_dt); + size_t wei_zero_points_dt_size + = jbgp.wei_decomp_zero_points_dt == data_type::undef + ? 0 + : types::data_type_size(jbgp.wei_decomp_zero_points_dt); if (jbgp.weights_decompression) { // weights decompression algorithm requires weights scales to be // applied before matmul to avoid huge numerical errors @@ -113,17 +121,27 @@ status_t brgemm_inner_product_fwd_t::execute_forward( // decompression algorithm assumes scales/zero_points buffers are aligned on oc_block size if (jbgp.oc % jbgp.simd_w != 0) { - if (!pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) { - auto dims = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).get_dims(); - auto decomp_scales_buf = scratchpad.template get(key_decompression_scales); - std::memcpy(decomp_scales_buf, wei_scales, dims[0] * dims[1] * wei_scales_dt_size); + if (!pd()->attr() + ->scales_.get(DNNL_ARG_WEIGHTS) + .has_default_values()) { + auto dims = pd()->attr() + ->scales_.get(DNNL_ARG_WEIGHTS) + .get_dims(); + auto decomp_scales_buf = scratchpad.template get( + key_decompression_scales); + std::memcpy(decomp_scales_buf, wei_scales, + dims[0] * dims[1] * wei_scales_dt_size); wei_scales = decomp_scales_buf; } - if (!pd()->attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)) { - auto decomp_zp_buf = scratchpad.template get(key_decompression_zero_points); - auto dims = pd()->attr()->zero_points_.get_dims(DNNL_ARG_WEIGHTS); - std::memcpy(decomp_zp_buf, wei_zero_points, dims[0] * dims[1] * wei_zero_points_dt_size); + if (!pd()->attr()->zero_points_.has_default_values( + DNNL_ARG_WEIGHTS)) { + auto decomp_zp_buf = scratchpad.template get( + key_decompression_zero_points); + auto dims + = pd()->attr()->zero_points_.get_dims(DNNL_ARG_WEIGHTS); + std::memcpy(decomp_zp_buf, wei_zero_points, + dims[0] * dims[1] * wei_zero_points_dt_size); wei_zero_points = decomp_zp_buf; } } @@ -135,50 +153,81 @@ status_t brgemm_inner_product_fwd_t::execute_forward( const size_t acc_dt_size = types::data_type_size(jbgp.acc_dt); const size_t dst_dt_size = types::data_type_size(jbgp.dst_dt); - int8_t* qsrc = nullptr; - float* src_dscales = nullptr; - int32_t* src_grouped_sum = nullptr; + int8_t *qsrc = nullptr; + float *src_dscales = nullptr; + int32_t *src_grouped_sum = nullptr; if (jbgp.with_src_dynamic_quant) { qsrc = scratchpad.template get(key_src_quantized); - src_dscales = scratchpad.template get(key_src_dequantized_scales); + src_dscales + = scratchpad.template get(key_src_dequantized_scales); src_grouped_sum = scratchpad.template get(key_src_grouped_sum); int ic_groups = div_up(jbgp.ic, jbgp.src_quant_group_size); int ic_sum_groups = div_up(jbgp.ic, jbgp.src_sum_group_size); - auto src_ptr = reinterpret_cast(src); + const size_t orig_src_dt_size = types::data_type_size(jbgp.orig_src_dt); + auto src_ptr_f32 = reinterpret_cast(src); + auto src_ptr_bf16 = reinterpret_cast(src); auto qsrc_ptr = qsrc; auto src_dscales_ptr = src_dscales; auto src_grouped_sum_ptr = src_grouped_sum; int vec_loop_end = rnd_dn(jbgp.ic, jbgp.src_quant_group_size); + // orig_src_dt is loop-invariant across all mb iterations; resolve it once + // here so the tail loops below stay branch-free. + const bool is_bf16_src = (jbgp.orig_src_dt == data_type::bf16); + parallel_nd(jbgp.mb, [&](int mb) { src_quantization_runtime_params_t rt_params = {}; - rt_params.src_ptr = src_ptr + mb * jbgp.ic; + rt_params.src_ptr = src + mb * jbgp.ic * orig_src_dt_size; rt_params.qsrc_ptr = qsrc_ptr + mb * jbgp.ic; rt_params.src_scales_ptr = src_dscales_ptr + mb * ic_groups; - rt_params.src_grouped_sum_ptr = src_grouped_sum_ptr + mb * ic_sum_groups; + rt_params.src_grouped_sum_ptr + = src_grouped_sum_ptr + mb * ic_sum_groups; rt_params.ic_size = vec_loop_end; (*brg_src_quant_kernel_)(&rt_params); + // Scalar tail for ic elements not covered by the JIT kernel if (vec_loop_end != jbgp.ic) { float amax = 0; - for (int ic = vec_loop_end; ic < jbgp.ic; ic++) { - amax = std::max(amax, std::abs(src_ptr[mb * jbgp.ic + ic])); + if (is_bf16_src) { + for (int ic = vec_loop_end; ic < jbgp.ic; ic++) { + amax = std::max(amax, + std::abs(static_cast( + src_ptr_bf16[mb * jbgp.ic + ic]))); + } + } else { + for (int ic = vec_loop_end; ic < jbgp.ic; ic++) { + amax = std::max( + amax, std::abs(src_ptr_f32[mb * jbgp.ic + ic])); + } } const float dscale = amax / 127; - const float qscale = (dscale != 0) ? (1.0f / dscale) : 0; + const float qscale = (dscale != 0) ? (1.0f / dscale) : 0; src_dscales_ptr[mb * ic_groups + ic_groups - 1] = dscale; - for (int ic = vec_loop_end; ic < jbgp.ic; ic++) { - qsrc_ptr[mb * jbgp.ic + ic] = std::round(src_ptr[mb * jbgp.ic + ic] * qscale); + if (is_bf16_src) { + for (int ic = vec_loop_end; ic < jbgp.ic; ic++) { + qsrc_ptr[mb * jbgp.ic + ic] = std::round( + static_cast( + src_ptr_bf16[mb * jbgp.ic + ic]) + * qscale); + } + } else { + for (int ic = vec_loop_end; ic < jbgp.ic; ic++) { + qsrc_ptr[mb * jbgp.ic + ic] = std::round( + src_ptr_f32[mb * jbgp.ic + ic] * qscale); + } } } if (jbgp.wei_decomp_zero_points_dt) { - for (int icb = vec_loop_end / jbgp.src_quant_group_size; icb < ic_sum_groups; icb++) { + for (int icb = vec_loop_end / jbgp.src_quant_group_size; + icb < ic_sum_groups; icb++) { int ic_begin = icb * jbgp.src_sum_group_size; - int ic_end = nstl::min(static_cast((icb + 1) * jbgp.src_sum_group_size), jbgp.ic); + int ic_end = nstl::min(static_cast((icb + 1) + * jbgp.src_sum_group_size), + jbgp.ic); int sum = 0; for (int ic = ic_begin; ic < ic_end; ic++) { sum += qsrc_ptr[mb * jbgp.ic + ic]; @@ -202,10 +251,11 @@ status_t brgemm_inner_product_fwd_t::execute_forward( const bool is_amx = jbgp.is_amx; auto wsp_tile_base = is_amx ? ctx.get_scratchpad_grantor().template get( - key_conv_amx_tile_buffer) + key_conv_amx_tile_buffer) : nullptr; - auto decomp_buf_global = (jbgp.weights_compressed || jbgp.weights_decompression) + auto decomp_buf_global + = (jbgp.weights_compressed || jbgp.weights_decompression) ? scratchpad.template get(key_brgemm_primitive_decomp_buf) : nullptr; @@ -335,7 +385,8 @@ status_t brgemm_inner_product_fwd_t::execute_forward( const int ic_blocks_per_batch = div_up(jbgp.K, jbgp.ic_block); const dim_t wei_cur_ocb = blk_off(weights_d, cur_ocb, 0, kd, kh, kw) - / types::data_type_size(weights_d.data_type()) * types::data_type_size(jbgp.wei_dt); + / types::data_type_size(weights_d.data_type()) + * types::data_type_size(jbgp.wei_dt); // weights_d & jbgp.wei_dt has different data size // printf("kd %ld, kh %ld, kw %ld weights_d %ld true_size %ld wei_cur_ocb %ld\n", kd, kh, kw, // weights_d.data_type(), jbgp.wei_dt, wei_cur_ocb); @@ -369,7 +420,8 @@ status_t brgemm_inner_product_fwd_t::execute_forward( if (jbgp.weights_compressed) { using comp_tile_len_type = int; const comp_tile_len_type *compressed_tile_lengths_ptr - = reinterpret_cast(weights); + = reinterpret_cast( + weights); int compressed_weights_offset = wei_offset / 4096; auto dcomp_params = brgemm_decomp_kernel_params_t(); @@ -377,63 +429,108 @@ status_t brgemm_inner_product_fwd_t::execute_forward( + compressed_tile_lengths_ptr [compressed_weights_offset] * 64; - dcomp_params.bitmask_ptr - = weights + jbgp.weight_comp_bitmask_off + wei_offset / 8; + dcomp_params.bitmask_ptr = weights + + jbgp.weight_comp_bitmask_off + wei_offset / 8; const size_t decomp_buf_per_thr = (size_t)jbgp.ic * 64; - auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr; + auto decomp_buf + = decomp_buf_global + ithr * decomp_buf_per_thr; dcomp_params.scratch_buf = decomp_buf; (*brg_decomp_kernel_)(&dcomp_params); addr_batch[b].ptr.B = decomp_buf; - } else if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t::prepack) { + } else if (jbgp.weights_decompression + && jbgp.wei_decomp_algo + == weights_decomp_kind_t::prepack) { int typesize_scale = [&] { if (jbgp.orig_wei_dt == data_type::u2) { return 4; - } else if (one_of(jbgp.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) { + } else if (one_of(jbgp.orig_wei_dt, data_type::nf4, + data_type::s4, data_type::u4, + data_type::f4_e2m1)) { return 2; } else { return 1; } - } (); - auto w_off = wei_offset * types::data_type_size(jbgp.orig_wei_dt) / types::data_type_size(jbgp.wei_dt) / typesize_scale; - auto weights_ptr = reinterpret_cast(&weights[w_off]); - - const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size(jbgp.wei_dt); - auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr + wei_ic_stride * b * ic_blocks_per_batch; + }(); + auto w_off = wei_offset + * types::data_type_size(jbgp.orig_wei_dt) + / types::data_type_size(jbgp.wei_dt) + / typesize_scale; + auto weights_ptr = reinterpret_cast( + &weights[w_off]); + + const size_t decomp_buf_per_thr = jbgp.ic_block + * jbgp.nb_ic_blocking * jbgp.oc_block + * types::data_type_size(jbgp.wei_dt); + auto decomp_buf = decomp_buf_global + + ithr * decomp_buf_per_thr + + wei_ic_stride * b * ic_blocks_per_batch; const int ic_internal_block = [&] { if (pd()->jbgp_.orig_wei_dt == data_type::u2) { return 4; } else if (pd()->jbgp_.wei_dt == data_type::bf16) { return 2; - } else if (one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) { + } else if (one_of(pd()->jbgp_.orig_wei_dt, + data_type::nf4, data_type::s4, + data_type::u4, data_type::f4_e2m1)) { return 2; } else { return 1; } - } (); - auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; - auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc * wei_scales_dt_size; + }(); + auto wei_zero_points_ptr = wei_zero_points + + wei_zero_points_oc_stride * oc + * wei_zero_points_dt_size; + auto wei_scales_ptr = wei_scales + + wei_scales_oc_stride * oc * wei_scales_dt_size; if (jbgp.with_grouped_weights_decompression) { weights_decompression_runtime_params_t rt_params = {}; - auto ic_size = jbgp.ic_block * ic_blocks_per_batch / ic_internal_block; - auto wei_scales_ic_group_size_local = jbgp.wei_scales_ic_group_size / ic_internal_block; - auto wei_zero_points_ic_group_size_local = jbgp.wei_zero_points_ic_group_size / ic_internal_block; - auto group_size = nstl::min(wei_scales_ic_group_size_local, wei_zero_points_ic_group_size_local); + auto ic_size = jbgp.ic_block * ic_blocks_per_batch + / ic_internal_block; + auto wei_scales_ic_group_size_local + = jbgp.wei_scales_ic_group_size + / ic_internal_block; + auto wei_zero_points_ic_group_size_local + = jbgp.wei_zero_points_ic_group_size + / ic_internal_block; + auto group_size + = nstl::min(wei_scales_ic_group_size_local, + wei_zero_points_ic_group_size_local); auto group_ic_blocks = div_up(ic_size, group_size); - auto start_group_scales = ic / jbgp.wei_scales_ic_group_size; - auto start_group_zero_points = ic / jbgp.wei_zero_points_ic_group_size; - for (int icb_idx = 0; icb_idx < group_ic_blocks; icb_idx++) { + auto start_group_scales + = ic / jbgp.wei_scales_ic_group_size; + auto start_group_zero_points + = ic / jbgp.wei_zero_points_ic_group_size; + for (int icb_idx = 0; icb_idx < group_ic_blocks; + icb_idx++) { auto ic_idx = icb_idx * group_size; - auto scales_idx = ic_idx / wei_scales_ic_group_size_local + start_group_scales; - auto zero_points_idx = ic_idx / wei_zero_points_ic_group_size_local + start_group_zero_points; - - rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.orig_wei_dt) / typesize_scale; - rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block *jbgp.oc_block * types::data_type_size(jbgp.wei_dt); - rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims()[0] * wei_scales_dt_size; - rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims()[0] * wei_zero_points_dt_size; - rt_params.ic_size = nstl::min(group_size, ic_size - icb_idx * group_size); + auto scales_idx + = ic_idx / wei_scales_ic_group_size_local + + start_group_scales; + auto zero_points_idx = ic_idx + / wei_zero_points_ic_group_size_local + + start_group_zero_points; + + rt_params.weights_ptr = weights_ptr + + ic_idx * ic_internal_block * jbgp.oc_block + * types::data_type_size( + jbgp.orig_wei_dt) + / typesize_scale; + rt_params.decomp_buffer_ptr = decomp_buf + + ic_idx * ic_internal_block * jbgp.oc_block + * types::data_type_size( + jbgp.wei_dt); + rt_params.scales_ptr = wei_scales_ptr + + scales_idx * wei_scales_d.dims()[0] + * wei_scales_dt_size; + rt_params.zero_points_ptr = wei_zero_points_ptr + + zero_points_idx + * wei_zero_points_d.dims()[0] + * wei_zero_points_dt_size; + rt_params.ic_size = nstl::min( + group_size, ic_size - icb_idx * group_size); (*brg_weights_decomp_kernel_)(&rt_params); } } else { @@ -442,7 +539,8 @@ status_t brgemm_inner_product_fwd_t::execute_forward( rt_params.decomp_buffer_ptr = decomp_buf; rt_params.scales_ptr = wei_scales_ptr; rt_params.zero_points_ptr = wei_zero_points_ptr; - rt_params.ic_size = jbgp.ic_block * ic_blocks_per_batch / ic_internal_block; + rt_params.ic_size = jbgp.ic_block * ic_blocks_per_batch + / ic_internal_block; (*brg_weights_decomp_kernel_)(&rt_params); } @@ -451,12 +549,14 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int typesize_scale = [&] { if (jbgp.wei_dt == data_type::u2) { return 4; - } else if (one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) { + } else if (one_of(jbgp.wei_dt, data_type::nf4, + data_type::s4, data_type::u4, + data_type::f4_e2m1)) { return 2; } else { return 1; } - } (); + }(); addr_batch[b].ptr.B = weights + wei_offset / typesize_scale; } } @@ -466,10 +566,14 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int src_scales_offset = 0; int src_grouped_sum_offset = 0; if (jbgp.weights_decompression) { - wei_scales_offset = wei_scales_oc_stride * oc * wei_scales_dt_size; - wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; - src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); - src_grouped_sum_offset = n * div_up(jbgp.ic, jbgp.src_sum_group_size); + wei_scales_offset + = wei_scales_oc_stride * oc * wei_scales_dt_size; + wei_zero_points_offset = wei_zero_points_oc_stride * oc + * wei_zero_points_dt_size; + src_scales_offset + = n * div_up(jbgp.ic, jbgp.src_quant_group_size); + src_grouped_sum_offset + = n * div_up(jbgp.ic, jbgp.src_sum_group_size); } auto ptr_D = dst + dst_off; @@ -479,9 +583,11 @@ status_t brgemm_inner_product_fwd_t::execute_forward( && is_last_ic_chunk && !is_ic_tail && last_spatial_slice) { void *scratch = is_amx ? static_cast(wsp_tile) - : (jbgp.req_s8s8_compensation ? static_cast( - const_cast(&compensation[oc])) - : nullptr); + : (jbgp.req_s8s8_compensation + ? static_cast( + const_cast( + &compensation[oc])) + : nullptr); auto ptr_bias = jbgp.with_bias ? bias + bia_dt_size * oc : nullptr; const brgemm_post_ops_data_t post_ops_data { @@ -489,19 +595,25 @@ status_t brgemm_inner_product_fwd_t::execute_forward( post_ops_binary_rhs_arg_vec.data(), static_cast(oc), 0, dst, 0, nullptr, nullptr, nullptr, false, 1, false, false, src_scales, - wei_scales_f ? reinterpret_cast(wei_scales_f) + wei_scales_f + ? reinterpret_cast(wei_scales_f) + jbgp.is_oc_scale * oc * sizeof(float) - : nullptr, + : nullptr, dst_scales_ptr}; brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data, - scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, - src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); + scratch, nullptr, wei_scales + wei_scales_offset, + wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, + src_grouped_sum + src_grouped_sum_offset, ic); } else { brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch, - (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, - src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); + (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, + nullptr, wei_scales + wei_scales_offset, + wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, + src_grouped_sum + src_grouped_sum_offset, ic); } } @@ -524,20 +636,28 @@ status_t brgemm_inner_product_fwd_t::execute_forward( const dim_t wei_offset = (wei_cur_ocb + wei_ic_stride * (icb + ic_block)); - if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t::prepack) { + if (jbgp.weights_decompression + && jbgp.wei_decomp_algo == weights_decomp_kind_t::prepack) { int typesize_scale = [&] { if (jbgp.orig_wei_dt == data_type::u2) { return 4; - } else if (one_of(jbgp.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) { + } else if (one_of(jbgp.orig_wei_dt, data_type::nf4, + data_type::s4, data_type::u4, + data_type::f4_e2m1)) { return 2; } else { return 1; } - } (); - auto w_off = wei_offset * types::data_type_size(jbgp.orig_wei_dt) / types::data_type_size(jbgp.wei_dt) / typesize_scale; - auto weights_ptr = reinterpret_cast(&weights[w_off]); - - const size_t decomp_buf_per_thr = jbgp.ic_block * jbgp.nb_ic_blocking * jbgp.oc_block * types::data_type_size(jbgp.wei_dt); + }(); + auto w_off = wei_offset + * types::data_type_size(jbgp.orig_wei_dt) + / types::data_type_size(jbgp.wei_dt) / typesize_scale; + auto weights_ptr + = reinterpret_cast(&weights[w_off]); + + const size_t decomp_buf_per_thr = jbgp.ic_block + * jbgp.nb_ic_blocking * jbgp.oc_block + * types::data_type_size(jbgp.wei_dt); auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr; const int ic_internal_block = [&] { @@ -545,34 +665,62 @@ status_t brgemm_inner_product_fwd_t::execute_forward( return 4; } else if (pd()->jbgp_.wei_dt == data_type::bf16) { return 2; - } else if (one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) { + } else if (one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, + data_type::s4, data_type::u4, + data_type::f4_e2m1)) { return 2; } else { return 1; } - } (); - auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; - auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc * wei_scales_dt_size; + }(); + auto wei_zero_points_ptr = wei_zero_points + + wei_zero_points_oc_stride * oc + * wei_zero_points_dt_size; + auto wei_scales_ptr = wei_scales + + wei_scales_oc_stride * oc * wei_scales_dt_size; if (jbgp.with_grouped_weights_decompression) { weights_decompression_runtime_params_t rt_params = {}; - auto ic_size = (jbgp.ic - (ic + ic_block * jbgp.ic_block)) / ic_internal_block; - auto wei_scales_ic_group_size_local = jbgp.wei_scales_ic_group_size / ic_internal_block; - auto wei_zero_points_ic_group_size_local = jbgp.wei_zero_points_ic_group_size / ic_internal_block; - auto group_size = nstl::min(wei_scales_ic_group_size_local, wei_zero_points_ic_group_size_local); + auto ic_size = (jbgp.ic - (ic + ic_block * jbgp.ic_block)) + / ic_internal_block; + auto wei_scales_ic_group_size_local + = jbgp.wei_scales_ic_group_size / ic_internal_block; + auto wei_zero_points_ic_group_size_local + = jbgp.wei_zero_points_ic_group_size + / ic_internal_block; + auto group_size = nstl::min(wei_scales_ic_group_size_local, + wei_zero_points_ic_group_size_local); auto group_ic_blocks = div_up(ic_size, group_size); - auto start_group_scales = ic / jbgp.wei_scales_ic_group_size; - auto start_group_zero_points = ic / jbgp.wei_zero_points_ic_group_size; - for (int icb_idx = 0; icb_idx < group_ic_blocks; icb_idx++) { + auto start_group_scales + = ic / jbgp.wei_scales_ic_group_size; + auto start_group_zero_points + = ic / jbgp.wei_zero_points_ic_group_size; + for (int icb_idx = 0; icb_idx < group_ic_blocks; + icb_idx++) { auto ic_idx = icb_idx * group_size; - auto scales_idx = ic_idx / wei_scales_ic_group_size_local + start_group_scales; - auto zero_points_idx = ic_idx / wei_zero_points_ic_group_size_local + start_group_zero_points; - - rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.orig_wei_dt) / typesize_scale; - rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.wei_dt); - rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims()[0] * wei_scales_dt_size; - rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims()[0] * wei_zero_points_dt_size; - rt_params.ic_size = nstl::min((int64_t)group_size, (int64_t)(ic_size - icb_idx * group_size)); + auto scales_idx + = ic_idx / wei_scales_ic_group_size_local + + start_group_scales; + auto zero_points_idx + = ic_idx / wei_zero_points_ic_group_size_local + + start_group_zero_points; + + rt_params.weights_ptr = weights_ptr + + ic_idx * ic_internal_block * jbgp.oc_block + * types::data_type_size( + jbgp.orig_wei_dt) + / typesize_scale; + rt_params.decomp_buffer_ptr = decomp_buf + + ic_idx * ic_internal_block * jbgp.oc_block + * types::data_type_size(jbgp.wei_dt); + rt_params.scales_ptr = wei_scales_ptr + + scales_idx * wei_scales_d.dims()[0] + * wei_scales_dt_size; + rt_params.zero_points_ptr = wei_zero_points_ptr + + zero_points_idx * wei_zero_points_d.dims()[0] + * wei_zero_points_dt_size; + rt_params.ic_size = nstl::min((int64_t)group_size, + (int64_t)(ic_size - icb_idx * group_size)); (*brg_weights_decomp_kernel_)(&rt_params); } } else { @@ -581,7 +729,9 @@ status_t brgemm_inner_product_fwd_t::execute_forward( rt_params.decomp_buffer_ptr = decomp_buf; rt_params.scales_ptr = wei_scales_ptr; rt_params.zero_points_ptr = wei_zero_points_ptr; - rt_params.ic_size = (jbgp.ic - (ic + ic_block * jbgp.ic_block)) / ic_internal_block; + rt_params.ic_size + = (jbgp.ic - (ic + ic_block * jbgp.ic_block)) + / ic_internal_block; (*brg_weights_decomp_kernel_)(&rt_params); } @@ -590,12 +740,14 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int typesize_scale = [&] { if (jbgp.wei_dt == data_type::u2) { return 4; - } else if (one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) { + } else if (one_of(jbgp.wei_dt, data_type::nf4, + data_type::s4, data_type::u4, + data_type::f4_e2m1)) { return 2; } else { return 1; } - } (); + }(); addr_batch[0].ptr.B = weights + wei_offset / typesize_scale; } @@ -604,10 +756,14 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int src_scales_offset = 0; int src_grouped_sum_offset = 0; if (jbgp.weights_decompression) { - wei_scales_offset = wei_scales_oc_stride * oc * wei_scales_dt_size; - wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; - src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); - src_grouped_sum_offset = n * div_up(jbgp.ic, jbgp.src_sum_group_size); + wei_scales_offset + = wei_scales_oc_stride * oc * wei_scales_dt_size; + wei_zero_points_offset = wei_zero_points_oc_stride * oc + * wei_zero_points_dt_size; + src_scales_offset + = n * div_up(jbgp.ic, jbgp.src_quant_group_size); + src_grouped_sum_offset + = n * div_up(jbgp.ic, jbgp.src_sum_group_size); } auto brg_kernel_ic_tail = brg_kernels_[brg_ker_ic_tail_idx].get(); @@ -617,9 +773,11 @@ status_t brgemm_inner_product_fwd_t::execute_forward( && last_spatial_slice) { void *scratch = is_amx ? static_cast(wsp_tile) - : (jbgp.req_s8s8_compensation ? static_cast( - const_cast(&compensation[oc])) - : nullptr); + : (jbgp.req_s8s8_compensation + ? static_cast( + const_cast( + &compensation[oc])) + : nullptr); auto ptr_bias = jbgp.with_bias ? bias + bia_dt_size * oc : nullptr; const brgemm_post_ops_data_t post_ops_data { @@ -627,18 +785,25 @@ status_t brgemm_inner_product_fwd_t::execute_forward( post_ops_binary_rhs_arg_vec.data(), static_cast(oc), 0, dst, 0, nullptr, nullptr, nullptr, false, 1, false, false, src_scales, - wei_scales_f ? reinterpret_cast(wei_scales_f) + wei_scales_f + ? reinterpret_cast(wei_scales_f) + jbgp.is_oc_scale * oc * sizeof(float) - : nullptr, + : nullptr, dst_scales_ptr}; brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch, - (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, - src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); + (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch, + nullptr, wei_scales + wei_scales_offset, + wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, + src_grouped_sum + src_grouped_sum_offset, ic); } else { brgemm_kernel_execute(brg_kernel_ic_tail, 1, addr_batch, - (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, nullptr, wei_scales + wei_scales_offset, wei_zero_points + wei_zero_points_offset, - src_dscales + src_scales_offset, src_grouped_sum + src_grouped_sum_offset, ic); + (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr, + nullptr, wei_scales + wei_scales_offset, + wei_zero_points + wei_zero_points_offset, + src_dscales + src_scales_offset, + src_grouped_sum + src_grouped_sum_offset, ic); } } }; @@ -907,10 +1072,13 @@ status_t brgemm_inner_product_fwd_t::execute_forward( void *scratch = is_amx ? static_cast(wsp_tile) - : (jbgp.req_s8s8_compensation ? static_cast< - void *>(const_cast( - &compensation[oc])) - : nullptr); + : (jbgp.req_s8s8_compensation + ? static_cast< + void *>(const_cast< + int *>( + &compensation + [oc])) + : nullptr); const brgemm_post_ops_data_t post_ops_data { static_cast(ptr_bias), @@ -995,7 +1163,7 @@ void brgemm_inner_product_bwd_data_t::execute_backward_data( : nullptr; auto wsp_tile_base = is_amx ? ctx.get_scratchpad_grantor().template get( - key_conv_amx_tile_buffer) + key_conv_amx_tile_buffer) : nullptr; const dim_t acc_dt_sz = types::data_type_size(jbgp.acc_dt); @@ -1450,7 +1618,7 @@ struct brgemm_inner_product_bwd_weights_t::thread_info_t { wsp_tile_base = is_amx ? ctx.get_scratchpad_grantor().template get( - key_conv_amx_tile_buffer) + key_conv_amx_tile_buffer) : nullptr; nthr = jbgp.nthr; diff --git a/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp b/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp index b6aff7896b0..0f36a150b1e 100644 --- a/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp +++ b/src/cpu/x64/jit_brgemm_src_quantization_kernel.cpp @@ -33,19 +33,28 @@ using namespace Xbyak; using namespace std::placeholders; template -void jit_brgemm_src_quantization_kernel_t::load_src(Vmm vmm_load, const Xbyak::Address& addr) { +void jit_brgemm_src_quantization_kernel_t::load_src( + Vmm vmm_load, const Xbyak::Address &addr) { switch (jcp_.src_dt) { case data_type::f32: { uni_vmovups(vmm_load, addr); break; } + case data_type::bf16: { + // Upconvert bf16 payload to f32 by placing bf16 bits in high 16 bits of each dword. + uni_vpmovzxwd(vmm_load, addr); + uni_vpslld(vmm_load, vmm_load, 16); + break; + } default: assert(!"unsupported data type"); } } template -void jit_brgemm_src_quantization_kernel_t::horiz_op(Vmm vmm_src, Vmm vmm_aux, op_type type) { - auto uni_op = [&](const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { +void jit_brgemm_src_quantization_kernel_t::horiz_op( + Vmm vmm_src, Vmm vmm_aux, op_type type) { + auto uni_op = [&](const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { if (type == op_type::max) { uni_vmaxps(x1, x2, op); } else if (type == op_type::sum) { @@ -94,20 +103,15 @@ void jit_brgemm_src_quantization_kernel_t::generate() { size_t src_scales_dt_size = types::data_type_size(data_type::f32); size_t src_grouped_sum_dt_size = types::data_type_size(data_type::s32); - static const float negative_zero[16] = { - -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, - -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f - }; + static const float negative_zero[16] = {-0.f, -0.f, -0.f, -0.f, -0.f, -0.f, + -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f, -0.f}; - static const float positive_one[16] = { - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f - }; + static const float positive_one[16] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; - static const float int8_max[16] = { - 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, - 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, 127.f - }; + static const float int8_max[16] + = {127.f, 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, + 127.f, 127.f, 127.f, 127.f, 127.f, 127.f, 127.f}; mov(reg_tmp, (size_t)negative_zero); uni_vmovups(vmm_sign_bit_mask(), ptr[reg_tmp]); @@ -145,7 +149,8 @@ void jit_brgemm_src_quantization_kernel_t::generate() { uni_vmovss(ptr[reg_src_scales], Xmm(vmm_dscale.getIdx())); if (jcp_.with_src_grouped_sum) { - uni_vxorps(vmm_src_sum_accum(), vmm_src_sum_accum(), vmm_src_sum_accum()); + uni_vxorps(vmm_src_sum_accum(), vmm_src_sum_accum(), + vmm_src_sum_accum()); } for (int icb = 0; icb < ic_blocks; icb++) { load_src(vmm_src(), ptr[reg_src + icb * vec_size * src_dt_size]); @@ -157,19 +162,23 @@ void jit_brgemm_src_quantization_kernel_t::generate() { if (((icb + 1) * vec_size) % jcp_.src_sum_group_size == 0) { horiz_op(vmm_src_sum_accum(), vmm_aux(), op_type::sum); - uni_vmovss(ptr[reg_src_grouped_sum], Xmm(vmm_src_sum_accum().getIdx())); - uni_vxorps(vmm_src_sum_accum(), vmm_src_sum_accum(), vmm_src_sum_accum()); + uni_vmovss(ptr[reg_src_grouped_sum], + Xmm(vmm_src_sum_accum().getIdx())); + uni_vxorps(vmm_src_sum_accum(), vmm_src_sum_accum(), + vmm_src_sum_accum()); add(reg_src_grouped_sum, src_grouped_sum_dt_size); } } if (isa == avx512_core) { - vpmovsdb(ptr[reg_qsrc + icb * vec_size * qsrc_dt_size], vmm_src()); + vpmovsdb(ptr[reg_qsrc + icb * vec_size * qsrc_dt_size], + vmm_src()); } else { uni_vpackssdw(vmm_src(), vmm_src(), vmm_src()); vpermq(Ymm(vmm_src().getIdx()), Ymm(vmm_src().getIdx()), 0x08); uni_vpacksswb(vmm_src(), vmm_src(), vmm_src()); - vmovq(ptr[reg_qsrc + icb * vec_size * qsrc_dt_size], Xmm(vmm_src().getIdx())); + vmovq(ptr[reg_qsrc + icb * vec_size * qsrc_dt_size], + Xmm(vmm_src().getIdx())); } }