diff --git a/npsr/lut-inl.h b/npsr/lut-inl.h index eb7611c..4285d8c 100644 --- a/npsr/lut-inl.h +++ b/npsr/lut-inl.h @@ -14,24 +14,26 @@ HWY_BEFORE_NAMESPACE(); namespace npsr::HWY_NAMESPACE { /** - * @brief SIMD-optimized lookup table implementation + * @brief Optimized Lookup Table. * - * This class provides an efficient lookup table. - * It stores data in both row-major and column-major - * formats to optimize different access patterns. + * A fixed-size lookup table (kRows x kCols) optimized for SIMD loads. + * It automatically selects the fastest loading strategy (TableLookup, + * Interleaved Loads, or Gather) based on the vector architecture and table dimensions. * - * @tparam T Element type (must match SIMD vector element type) - * @tparam kRows Number of rows in the lookup table - * @tparam kCols Number of columns in the lookup table - * - * Example usage: + * @par Example Usage * @code - * // Create a 2x4 lookup table - * constexpr Lut lut{{1.0f, 2.0f, 3.0f, 4.0f}, {5.0f, 6.0f, 7.0f, 8.0f}}; - * // Load values using SIMD indices - * auto indices = Set(d, 2); // SIMD vector of indices - * Vec out0, out1; - * lut.Load(indices, out0, out1); + * // 1. Create a 2x4 table (2 rows, 4 columns) + * auto lut = MakeLut( + * {1.0f, 2.0f, 3.0f, 4.0f}, + * {5.0f, 6.0f, 7.0f, 8.0f} + * ); + * + * // 2. Prepare SIMD indices (e.g., select columns 2, 0, 1, 3...) + * auto indices = Set(d, 2); + * + * // 3. Load values: 'r0' gets values from row 0, 'r1' from row 1 + * Vec r0, r1; + * lut.Load(indices, r0, r1); * @endcode */ template @@ -39,92 +41,231 @@ class Lut { public: static constexpr size_t kLength = kRows * kCols; - /** - * @brief Construct a lookup table from row arrays - * - * @tparam ColSizes Size of each row array (deduced) - * @param rows Variable number of arrays, each representing a row + // Implementation details for transposition optimization + static constexpr size_t kTransposeBy = HWY_LANES(T); + static constexpr size_t kTransposeTail = kRows % kTransposeBy; + static constexpr size_t kTransposeLength = (kRows - kTransposeTail) * kCols; + + // Determine at compile-time if transposition optimization is viable + static constexpr bool kInitTranspose = !HWY_HAVE_SCALABLE && ( + kRows / kTransposeBy > 0 && kCols % kTransposeBy == 0 && + (kTransposeBy == 2 || kTransposeBy == 4) // Currently supports 2x or 4x unrolling + ); + + /** + * @brief Constructs the table from row arrays. * - * @note All rows must have exactly kCols elements - * @note The constructor is constexpr for compile-time initialization + * @param rows Variable number of C-arrays, one for each row. + * Must match kRows count and kCols size. */ template - constexpr Lut(const T (&...rows)[ColSizes]) : row_{} { - // Check that we have the right number of rows - static_assert(sizeof...(rows) == kRows, - "Number of rows doesn't match template parameter"); - // Check that all rows have the same number of columns - static_assert(((ColSizes == kCols) && ...), - "All rows must have the same number of columns"); - - // Copy data using recursive template approach - ToRowMajor_<0>(rows...); + constexpr Lut(const T (&...rows)[ColSizes]) : row_{}, trans_{} { + static_assert(sizeof...(rows) == kRows, "Count of input arrays must match kRows."); + static_assert(((ColSizes == kCols) && ...), "All input arrays must have kCols elements."); + + // Recursively copy data to internal storage + const auto &t_rows = std::forward_as_tuple(rows...); + if constexpr (kInitTranspose) { + InitTranspose_(t_rows); + } + InitRow_(t_rows); } /** - * @brief Load values from the LUT using SIMD indices - * - * This method performs efficient SIMD lookups by selecting the optimal - * implementation based on the vector size and LUT dimensions. + * @brief Loads values from the table using SIMD indices. * - * @tparam VU SIMD vector type for indices - * @tparam OutV Output SIMD vector types (must match number of rows) - * @param idx SIMD vector of column indices - * @param out Output vectors (one per row) + * Retrieves values from every row in the table corresponding to the column `idx`. * - * @note The number of output vectors must exactly match kRows - * @note Index values must be in range [0, kCols) + * @param idx SIMD vector containing column indices (0 to kCols-1). + * @param[out] out Reference to output vectors. You must provide exactly one + * output vector per row (total kRows). */ template HWY_INLINE void Load(VU idx, OutV &...out) const { - static_assert(sizeof...(OutV) == kRows, - "Number of output vectors must match number of rows in LUT"); + static_assert(sizeof...(OutV) == kRows, "Must provide one output vector per table row."); using namespace hn; using TU = TFromV; - static_assert(sizeof(TU) == sizeof(T), - "Index type must match LUT element type"); - // Row-major based optimization - LoadRow_(idx, out...); + static_assert(sizeof(TU) == sizeof(T), "Index vector type must match table element type."); + +#if !HWY_HAVE_SCALABLE + // Try optimized transposed load first + if constexpr (kInitTranspose) { + LoadTranspose_(idx, out...); + } +#else + if constexpr (0) {} +#endif + else { + // Fallback to standard row loading + LoadRow_(idx, out...); + } } private: - /// Convert input rows to row-major storage format - template - constexpr void ToRowMajor_(const T (&...rows)[ColSizes]) { + + // Flattens input arrays into the standard linear member `row_` + template + constexpr void InitRow_(const Tuple& rows) { if constexpr (RowIDX < kRows) { - auto row_array = std::get(std::make_tuple(rows...)); + const auto& row_array = std::get(rows); for (size_t col = 0; col < kCols; ++col) { row_[RowIDX * kCols + col] = row_array[col]; } - ToRowMajor_(rows...); + InitRow_(rows); + } + } + +#if !HWY_HAVE_SCALABLE + // Pre-calculates transposed blocks for specific access patterns + template + constexpr void InitTranspose_(const Tuple& rows) { + constexpr size_t kTransposeRows = kRows - kTransposeTail; + if constexpr (RowIDX < kTransposeRows) { + const auto& row_array = std::get(rows); + const size_t block = RowIDX / kTransposeBy; + const size_t in_block = RowIDX % kTransposeBy; + constexpr size_t block_size = kTransposeBy * kCols; + for (size_t col = 0; col < kCols; ++col) { + trans_[ + block * block_size + + col * kTransposeBy + + in_block + ] = row_array[col]; + } + InitTranspose_(rows); } } - /// Dispatch to optimal row-load implementation based on vector/LUT size + // --- Transposed Load Implementation --- + template - HWY_INLINE void LoadRow_(VU idx, OutV &...out) const { + HWY_INLINE void LoadTranspose_(const VU &idx, OutV &...out) const { using namespace hn; using DU = DFromV; + using TU = TFromD; const DU du; + + constexpr size_t kLanes = Lanes(du); + + // Only use transposed load if vector lanes match the transpose blocking factor + if constexpr (kLanes == kTransposeBy) { + HWY_ALIGN TU s_idx[kLanes]; + Store(ShiftLeft(idx), du, s_idx); + if constexpr (kTransposeBy == 2) { + LoadTransposeX2_(s_idx, idx, out...); + } + else { + static_assert(kTransposeBy == 4, "It's already guarded by kInitTranspose"); + LoadTransposeX4_(s_idx, idx, out...); + } + } + else { + LoadRow_(idx, out...); + } + } + + // 2-wide transposed load optimization + template + HWY_INLINE void LoadTransposeX2_(const TU *trans_idx, const VU &idx, + OutV0& v0, OutV0& v1, OutV &...out) const { + using namespace hn; + using D = DFromV; + const D d; + + // Load interleaved data + const OutV0 a0b0 = LoadU(d, trans_ + Off + trans_idx[0]); + const OutV0 a1b1 = LoadU(d, trans_ + Off + trans_idx[1]); + + // De-interleave into separate vectors + v0 = ConcatLowerLower(d, a1b1, a0b0); + v1 = ConcatUpperUpper(d, a1b1, a0b0); + + // Recurse for remaining rows + if constexpr (sizeof...(OutV) == 1) { + LoadRow_(idx, out...); + } + else if constexpr (sizeof...(OutV) > 0) { + LoadTransposeX2_(trans_idx, idx, out...); + } + } + + // 4-wide transposed load optimization + template + HWY_INLINE void LoadTransposeX4_(const TU *trans_idx, const VU &idx, + OutV0& v0, OutV0& v1, OutV0& v2, OutV0& v3, OutV &...out) const { + using namespace hn; + using D = DFromV; + const D d; + + const OutV0 abcd0 = LoadU(d, trans_ + Off + trans_idx[0]); + const OutV0 abcd1 = LoadU(d, trans_ + Off + trans_idx[1]); + const OutV0 abcd2 = LoadU(d, trans_ + Off + trans_idx[2]); + const OutV0 abcd3 = LoadU(d, trans_ + Off + trans_idx[3]); + + const OutV0 ab01 = InterleaveLower(d, abcd0, abcd1); + const OutV0 cd01 = InterleaveUpper(d, abcd0, abcd1); + const OutV0 ab23 = InterleaveLower(d, abcd2, abcd3); + const OutV0 cd23 = InterleaveUpper(d, abcd2, abcd3); + + v0 = ConcatLowerLower(d, ab23, ab01); + v1 = ConcatUpperUpper(d, ab23, ab01); + v2 = ConcatLowerLower(d, cd23, cd01); + v3 = ConcatUpperUpper(d, cd23, cd01); + + if constexpr (sizeof...(OutV) <= 3 && sizeof...(OutV) > 0) { + LoadRow_(idx, out...); + } + else if constexpr (sizeof...(OutV) > 0) { + LoadTransposeX4_(trans_idx, idx, out...); + } + } +#endif + + // --- Row-Major Load Implementation --- + + // Selects the best row loading strategy based on vector size vs table width + template + HWY_INLINE void LoadRow_(const VU& idx, OutV& ...out) const { + using namespace hn; + using DU = DFromV; + const DU du; + using DI = RebindToSigned; + using TI = TFromD; + const DI di; using D = Rebind; + using M = MFromD; const D d; - HWY_LANES_CONSTEXPR size_t kLanes = Lanes(du); - if HWY_LANES_CONSTEXPR (kLanes == kCols) { - // Vector size matches table width - use single table lookup +#if !HWY_HAVE_SCALABLE + constexpr size_t kLanes = Lanes(du); + // Strategy 1: Vector size equals table width (Single Table Lookup) + if constexpr (kLanes == kCols) { const auto ind = IndicesFromVec(d, idx); LoadX1_(ind, out...); - } else if HWY_LANES_CONSTEXPR (kLanes * 2 == kCols) { - // Vector size is half table width - use two table lookup + } + // Strategy 2: Vector size is half table width (Two Table Lookups) + else if constexpr (kLanes * 2 == kCols) { const auto ind = IndicesFromVec(d, idx); LoadX2_(ind, out...); - } else { - // Fallback to gather for other configurations + } + // Strategy 3: Vector size is quarter table width (Four Table Lookups) + else if constexpr (kLanes * 4 == kCols && kLanes < std::numeric_limits::max()) { + const VU lut_lim = Set(du, kLanes * 2 - 1); + const auto ind = IndicesFromVec(d, And(idx, lut_lim)); + // Note: Rebind to signed because native unsigned 'Greater Than' might be missing + const M hi_mask = RebindMask(d, Gt(BitCast(di, idx), BitCast(di, lut_lim))); + LoadX4_(ind, hi_mask, out...); + } +#else + if constexpr (0) {} +#endif + else { + // Fallback: Use Gather instructions LoadGather_(idx, out...); } } - // Load using single table lookup (vector size == table width) + // Implementation: Single Table Lookup template HWY_INLINE void LoadX1_(const VInd &ind, OutV0 &out0, OutV &...out) const { using namespace hn; @@ -139,7 +280,7 @@ class Lut { } } - // Load using two table lookups (vector size == table width / 2) + // Implementation: Two Table Lookups template HWY_INLINE void LoadX2_(const VInd &ind, OutV0 &out0, OutV &...out) const { using namespace hn; @@ -156,7 +297,29 @@ class Lut { } } - // General fallback using gather instructions + // Implementation: Four Table Lookups + template + HWY_INLINE void LoadX4_(const VInd &ind, const HiMask &hi_mask, OutV0 &out0, OutV &...out) const { + using namespace hn; + using D = DFromV; + const D d; + + constexpr size_t kLanes = kCols / 4; + const OutV0 lut0 = LoadU(d, row_ + Off); + const OutV0 lut1 = LoadU(d, row_ + Off + kLanes); + const OutV0 lut2 = LoadU(d, row_ + Off + kLanes * 2); + const OutV0 lut3 = LoadU(d, row_ + Off + kLanes * 3); + + OutV0 lo = TwoTablesLookupLanes(d, lut0, lut1, ind); + OutV0 hi = TwoTablesLookupLanes(d, lut2, lut3, ind); + out0 = IfThenElse(hi_mask, hi, lo); + + if constexpr (sizeof...(OutV) > 0) { + LoadX4_(ind, hi_mask, out...); + } + } + + // Implementation: Gather fallback template HWY_INLINE void LoadGather_(const VU &idx, OutV0 &out0, OutV &...out) const { using namespace hn; @@ -168,40 +331,27 @@ class Lut { } } - // Row-major HWY_ALIGN T row_[kLength]; + HWY_ALIGN T trans_[kTransposeLength]; }; /** - * @brief Deduction guide for automatic dimension detection - * - * Allows constructing a Lut without explicitly specifying dimensions: - * @code - * Lut lut{row0, row1, row2}; // Dimensions deduced from arrays - * @endcode + * @brief Deduction Guide (C++17/20). + * Allows `Lut x{row1, row2};` without specifying template arguments. */ template Lut(const T (&first)[First], const T (&...rest)[Rest]) -> Lut; /** - * @brief Factory function that requires explicit type specification - * - * This approach forces users to specify the type T explicitly while - * automatically deducing the dimensions from the array arguments. - * - * Note: We use MakeLut since partial deduction guides (e.g., Lut{...}) - * require C++20, but this codebase targets C++17. + * @brief Factory function for explicit type specification. * - * @tparam T Element type (must be explicitly specified) - * @param first First row array - * @param rest Additional row arrays - * @return Lut with deduced dimensions + * Useful in C++17 where partial template deduction isn't available. + * * @tparam T Explicit element type (e.g., float). + * @return Lut with deduced dimensions. * - * Usage: * @code - * auto lut = MakeLut(row0, row1, row2); // T explicit, dimensions - * deduced + * auto lut = MakeLut(row1, row2); * @endcode */ template