From 11b3fc44de8572c8a9fbb1522e92a2361b198ac9 Mon Sep 17 00:00:00 2001 From: Denis Kotov Date: Tue, 6 Sep 2022 21:35:02 +0300 Subject: [PATCH 1/2] Added push(...) and pop(...) for SIMD registers --- src/cpu/x64/jit_generator.hpp | 77 +++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/src/cpu/x64/jit_generator.hpp b/src/cpu/x64/jit_generator.hpp index d3d0bb0efa8..48f1ee87331 100644 --- a/src/cpu/x64/jit_generator.hpp +++ b/src/cpu/x64/jit_generator.hpp @@ -147,6 +147,8 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { private: const size_t xmm_len = 16; + const size_t ymm_len = 32; + const size_t zmm_len = 64; #ifdef _WIN32 const size_t xmm_to_preserve_start = 6; const size_t xmm_to_preserve = 10; @@ -182,6 +184,81 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { inline size_t get_size_of_abi_save_regs() { return size_of_abi_save_regs; } + using Xbyak::CodeGenerator::push; + using Xbyak::CodeGenerator::pop; + + inline void push(const Xbyak::Xmm &xmm) { + sub(rsp, xmm_len); + uni_vmovdqu(ptr[rsp], xmm); + } + + inline void push(const std::vector &xmms) { + sub(rsp, xmms.size() * xmm_len); + for (size_t i = 0; i < xmms.size(); ++i) { + uni_vmovdqu(ptr[rsp + i * xmm_len], xmms[i]); + } + } + + inline void push(const Xbyak::Ymm &ymm) { + sub(rsp, ymm_len); + uni_vmovdqu(ptr[rsp], ymm); + } + + inline void push(const std::vector &ymms) { + sub(rsp, ymms.size() * ymm_len); + for (size_t i = 0; i < ymms.size(); ++i) { + uni_vmovdqu(ptr[rsp + i * ymm_len], ymms[i]); + } + } + + inline void push(const Xbyak::Zmm &zmm) { + sub(rsp, zmm_len); + uni_vmovdqu(ptr[rsp], zmm); + } + + inline void push(const std::vector &zmms) { + sub(rsp, zmms.size() * zmm_len); + for (size_t i = 0; i < zmms.size(); ++i) { + uni_vmovdqu(ptr[rsp + i * zmm_len], zmms[i]); + } + } + + inline void pop(const Xbyak::Xmm &xmm) { + uni_vmovdqu(xmm, ptr[rsp]); + add(rsp, xmm_len); + } + + inline void pop(const std::vector &xmms) { + for (size_t i = 0; i < xmms.size(); ++i) { + uni_vmovdqu(xmms[i], ptr[rsp + i * xmm_len]); + } + sub(rsp, xmms.size() * xmm_len); + } + + inline void pop(const Xbyak::Ymm &ymm) { + uni_vmovdqu(ymm, ptr[rsp]); + add(rsp, ymm_len); + } + + inline void pop(const std::vector &ymms) { + for (size_t i = 0; i < ymms.size(); ++i) { + uni_vmovdqu(ymms[i], ptr[rsp + i * ymm_len]); + } + sub(rsp, ymms.size() * ymm_len); + } + + inline void pop(const Xbyak::Zmm &zmm) { + uni_vmovdqu(zmm, ptr[rsp]); + add(rsp, zmm_len); + } + + inline void pop(const std::vector &zmms) { + for (size_t i = 0; i < zmms.size(); ++i) { + uni_vmovdqu(zmms[i], ptr[rsp + i * zmm_len]); + } + sub(rsp, zmms.size() * zmm_len); + } + void preamble() { if (xmm_to_preserve) { sub(rsp, xmm_to_preserve * xmm_len); From fea517f6bf92885035b8d302f32db94d56fde5d8 Mon Sep 17 00:00:00 2001 From: Denis Kotov Date: Thu, 8 Sep 2022 00:41:08 +0300 Subject: [PATCH 2/2] Added general uni_vgatherdps and uni_vscatterdps --- src/cpu/x64/jit_generator.hpp | 179 ++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) diff --git a/src/cpu/x64/jit_generator.hpp b/src/cpu/x64/jit_generator.hpp index 48f1ee87331..c92de9c70d9 100644 --- a/src/cpu/x64/jit_generator.hpp +++ b/src/cpu/x64/jit_generator.hpp @@ -164,6 +164,26 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { = num_abi_save_gpr_regs * rax.getBit() / 8 + xmm_to_preserve * xmm_len; + template + inline TOutReg get_free_reg(std::vector& reg_idxs, + std::vector& not_available) { + std::vector not_available_idx(not_available.size()); + std::transform(not_available.begin(), not_available.end(), not_available_idx.begin(), + [](const Xbyak::Reg& reg) { + return reg.getIdx(); + }); + auto removed = std::remove_if(reg_idxs.begin(), reg_idxs.end(), + [¬_available_idx](const int& reg_idx) { + return not_available_idx.end() != std::find(not_available_idx.begin(), + not_available_idx.end(), + reg_idx); + }); + reg_idxs.erase(removed, reg_idxs.end()); + TOutReg alloc_reg{reg_idxs.front()}; + not_available.push_back(alloc_reg); + return alloc_reg; + } + public: enum { _cmp_eq_oq = 0u, @@ -416,6 +436,165 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible { vpxord(x1, x2, op); } + template + inline TReg get_free_reg(std::vector& not_available) { + static_assert(std::is_base_of::value, "Xbyak::Reg should be base of Tmm"); + const size_t regsNumber = 16; + std::vector reg_idxs; + reg_idxs.reserve(regsNumber); + for (int i = 0; i < static_cast(regsNumber); ++i) { + // NOTE: We should avoid allocation rsp, otherwise we could write in + // wrong stack and crash application + if (rsp.getIdx() != i) { + reg_idxs.push_back(i); + } + } + return get_free_reg(reg_idxs, not_available); + } + + template + inline TVmm get_free_reg(std::vector& not_available) { + static_assert(std::is_base_of::value, "Xbyak::Xmm should be base of TVmm"); + std::vector xmm_idxs(8); +#ifdef XBYAK64 + size_t simdNumber = 0; + if (is_valid_isa(cpu_isa_t::avx512_core)) { + simdNumber = x64::cpu_isa_traits::vlen; + } else if (is_valid_isa(cpu_isa_t::avx2)) { + simdNumber = x64::cpu_isa_traits::vlen; + } else { + simdNumber = x64::cpu_isa_traits::vlen; + } + xmm_idxs.reserve(simdNumber); + for (int i = 0; i < static_cast(simdNumber); ++i) { + xmm_idxs.push_back(i); + } +#endif + return get_free_reg(xmm_idxs, not_available); + } + + inline void uni_vgatherdps(const Xbyak::Xmm &xmm_val, + const Xbyak::Reg64 ®_addr, + const Xbyak::Xmm &xmm_index, + const int &scale, + const int &disp, + const Xbyak::Reg ®_mask) { + const size_t kDataTypeSize = sizeof(float); + if (is_valid_isa(cpu_isa_t::avx512_core)) { + assert(reg_mask.isOPMASK()); + vgatherdps(xmm_val, ptr[reg_addr + xmm_index * scale + disp]); + } else if (is_valid_isa(cpu_isa_t::avx2)) { + assert(reg_mask.isYMM()); + Xbyak::Ymm ymm_mask{reg_mask.getIdx()}; + vgatherdps(xmm_val, ptr[reg_addr + xmm_index * scale + disp], ymm_mask); + } else { + const size_t kSimdWidth = x64::cpu_isa_traits::vlen / kDataTypeSize; + assert(reg_mask.isXMM()); + Xbyak::Xmm xmm_mask{reg_mask.getIdx()}; + assert(xmm_val.getKind() == xmm_index.getKind()); + assert(xmm_index.getKind() == xmm_mask.getKind()); + + std::vector not_available_reg{reg_addr}; + const Xbyak::Reg64 idx = this->get_free_reg(not_available_reg); + const Xbyak::Reg64 mask = this->get_free_reg(not_available_reg); + + push(idx); + push(mask); + xor_(idx, idx); + xor_(mask, mask); + + for (int i = 0; i < static_cast(kSimdWidth); i++) { + Xbyak::Label gather_end; + uni_vpextrd(mask.cvt32(), xmm_mask, i); + cmp(mask.cvt32(), 0xFFFFFFFF); + jne(gather_end, T_NEAR); + uni_vpextrd(idx.cvt32(), xmm_index, i); + Xbyak::Address addr = ptr[reg_addr + idx * scale + disp]; + uni_vpinsrd(xmm_val, xmm_val, addr, i); + L(gather_end); + } + pop(mask); + pop(idx); + } + } + + inline void uni_vscatterdps(const Xbyak::Reg64& reg_addr, + const Xbyak::Xmm& xmm_index, + const int scale, + const int disp, + const Xbyak::Xmm& xmm_val, + const Xbyak::Reg& reg_mask) { + const size_t kDataTypeSize = sizeof(float); + if (is_valid_isa(cpu_isa_t::avx512_core)) { + assert(reg_mask.isOPMASK()); + vscatterdps(ptr[reg_addr + xmm_index * scale + disp], xmm_val); + } else { + assert(reg_mask.isXMM() || reg_mask.isYMM()); + const size_t kXmmSimdWidth = x64::cpu_isa_traits::vlen / kDataTypeSize; + const size_t kYmmSimdWidth = x64::cpu_isa_traits::vlen / kDataTypeSize; + Xbyak::Xmm xmm_mask{reg_mask.getIdx(), reg_mask.getKind(), static_cast(reg_mask.getBit())}; + assert(xmm_val.getKind() == xmm_index.getKind()); + assert(xmm_index.getKind() == xmm_mask.getKind()); + + std::vector not_available_reg{reg_addr}; + std::vector not_available_xmm{xmm_index, xmm_val, xmm_mask}; + const Xbyak::Reg64 idx = this->get_free_reg(not_available_reg); + const Xbyak::Reg64 mask = this->get_free_reg(not_available_reg); + const Xbyak::Reg64 val = this->get_free_reg(not_available_reg); + const Xbyak::Xmm xmm_mask_temp = this->get_free_reg(not_available_xmm); + const Xbyak::Xmm xmm_index_temp = this->get_free_reg(not_available_xmm); + const Xbyak::Xmm xmm_val_temp = this->get_free_reg(not_available_xmm); + + push(idx); + push(mask); + push(val); + if (is_valid_isa(cpu_isa_t::avx2)) { + push(Xbyak::Ymm{xmm_mask_temp.getIdx()}); + push(Xbyak::Ymm{xmm_index_temp.getIdx()}); + push(Xbyak::Ymm{xmm_val_temp.getIdx()}); + } + xor_(idx, idx); + xor_(mask, mask); + xor_(val, val); + + auto store_xmm = [&](const Xbyak::Xmm& xmm_mask, + const Xbyak::Xmm& xmm_index, + const Xbyak::Xmm& xmm_val) { + for (int i = 0; i < static_cast(kXmmSimdWidth); i++) { + Xbyak::Label scatter_end; + uni_vpextrd(mask.cvt32(), xmm_mask, i); + cmp(mask.cvt32(), 0xFFFFFFFF); + jne(scatter_end, T_NEAR); + uni_vpextrd(idx.cvt32(), xmm_index, i); + Xbyak::Address addr = ptr[reg_addr + idx * scale]; + uni_vpextrd(val.cvt32(), xmm_val, i); + mov(addr, val.cvt32()); + L(scatter_end); + } + }; + + if (is_valid_isa(cpu_isa_t::avx2)) { + for (int i = 0; i < static_cast(kYmmSimdWidth / kXmmSimdWidth); i++) { + vextracti128(xmm_mask_temp, Xbyak::Ymm{xmm_mask.getIdx()}, i); + vextracti128(xmm_index_temp, Xbyak::Ymm{xmm_index.getIdx()}, i); + vextracti128(xmm_val_temp, Xbyak::Ymm{xmm_val.getIdx()}, i); + store_xmm(xmm_mask_temp, xmm_index_temp, xmm_val_temp); + } + } else { + store_xmm(xmm_mask, xmm_index, xmm_val); + } + + if (is_valid_isa(cpu_isa_t::avx2)) { + pop(Xbyak::Ymm{xmm_val_temp.getIdx()}); + pop(Xbyak::Ymm{xmm_index_temp.getIdx()}); + pop(Xbyak::Ymm{xmm_mask_temp.getIdx()}); + } + pop(val); + pop(mask); + pop(idx); + } + } + void uni_vmovss(const Xbyak::Address &addr, const Xbyak::Xmm &x) { if (is_valid_isa(avx)) vmovss(addr, x);