From 3120fd17863f6a178acaec3d7d976d9e3082e072 Mon Sep 17 00:00:00 2001 From: Jianning Wang Date: Fri, 13 Feb 2026 15:15:22 +0800 Subject: [PATCH] feat: int4 & inner_product support one2many calc --- src/ailego/math_batch/cosine_distance_batch.h | 27 +++++ src/ailego/math_batch/distance_batch.h | 6 ++ .../math_batch/inner_product_distance_batch.h | 18 ++++ .../inner_product_distance_batch_impl_int4.h | 98 +++++++++++++++++++ src/core/metric/inner_product_metric.cc | 14 +-- src/core/metric/quantized_integer_metric.cc | 6 +- .../metric/quantized_integer_metric_batch.h | 96 +++++++++++++++--- 7 files changed, 241 insertions(+), 24 deletions(-) create mode 100644 src/ailego/math_batch/inner_product_distance_batch_impl_int4.h diff --git a/src/ailego/math_batch/cosine_distance_batch.h b/src/ailego/math_batch/cosine_distance_batch.h index b8a8309a..5435587c 100644 --- a/src/ailego/math_batch/cosine_distance_batch.h +++ b/src/ailego/math_batch/cosine_distance_batch.h @@ -27,6 +27,10 @@ namespace zvec::ailego::DistanceBatch { template struct CosineDistanceBatch; +template +struct MinusInnerProductDistanceBatch; + + template struct CosineDistanceBatch { using ValueType = typename std::remove_cv::type; @@ -54,5 +58,28 @@ struct CosineDistanceBatch { } }; +template +struct MinusInnerProductDistanceBatch { + using ValueType = typename std::remove_cv::type; + + static inline void ComputeBatch(const ValueType **vecs, + const ValueType *query, size_t num_vecs, + size_t dim, float *results) { + InnerProductDistanceBatch::ComputeBatch( + vecs, query, num_vecs, dim, results); + + for (size_t i = 0; i < num_vecs; ++i) { + results[i] = -results[i]; + } + } + + using IPImplType = + InnerProductDistanceBatch; + + static void QueryPreprocess(void *query, size_t dim) { + return IPImplType::QueryPreprocess(query, dim); + } +}; + } // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/ailego/math_batch/distance_batch.h b/src/ailego/math_batch/distance_batch.h index c762a258..a02c4adb 100644 --- a/src/ailego/math_batch/distance_batch.h +++ b/src/ailego/math_batch/distance_batch.h @@ -43,6 +43,12 @@ struct BaseDistance { ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim, out); } + if constexpr (std::is_same_v, + MinusInnerProductMatrix>) { + return DistanceBatch::MinusInnerProductDistanceBatch< + ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim, + out); + } _ComputeBatch(m, q, num, dim, out); } diff --git a/src/ailego/math_batch/inner_product_distance_batch.h b/src/ailego/math_batch/inner_product_distance_batch.h index f5799497..9eca8dda 100644 --- a/src/ailego/math_batch/inner_product_distance_batch.h +++ b/src/ailego/math_batch/inner_product_distance_batch.h @@ -22,6 +22,7 @@ #include #include "inner_product_distance_batch_impl.h" #include "inner_product_distance_batch_impl_fp16.h" +#include "inner_product_distance_batch_impl_int4.h" #include "inner_product_distance_batch_impl_int8.h" namespace zvec::ailego::DistanceBatch { @@ -130,6 +131,23 @@ struct InnerProductDistanceBatchImpl { } }; +template +struct InnerProductDistanceBatchImpl { + using ValueType = uint8_t; + static void compute_one_to_many( + const uint8_t *query, const uint8_t **ptrs, + std::array &prefetch_ptrs, size_t dim, + float *sums) { +#if defined(__AVX2__) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + return compute_one_to_many_avx2_int4(query, ptrs, + prefetch_ptrs, dim, sums); + } +#endif + return compute_one_to_many_fallback(query, ptrs, prefetch_ptrs, dim, sums); + } +}; + template struct InnerProductDistanceBatch { using ValueType = typename std::remove_cv::type; diff --git a/src/ailego/math_batch/inner_product_distance_batch_impl_int4.h b/src/ailego/math_batch/inner_product_distance_batch_impl_int4.h new file mode 100644 index 00000000..d1bf0ec2 --- /dev/null +++ b/src/ailego/math_batch/inner_product_distance_batch_impl_int4.h @@ -0,0 +1,98 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace zvec::ailego::DistanceBatch { + +#if defined(__AVX2__) + +static const __m256i MASK_INT4_AVX = _mm256_set1_epi32(0x0f0f0f0f); +static const AILEGO_ALIGNED(32) int8_t Int4ConvertTable[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, + 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1}; +static const __m256i INT4_LOOKUP_AVX = + _mm256_load_si256((const __m256i *)Int4ConvertTable); +static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); + +template +static void compute_one_to_many_avx2_int4( + const uint8_t *query, const uint8_t **ptrs, + std::array &prefetch_ptrs, size_t dimensionality, + float *results) { + dimensionality >>= 1; + __m256i accs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + accs[i] = _mm256_setzero_si256(); + } + size_t dim = 0; + for (; dim + 32 <= dimensionality; dim += 32) { + __m256i q = _mm256_loadu_si256((const __m256i *)(query + dim)); + __m256i q0 = _mm256_shuffle_epi8(INT4_LOOKUP_AVX, + _mm256_and_si256(q, MASK_INT4_AVX)); + __m256i q1 = _mm256_shuffle_epi8( + INT4_LOOKUP_AVX, + _mm256_and_si256(_mm256_srli_epi16(q, 4), MASK_INT4_AVX)); + __m256i q0_abs = _mm256_abs_epi8(q0); + __m256i q1_abs = _mm256_abs_epi8(q1); + __m256i data_regs[dp_batch]; + for (size_t i = 0; i < dp_batch; ++i) { + data_regs[i] = _mm256_loadu_si256((const __m256i *)(ptrs[i] + dim)); + } + if (prefetch_ptrs[0]) { + for (size_t i = 0; i < dp_batch; ++i) { + ailego_prefetch(prefetch_ptrs[i] + dim); + } + } + for (size_t i = 0; i < dp_batch; ++i) { + __m256i data0 = _mm256_shuffle_epi8( + INT4_LOOKUP_AVX, _mm256_and_si256(data_regs[i], MASK_INT4_AVX)); + __m256i data1 = _mm256_shuffle_epi8( + INT4_LOOKUP_AVX, + _mm256_and_si256(_mm256_srli_epi16(data_regs[i], 4), MASK_INT4_AVX)); + data0 = _mm256_sign_epi8(data0, q0); + data1 = _mm256_sign_epi8(data1, q1); + data0 = _mm256_madd_epi16(_mm256_maddubs_epi16(q0_abs, data0), + ONES_INT16_AVX); + data1 = _mm256_madd_epi16(_mm256_maddubs_epi16(q1_abs, data1), + ONES_INT16_AVX); + accs[i] = _mm256_add_epi32(_mm256_add_epi32(data0, data1), accs[i]); + } + } + std::array temp_results; + for (size_t i = 0; i < dp_batch; ++i) { + __m128i lo = _mm256_castsi256_si128(accs[i]); + __m128i hi = _mm256_extracti128_si256(accs[i], 1); + __m128i sum128 = _mm_add_epi32(lo, hi); + sum128 = _mm_hadd_epi32(sum128, sum128); + sum128 = _mm_hadd_epi32(sum128, sum128); + temp_results[i] = _mm_cvtsi128_si32(sum128); + } + for (; dim < dimensionality; ++dim) { + uint8_t q = query[dim]; + for (size_t i = 0; i < dp_batch; ++i) { + uint8_t m = ptrs[i][dim]; + temp_results[i] += + Int4MulTable[(((m) << 4) & 0xf0) | (((q) >> 0) & 0xf)] + + Int4MulTable[(((m) >> 0) & 0xf0) | (((q) >> 4) & 0xf)]; + } + } + for (size_t i = 0; i < dp_batch; ++i) { + results[i] = static_cast(temp_results[i]); + } +} + +#endif + +} // namespace zvec::ailego::DistanceBatch \ No newline at end of file diff --git a/src/core/metric/inner_product_metric.cc b/src/core/metric/inner_product_metric.cc index 8ef0a11b..e4a609c8 100644 --- a/src/core/metric/inner_product_metric.cc +++ b/src/core/metric/inner_product_metric.cc @@ -354,20 +354,20 @@ class InnerProductMetric : public IndexMetric { switch (data_type_) { case IndexMeta::DataType::DT_FP32: return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); + ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); + ailego::Float16, 12, 2>::ComputeBatch); case IndexMeta::DataType::DT_INT8: return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); + ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_INT4: return reinterpret_cast( - ailego::BaseDistance::ComputeBatch); + ailego::BaseDistance::ComputeBatch); default: return nullptr; } diff --git a/src/core/metric/quantized_integer_metric.cc b/src/core/metric/quantized_integer_metric.cc index 56e95634..9c5cce32 100644 --- a/src/core/metric/quantized_integer_metric.cc +++ b/src/core/metric/quantized_integer_metric.cc @@ -264,7 +264,11 @@ class QuantizedIntegerMetric : public IndexMetric { const override { if (origin_metric_type_ == MetricType::kCosine && meta_.data_type() == IndexMeta::DataType::DT_INT8) { - return CosineMinusInnerProductDistanceBatchWithScoreUnquantized< + return CosineDistanceBatchWithScoreUnquantized< + int8_t, 1, 1>::GetQueryPreprocessFunc(); + } else if (origin_metric_type_ == MetricType::kInnerProduct && + meta_.data_type() == IndexMeta::DataType::DT_INT8) { + return MinusInnerProductDistanceBatchWithScoreUnquantized< int8_t, 1, 1>::GetQueryPreprocessFunc(); } diff --git a/src/core/metric/quantized_integer_metric_batch.h b/src/core/metric/quantized_integer_metric_batch.h index e9e63cef..e87dc091 100644 --- a/src/core/metric/quantized_integer_metric_batch.h +++ b/src/core/metric/quantized_integer_metric_batch.h @@ -24,7 +24,7 @@ template struct MinusInnerProductDistanceBatchWithScoreUnquantized; template -struct CosineMinusInnerProductDistanceBatchWithScoreUnquantized; +struct CosineDistanceBatchWithScoreUnquantized; template struct SquaredEuclideanDistanceBatchWithScoreUnquantized; @@ -32,6 +32,9 @@ struct SquaredEuclideanDistanceBatchWithScoreUnquantized; template struct MipsSquaredEuclideanDistanceBatchWithScoreUnquantized; +template +struct InternalMinusInnerProductDistanceBatchWithScoreUnquantized; + template