diff --git a/eigenix/Makefile b/eigenix/Makefile new file mode 100644 index 00000000..c973fee1 --- /dev/null +++ b/eigenix/Makefile @@ -0,0 +1,76 @@ +.PHONY: build clean run test docker + +# Build the Go service with C++ bindings +build: + @echo "Building Skye-Eigenix HNSW service with C++ bindings..." + CGO_ENABLED=1 go build -o eigenix ./cmd/eigenix + +# Build and run the service (using go run to avoid security issues) +run: + @echo "Starting Skye-Eigenix HNSW service..." + CGO_ENABLED=1 go run ./cmd/eigenix + +# Build and run from temp directory (alternative method) +run-safe: build-safe + @echo "Starting Skye-Eigenix service from temp directory..." + ~/temp-build/eigenix + +# Build in temp directory to avoid security software +build-safe: + @echo "Building Skye-Eigenix service in temp directory..." + @mkdir -p ~/temp-build + CGO_ENABLED=1 go build -o ~/temp-build/eigenix ./cmd/eigenix + +# Run the example client (requires service to be running) +test-client: + @echo "Running example client..." + @echo "Note: Client example not yet implemented" + +# Clean build artifacts +clean: + @echo "Cleaning build artifacts..." + rm -f eigenix vector-search-service hnsw-service + rm -rf indices + go clean + +# Install dependencies +deps: + @echo "Installing Go dependencies..." + go mod tidy + +# Build Docker image +docker: + @echo "Building Docker image..." + docker build -t hnsw-service . + +# Run with Docker +docker-run: + @echo "Running with Docker..." + docker run -p 8080:8080 hnsw-service + +# Development build with debug info +debug: + @echo "Building with debug info..." + CGO_ENABLED=1 go build -gcflags="all=-N -l" -o hnsw-service-debug . + +# Check if C++ compiler is available +check-deps: + @echo "Checking dependencies..." + @which g++ > /dev/null || (echo "Error: g++ not found. Please install a C++ compiler." && exit 1) + @echo "✓ C++ compiler found" + @go version + @echo "✓ Go found" + +# Help +help: + @echo "Available targets:" + @echo " build - Build the HNSW service" + @echo " run - Build and run the service" + @echo " test-client - Run the example client" + @echo " clean - Clean build artifacts" + @echo " deps - Install Go dependencies" + @echo " docker - Build Docker image" + @echo " docker-run - Run with Docker" + @echo " debug - Build with debug info" + @echo " check-deps - Check if dependencies are installed" + @echo " help - Show this help" \ No newline at end of file diff --git a/eigenix/README.md b/eigenix/README.md new file mode 100644 index 00000000..31e510c9 --- /dev/null +++ b/eigenix/README.md @@ -0,0 +1,2 @@ +# skye-eigenix +IVF + HNSW Based Vector Database diff --git a/eigenix/cmd/eigenix/eigenix.env b/eigenix/cmd/eigenix/eigenix.env new file mode 100644 index 00000000..80b3de34 --- /dev/null +++ b/eigenix/cmd/eigenix/eigenix.env @@ -0,0 +1,8 @@ +APP_NAME=eigenix +APP_METRIC_SAMPLING_RATE=0.1 + +IS_HNSW_LIB_ENABLED=false +IS_QDRANT_ENABLED=true +QDRANT_CENTROID_FILE_PATH="" +QDRANT_HOST="localhost" +QDRANT_PORT="6334" \ No newline at end of file diff --git a/eigenix/cmd/eigenix/main.go b/eigenix/cmd/eigenix/main.go new file mode 100644 index 00000000..92596c54 --- /dev/null +++ b/eigenix/cmd/eigenix/main.go @@ -0,0 +1,78 @@ +package main + +import ( + "os" + "os/signal" + "syscall" + + "github.com/Meesho/go-core/config" + "github.com/Meesho/go-core/grpc" + "github.com/Meesho/go-core/metric" + pb "github.com/Meesho/skye-eigenix/internal/client" + "github.com/Meesho/skye-eigenix/internal/hnswlib" + "github.com/Meesho/skye-eigenix/internal/qdrant" + "github.com/Meesho/skye-eigenix/internal/serving" + "github.com/rs/zerolog/log" + "github.com/spf13/viper" +) + +func main() { + // Load environment variables from .env file + config.InitEnv() + metric.Init() + os.Setenv("APP_NAME", "eigenix") + os.Setenv("APP_PORT", "8080") + initHNSWLib() + initQdrant() + grpc.Init() + // Register gRPC service + srv := serving.Init() + pb.RegisterEigenixServiceServer(grpc.Instance().GRPCServer, srv) + + // Setup signal handling for graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT) + + // Start gRPC server in a goroutine so it doesn't block + // Recover from any panics in this goroutine + go func() { + defer func() { + if r := recover(); r != nil { + log.Error().Interface("panic", r).Msg("gRPC server goroutine panicked") + } + }() + // Errors from Run() during shutdown are expected + _ = grpc.Instance().Run() + }() + log.Info().Msg("Server started successfully") + + // Wait for termination signal + sig := <-sigChan + log.Info().Msgf("Received signal: %v. Shutting down gracefully...", sig) + log.Info().Msg("Initiating shutdown...") +} + +func initHNSWLib() { + log.Info().Msg("Initializing HNSW library...") + db := hnswlib.Init() + + // Load existing indices on startup + log.Info().Msg("Loading existing indices...") + db.LoadAllIndices() + + // Ensure indices are saved on exit, even if there's a panic + defer func() { + if r := recover(); r != nil { + log.Error().Interface("panic", r).Msg("Recovered from panic during shutdown") + } + log.Info().Msg("Saving all indices before exit...") + db.SaveAllIndices() + log.Info().Msg("All indices saved successfully") + }() +} + +func initQdrant() { + log.Info().Msg("Initializing Qdrant...") + db := qdrant.Init() + db.LoadCentroids(viper.GetString("QDRANT_CENTROID_FILE_PATH")) +} diff --git a/eigenix/go.mod b/eigenix/go.mod new file mode 100644 index 00000000..23cf31b8 --- /dev/null +++ b/eigenix/go.mod @@ -0,0 +1,80 @@ +module github.com/Meesho/skye-eigenix + +go 1.24.4 + +toolchain go1.24.7 + +require ( + github.com/Meesho/go-core v1.30.17 + github.com/qdrant/go-client v1.15.2 + github.com/rs/zerolog v1.34.0 + github.com/spf13/viper v1.20.1 + google.golang.org/grpc v1.73.0 + google.golang.org/protobuf v1.36.9 +) + +require ( + github.com/DataDog/datadog-go/v5 v5.5.0 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/bits-and-blooms/bitset v1.22.0 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/failsafe-go/failsafe-go v0.6.9 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.9 // indirect + github.com/gin-contrib/sse v1.1.0 // indirect + github.com/gin-gonic/gin v1.11.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.27.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/go-zookeeper/zk v1.0.4 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/goccy/go-yaml v1.18.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/quic-go/qpack v0.5.1 // indirect + github.com/quic-go/quic-go v0.54.0 // indirect + github.com/sagikazarmark/locafero v0.9.0 // indirect + github.com/soheilhy/cmux v0.1.5 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.14.0 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/spf13/pflag v1.0.6 // indirect + github.com/stretchr/objx v0.5.2 // indirect + github.com/stretchr/testify v1.11.1 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.3.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.62.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.62.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + go.uber.org/mock v0.5.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/arch v0.20.0 // indirect + golang.org/x/crypto v0.40.0 // indirect + golang.org/x/mod v0.25.0 // indirect + golang.org/x/net v0.42.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.27.0 // indirect + golang.org/x/tools v0.34.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250715232539-7130f93afb79 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/eigenix/go.sum b/eigenix/go.sum new file mode 100644 index 00000000..789da5e1 --- /dev/null +++ b/eigenix/go.sum @@ -0,0 +1,234 @@ +github.com/DataDog/datadog-go/v5 v5.5.0 h1:G5KHeB8pWBNXT4Jtw0zAkhdxEAWSpWH00geHI6LDrKU= +github.com/DataDog/datadog-go/v5 v5.5.0/go.mod h1:K9kcYBlxkcPP8tvvjZZKs/m1edNAUFzBbdpTUKfCsuw= +github.com/Meesho/go-core v1.30.17 h1:igxhW9N/F/sFSG00AMIe73NjbZfYYoqzhdnnNZZMcss= +github.com/Meesho/go-core v1.30.17/go.mod h1:Ftn5QRPrCwy/c/m0Mp8zoX0dWcW4PZvtPwyi/qFL6lc= +github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/bits-and-blooms/bitset v1.22.0 h1:Tquv9S8+SGaS3EhyA+up3FXzmkhxPGjQQCkcs2uw7w4= +github.com/bits-and-blooms/bitset v1.22.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8= +github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/failsafe-go/failsafe-go v0.6.9 h1:7HWEzOlFOjNerxgWd8onWA2j/aEuqyAtuX6uWya/364= +github.com/failsafe-go/failsafe-go v0.6.9/go.mod h1:zb7xfp1/DJ7Mn4xJhVSZ9F2qmmMEGvYHxEOHYK5SIm0= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= +github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= +github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= +github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= +github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= +github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= +github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/go-zookeeper/zk v1.0.4 h1:DPzxraQx7OrPyXq2phlGlNSIyWEsAox0RJmjTseMV6I= +github.com/go-zookeeper/zk v1.0.4/go.mod h1:nOB03cncLtlp4t+UAkGSV+9beXP/akpekBwL+UX1Qcw= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= +github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 h1:X5VWvz21y3gzm9Nw/kaUeku/1+uBhcekkmy4IkffJww= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVTJueD4wSS5hT7zTt4Mrutd90= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/qdrant/go-client v1.15.2 h1:3NSyxpHrfQTP6JLDAwqNUShz6V9tuRBKz0G7hSOxrac= +github.com/qdrant/go-client v1.15.2/go.mod h1:iO8ts78jL4x6LDHFOViyYWELVtIBDTjOykBmiOTHLnQ= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg= +github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k= +github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= +github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.14.0 h1:9tH6MapGnn/j0eb0yIXiLjERO8RB6xIVZRDCX7PtqWA= +github.com/spf13/afero v1.14.0/go.mod h1:acJQ8t0ohCGuMN3O+Pv0V0hgMxNYDlvdk+VTfyZmbYo= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= +github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA= +github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.62.0 h1:fZNpsQuTwFFSGC96aJexNOBrCD7PjD9Tm/HyHtXhmnk= +go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.62.0/go.mod h1:+NFxPSeYg0SoiRUO4k0ceJYMCY9FiRbYFmByUpm7GJY= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.62.0 h1:rbRJ8BBoVMsQShESYZ0FkvcITu8X8QNwJogcLUmDNNw= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.62.0/go.mod h1:ru6KHrNtNHxM4nD/vd6QrLVWgKhxPYgblq4VAtNawTQ= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY= +go.opentelemetry.io/contrib/propagators/b3 v1.37.0 h1:0aGKdIuVhy5l4GClAjl72ntkZJhijf2wg1S7b5oLoYA= +go.opentelemetry.io/contrib/propagators/b3 v1.37.0/go.mod h1:nhyrxEJEOQdwR15zXrCKI6+cJK60PXAkJ/jRyfhr2mg= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0 h1:Ahq7pZmv87yiyn3jeFz/LekZmPLLdKejuO3NcK9MssM= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0/go.mod h1:MJTqhM0im3mRLw1i8uGHnCvUEeS7VwRyxlLC78PA18M= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0 h1:EtFWSnwW9hGObjkIdmlnWSydO+Qs8OwzfzXLUPg4xOc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.37.0/go.mod h1:QjUEoiGCPkvFZ/MjK6ZZfNOS6mfVEVKYE99dFhuN2LI= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0 h1:SNhVp/9q4Go/XHBkQ1/d5u9P/U+L1yaGPoi0x+mStaI= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0/go.mod h1:tx8OOlGH6R4kLV67YaYO44GFXloEjGPZuMjEkaaqIp4= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.opentelemetry.io/proto/otlp v1.7.0 h1:jX1VolD6nHuFzOYso2E73H85i92Mv8JQYk0K9vz09os= +go.opentelemetry.io/proto/otlp v1.7.0/go.mod h1:fSKjH6YJ7HDlwzltzyMj036AJ3ejJLCgCSHGj4efDDo= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= +golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto v0.0.0-20250715232539-7130f93afb79 h1:Nt6z9UHqSlIdIGJdz6KhTIs2VRx/iOsA5iE8bmQNcxs= +google.golang.org/genproto/googleapis/api v0.0.0-20250715232539-7130f93afb79 h1:iOye66xuaAK0WnkPuhQPUFy8eJcmwUXqGGP3om6IxX8= +google.golang.org/genproto/googleapis/api v0.0.0-20250715232539-7130f93afb79/go.mod h1:HKJDgKsFUnv5VAGeQjz8kxcgDP0HoE0iZNp0OdZNlhE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250715232539-7130f93afb79 h1:1ZwqphdOdWYXsUHgMpU/101nCtf/kSp9hOrcvFsnl10= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250715232539-7130f93afb79/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= +google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= +google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= +google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/eigenix/internal/client/eigenix.pb.go b/eigenix/internal/client/eigenix.pb.go new file mode 100644 index 00000000..1544c04e --- /dev/null +++ b/eigenix/internal/client/eigenix.pb.go @@ -0,0 +1,1597 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v5.29.3 +// source: internal/client/eigenix.proto + +package client + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// DatabaseType specifies which vector database backend to use +type DatabaseType int32 + +const ( + DatabaseType_HNSW_LIB DatabaseType = 0 // Use local HNSW library (in-memory) + DatabaseType_QDRANT DatabaseType = 1 // Use Qdrant (remote vector database) +) + +// Enum value maps for DatabaseType. +var ( + DatabaseType_name = map[int32]string{ + 0: "HNSW_LIB", + 1: "QDRANT", + } + DatabaseType_value = map[string]int32{ + "HNSW_LIB": 0, + "QDRANT": 1, + } +) + +func (x DatabaseType) Enum() *DatabaseType { + p := new(DatabaseType) + *p = x + return p +} + +func (x DatabaseType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DatabaseType) Descriptor() protoreflect.EnumDescriptor { + return file_internal_client_eigenix_proto_enumTypes[0].Descriptor() +} + +func (DatabaseType) Type() protoreflect.EnumType { + return &file_internal_client_eigenix_proto_enumTypes[0] +} + +func (x DatabaseType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DatabaseType.Descriptor instead. +func (DatabaseType) EnumDescriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{0} +} + +type Point struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + Vector []float32 `protobuf:"fixed32,2,rep,packed,name=vector,proto3" json:"vector,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Point) Reset() { + *x = Point{} + mi := &file_internal_client_eigenix_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Point) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Point) ProtoMessage() {} + +func (x *Point) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Point.ProtoReflect.Descriptor instead. +func (*Point) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{0} +} + +func (x *Point) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *Point) GetVector() []float32 { + if x != nil { + return x.Vector + } + return nil +} + +type SearchResult struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + Distance float32 `protobuf:"fixed32,2,opt,name=distance,proto3" json:"distance,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchResult) Reset() { + *x = SearchResult{} + mi := &file_internal_client_eigenix_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchResult) ProtoMessage() {} + +func (x *SearchResult) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchResult.ProtoReflect.Descriptor instead. +func (*SearchResult) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{1} +} + +func (x *SearchResult) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *SearchResult) GetDistance() float32 { + if x != nil { + return x.Distance + } + return 0 +} + +type IndexInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Dimension int32 `protobuf:"varint,2,opt,name=dimension,proto3" json:"dimension,omitempty"` + Space string `protobuf:"bytes,3,opt,name=space,proto3" json:"space,omitempty"` + MaxElements int32 `protobuf:"varint,4,opt,name=max_elements,json=maxElements,proto3" json:"max_elements,omitempty"` + M int32 `protobuf:"varint,5,opt,name=m,proto3" json:"m,omitempty"` + EfConstruction int32 `protobuf:"varint,6,opt,name=ef_construction,json=efConstruction,proto3" json:"ef_construction,omitempty"` + AllowReplaceDeleted bool `protobuf:"varint,7,opt,name=allow_replace_deleted,json=allowReplaceDeleted,proto3" json:"allow_replace_deleted,omitempty"` + CurrentCount int32 `protobuf:"varint,8,opt,name=current_count,json=currentCount,proto3" json:"current_count,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IndexInfo) Reset() { + *x = IndexInfo{} + mi := &file_internal_client_eigenix_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IndexInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IndexInfo) ProtoMessage() {} + +func (x *IndexInfo) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IndexInfo.ProtoReflect.Descriptor instead. +func (*IndexInfo) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{2} +} + +func (x *IndexInfo) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *IndexInfo) GetDimension() int32 { + if x != nil { + return x.Dimension + } + return 0 +} + +func (x *IndexInfo) GetSpace() string { + if x != nil { + return x.Space + } + return "" +} + +func (x *IndexInfo) GetMaxElements() int32 { + if x != nil { + return x.MaxElements + } + return 0 +} + +func (x *IndexInfo) GetM() int32 { + if x != nil { + return x.M + } + return 0 +} + +func (x *IndexInfo) GetEfConstruction() int32 { + if x != nil { + return x.EfConstruction + } + return 0 +} + +func (x *IndexInfo) GetAllowReplaceDeleted() bool { + if x != nil { + return x.AllowReplaceDeleted + } + return false +} + +func (x *IndexInfo) GetCurrentCount() int32 { + if x != nil { + return x.CurrentCount + } + return 0 +} + +type CreateIndexRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Dimension int32 `protobuf:"varint,2,opt,name=dimension,proto3" json:"dimension,omitempty"` + Space string `protobuf:"bytes,3,opt,name=space,proto3" json:"space,omitempty"` + MaxElements int32 `protobuf:"varint,4,opt,name=max_elements,json=maxElements,proto3" json:"max_elements,omitempty"` + M int32 `protobuf:"varint,5,opt,name=m,proto3" json:"m,omitempty"` + EfConstruction int32 `protobuf:"varint,6,opt,name=ef_construction,json=efConstruction,proto3" json:"ef_construction,omitempty"` + AllowReplaceDeleted bool `protobuf:"varint,7,opt,name=allow_replace_deleted,json=allowReplaceDeleted,proto3" json:"allow_replace_deleted,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CreateIndexRequest) Reset() { + *x = CreateIndexRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CreateIndexRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateIndexRequest) ProtoMessage() {} + +func (x *CreateIndexRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateIndexRequest.ProtoReflect.Descriptor instead. +func (*CreateIndexRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{3} +} + +func (x *CreateIndexRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *CreateIndexRequest) GetDimension() int32 { + if x != nil { + return x.Dimension + } + return 0 +} + +func (x *CreateIndexRequest) GetSpace() string { + if x != nil { + return x.Space + } + return "" +} + +func (x *CreateIndexRequest) GetMaxElements() int32 { + if x != nil { + return x.MaxElements + } + return 0 +} + +func (x *CreateIndexRequest) GetM() int32 { + if x != nil { + return x.M + } + return 0 +} + +func (x *CreateIndexRequest) GetEfConstruction() int32 { + if x != nil { + return x.EfConstruction + } + return 0 +} + +func (x *CreateIndexRequest) GetAllowReplaceDeleted() bool { + if x != nil { + return x.AllowReplaceDeleted + } + return false +} + +type CreateIndexResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CreateIndexResponse) Reset() { + *x = CreateIndexResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CreateIndexResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateIndexResponse) ProtoMessage() {} + +func (x *CreateIndexResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateIndexResponse.ProtoReflect.Descriptor instead. +func (*CreateIndexResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{4} +} + +func (x *CreateIndexResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *CreateIndexResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type GetIndexRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetIndexRequest) Reset() { + *x = GetIndexRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetIndexRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetIndexRequest) ProtoMessage() {} + +func (x *GetIndexRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetIndexRequest.ProtoReflect.Descriptor instead. +func (*GetIndexRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{5} +} + +func (x *GetIndexRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +type GetIndexResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Index *IndexInfo `protobuf:"bytes,1,opt,name=index,proto3" json:"index,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetIndexResponse) Reset() { + *x = GetIndexResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetIndexResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetIndexResponse) ProtoMessage() {} + +func (x *GetIndexResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetIndexResponse.ProtoReflect.Descriptor instead. +func (*GetIndexResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{6} +} + +func (x *GetIndexResponse) GetIndex() *IndexInfo { + if x != nil { + return x.Index + } + return nil +} + +type DeleteIndexRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeleteIndexRequest) Reset() { + *x = DeleteIndexRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeleteIndexRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteIndexRequest) ProtoMessage() {} + +func (x *DeleteIndexRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteIndexRequest.ProtoReflect.Descriptor instead. +func (*DeleteIndexRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{7} +} + +func (x *DeleteIndexRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +type DeleteIndexResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeleteIndexResponse) Reset() { + *x = DeleteIndexResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeleteIndexResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteIndexResponse) ProtoMessage() {} + +func (x *DeleteIndexResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteIndexResponse.ProtoReflect.Descriptor instead. +func (*DeleteIndexResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{8} +} + +func (x *DeleteIndexResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *DeleteIndexResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type ListIndicesRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListIndicesRequest) Reset() { + *x = ListIndicesRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListIndicesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListIndicesRequest) ProtoMessage() {} + +func (x *ListIndicesRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListIndicesRequest.ProtoReflect.Descriptor instead. +func (*ListIndicesRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{9} +} + +type ListIndicesResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Indices []*IndexInfo `protobuf:"bytes,1,rep,name=indices,proto3" json:"indices,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListIndicesResponse) Reset() { + *x = ListIndicesResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListIndicesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListIndicesResponse) ProtoMessage() {} + +func (x *ListIndicesResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListIndicesResponse.ProtoReflect.Descriptor instead. +func (*ListIndicesResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{10} +} + +func (x *ListIndicesResponse) GetIndices() []*IndexInfo { + if x != nil { + return x.Indices + } + return nil +} + +type IndexOperationRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IndexOperationRequest) Reset() { + *x = IndexOperationRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IndexOperationRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IndexOperationRequest) ProtoMessage() {} + +func (x *IndexOperationRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IndexOperationRequest.ProtoReflect.Descriptor instead. +func (*IndexOperationRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{11} +} + +func (x *IndexOperationRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +type IndexOperationResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IndexOperationResponse) Reset() { + *x = IndexOperationResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IndexOperationResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IndexOperationResponse) ProtoMessage() {} + +func (x *IndexOperationResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IndexOperationResponse.ProtoReflect.Descriptor instead. +func (*IndexOperationResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{12} +} + +func (x *IndexOperationResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *IndexOperationResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type PointsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + IndexName string `protobuf:"bytes,1,opt,name=index_name,json=indexName,proto3" json:"index_name,omitempty"` + Points []*Point `protobuf:"bytes,2,rep,name=points,proto3" json:"points,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PointsRequest) Reset() { + *x = PointsRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PointsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PointsRequest) ProtoMessage() {} + +func (x *PointsRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PointsRequest.ProtoReflect.Descriptor instead. +func (*PointsRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{13} +} + +func (x *PointsRequest) GetIndexName() string { + if x != nil { + return x.IndexName + } + return "" +} + +func (x *PointsRequest) GetPoints() []*Point { + if x != nil { + return x.Points + } + return nil +} + +type PointsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + AffectedCount int32 `protobuf:"varint,2,opt,name=affected_count,json=affectedCount,proto3" json:"affected_count,omitempty"` + TotalCount int32 `protobuf:"varint,3,opt,name=total_count,json=totalCount,proto3" json:"total_count,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PointsResponse) Reset() { + *x = PointsResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PointsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PointsResponse) ProtoMessage() {} + +func (x *PointsResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PointsResponse.ProtoReflect.Descriptor instead. +func (*PointsResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{14} +} + +func (x *PointsResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *PointsResponse) GetAffectedCount() int32 { + if x != nil { + return x.AffectedCount + } + return 0 +} + +func (x *PointsResponse) GetTotalCount() int32 { + if x != nil { + return x.TotalCount + } + return 0 +} + +type DeletePointsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + IndexName string `protobuf:"bytes,1,opt,name=index_name,json=indexName,proto3" json:"index_name,omitempty"` + Ids []uint64 `protobuf:"varint,2,rep,packed,name=ids,proto3" json:"ids,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeletePointsRequest) Reset() { + *x = DeletePointsRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeletePointsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeletePointsRequest) ProtoMessage() {} + +func (x *DeletePointsRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeletePointsRequest.ProtoReflect.Descriptor instead. +func (*DeletePointsRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{15} +} + +func (x *DeletePointsRequest) GetIndexName() string { + if x != nil { + return x.IndexName + } + return "" +} + +func (x *DeletePointsRequest) GetIds() []uint64 { + if x != nil { + return x.Ids + } + return nil +} + +type DeletePointsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + DeletedCount int32 `protobuf:"varint,2,opt,name=deleted_count,json=deletedCount,proto3" json:"deleted_count,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeletePointsResponse) Reset() { + *x = DeletePointsResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeletePointsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeletePointsResponse) ProtoMessage() {} + +func (x *DeletePointsResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[16] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeletePointsResponse.ProtoReflect.Descriptor instead. +func (*DeletePointsResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{16} +} + +func (x *DeletePointsResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *DeletePointsResponse) GetDeletedCount() int32 { + if x != nil { + return x.DeletedCount + } + return 0 +} + +type SearchRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + IndexName string `protobuf:"bytes,1,opt,name=index_name,json=indexName,proto3" json:"index_name,omitempty"` + Vector []float32 `protobuf:"fixed32,2,rep,packed,name=vector,proto3" json:"vector,omitempty"` + Limit int32 `protobuf:"varint,3,opt,name=limit,proto3" json:"limit,omitempty"` + SearchParams *SearchParams `protobuf:"bytes,4,opt,name=search_params,json=searchParams,proto3" json:"search_params,omitempty"` + DatabaseType DatabaseType `protobuf:"varint,5,opt,name=database_type,json=databaseType,proto3,enum=eigenix.DatabaseType" json:"database_type,omitempty"` // Which database backend to use + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchRequest) Reset() { + *x = SearchRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchRequest) ProtoMessage() {} + +func (x *SearchRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[17] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchRequest.ProtoReflect.Descriptor instead. +func (*SearchRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{17} +} + +func (x *SearchRequest) GetIndexName() string { + if x != nil { + return x.IndexName + } + return "" +} + +func (x *SearchRequest) GetVector() []float32 { + if x != nil { + return x.Vector + } + return nil +} + +func (x *SearchRequest) GetLimit() int32 { + if x != nil { + return x.Limit + } + return 0 +} + +func (x *SearchRequest) GetSearchParams() *SearchParams { + if x != nil { + return x.SearchParams + } + return nil +} + +func (x *SearchRequest) GetDatabaseType() DatabaseType { + if x != nil { + return x.DatabaseType + } + return DatabaseType_HNSW_LIB +} + +type SearchResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Results []*SearchResult `protobuf:"bytes,1,rep,name=results,proto3" json:"results,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchResponse) Reset() { + *x = SearchResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchResponse) ProtoMessage() {} + +func (x *SearchResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[18] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchResponse.ProtoReflect.Descriptor instead. +func (*SearchResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{18} +} + +func (x *SearchResponse) GetResults() []*SearchResult { + if x != nil { + return x.Results + } + return nil +} + +type SearchParams struct { + state protoimpl.MessageState `protogen:"open.v1"` + UseIvf bool `protobuf:"varint,1,opt,name=use_ivf,json=useIvf,proto3" json:"use_ivf,omitempty"` + CentCount int32 `protobuf:"varint,2,opt,name=cent_count,json=centCount,proto3" json:"cent_count,omitempty"` // Number of centroids to search (only for IVF mode) + SearchIndexedOnly bool `protobuf:"varint,3,opt,name=search_indexed_only,json=searchIndexedOnly,proto3" json:"search_indexed_only,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchParams) Reset() { + *x = SearchParams{} + mi := &file_internal_client_eigenix_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchParams) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchParams) ProtoMessage() {} + +func (x *SearchParams) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[19] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchParams.ProtoReflect.Descriptor instead. +func (*SearchParams) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{19} +} + +func (x *SearchParams) GetUseIvf() bool { + if x != nil { + return x.UseIvf + } + return false +} + +func (x *SearchParams) GetCentCount() int32 { + if x != nil { + return x.CentCount + } + return 0 +} + +func (x *SearchParams) GetSearchIndexedOnly() bool { + if x != nil { + return x.SearchIndexedOnly + } + return false +} + +type BatchSearchRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + IndexName string `protobuf:"bytes,1,opt,name=index_name,json=indexName,proto3" json:"index_name,omitempty"` + Vectors []*Vector `protobuf:"bytes,2,rep,name=vectors,proto3" json:"vectors,omitempty"` + Limit int32 `protobuf:"varint,3,opt,name=limit,proto3" json:"limit,omitempty"` + SearchParams *SearchParams `protobuf:"bytes,4,opt,name=search_params,json=searchParams,proto3" json:"search_params,omitempty"` + DatabaseType DatabaseType `protobuf:"varint,5,opt,name=database_type,json=databaseType,proto3,enum=eigenix.DatabaseType" json:"database_type,omitempty"` // Which database backend to use + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *BatchSearchRequest) Reset() { + *x = BatchSearchRequest{} + mi := &file_internal_client_eigenix_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *BatchSearchRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BatchSearchRequest) ProtoMessage() {} + +func (x *BatchSearchRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[20] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BatchSearchRequest.ProtoReflect.Descriptor instead. +func (*BatchSearchRequest) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{20} +} + +func (x *BatchSearchRequest) GetIndexName() string { + if x != nil { + return x.IndexName + } + return "" +} + +func (x *BatchSearchRequest) GetVectors() []*Vector { + if x != nil { + return x.Vectors + } + return nil +} + +func (x *BatchSearchRequest) GetLimit() int32 { + if x != nil { + return x.Limit + } + return 0 +} + +func (x *BatchSearchRequest) GetSearchParams() *SearchParams { + if x != nil { + return x.SearchParams + } + return nil +} + +func (x *BatchSearchRequest) GetDatabaseType() DatabaseType { + if x != nil { + return x.DatabaseType + } + return DatabaseType_HNSW_LIB +} + +type Vector struct { + state protoimpl.MessageState `protogen:"open.v1"` + Values []float32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Vector) Reset() { + *x = Vector{} + mi := &file_internal_client_eigenix_proto_msgTypes[21] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Vector) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Vector) ProtoMessage() {} + +func (x *Vector) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[21] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Vector.ProtoReflect.Descriptor instead. +func (*Vector) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{21} +} + +func (x *Vector) GetValues() []float32 { + if x != nil { + return x.Values + } + return nil +} + +type BatchSearchResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + ResultSets []*SearchResultSet `protobuf:"bytes,1,rep,name=result_sets,json=resultSets,proto3" json:"result_sets,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *BatchSearchResponse) Reset() { + *x = BatchSearchResponse{} + mi := &file_internal_client_eigenix_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *BatchSearchResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BatchSearchResponse) ProtoMessage() {} + +func (x *BatchSearchResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[22] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BatchSearchResponse.ProtoReflect.Descriptor instead. +func (*BatchSearchResponse) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{22} +} + +func (x *BatchSearchResponse) GetResultSets() []*SearchResultSet { + if x != nil { + return x.ResultSets + } + return nil +} + +type SearchResultSet struct { + state protoimpl.MessageState `protogen:"open.v1"` + Results []*SearchResult `protobuf:"bytes,1,rep,name=results,proto3" json:"results,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchResultSet) Reset() { + *x = SearchResultSet{} + mi := &file_internal_client_eigenix_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchResultSet) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchResultSet) ProtoMessage() {} + +func (x *SearchResultSet) ProtoReflect() protoreflect.Message { + mi := &file_internal_client_eigenix_proto_msgTypes[23] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchResultSet.ProtoReflect.Descriptor instead. +func (*SearchResultSet) Descriptor() ([]byte, []int) { + return file_internal_client_eigenix_proto_rawDescGZIP(), []int{23} +} + +func (x *SearchResultSet) GetResults() []*SearchResult { + if x != nil { + return x.Results + } + return nil +} + +var File_internal_client_eigenix_proto protoreflect.FileDescriptor + +const file_internal_client_eigenix_proto_rawDesc = "" + + "\n" + + "\x1dinternal/client/eigenix.proto\x12\aeigenix\"/\n" + + "\x05Point\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\x12\x16\n" + + "\x06vector\x18\x02 \x03(\x02R\x06vector\":\n" + + "\fSearchResult\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\x12\x1a\n" + + "\bdistance\x18\x02 \x01(\x02R\bdistance\"\x86\x02\n" + + "\tIndexInfo\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1c\n" + + "\tdimension\x18\x02 \x01(\x05R\tdimension\x12\x14\n" + + "\x05space\x18\x03 \x01(\tR\x05space\x12!\n" + + "\fmax_elements\x18\x04 \x01(\x05R\vmaxElements\x12\f\n" + + "\x01m\x18\x05 \x01(\x05R\x01m\x12'\n" + + "\x0fef_construction\x18\x06 \x01(\x05R\x0eefConstruction\x122\n" + + "\x15allow_replace_deleted\x18\a \x01(\bR\x13allowReplaceDeleted\x12#\n" + + "\rcurrent_count\x18\b \x01(\x05R\fcurrentCount\"\xea\x01\n" + + "\x12CreateIndexRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1c\n" + + "\tdimension\x18\x02 \x01(\x05R\tdimension\x12\x14\n" + + "\x05space\x18\x03 \x01(\tR\x05space\x12!\n" + + "\fmax_elements\x18\x04 \x01(\x05R\vmaxElements\x12\f\n" + + "\x01m\x18\x05 \x01(\x05R\x01m\x12'\n" + + "\x0fef_construction\x18\x06 \x01(\x05R\x0eefConstruction\x122\n" + + "\x15allow_replace_deleted\x18\a \x01(\bR\x13allowReplaceDeleted\"I\n" + + "\x13CreateIndexResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\"%\n" + + "\x0fGetIndexRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\"<\n" + + "\x10GetIndexResponse\x12(\n" + + "\x05index\x18\x01 \x01(\v2\x12.eigenix.IndexInfoR\x05index\"(\n" + + "\x12DeleteIndexRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\"I\n" + + "\x13DeleteIndexResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\"\x14\n" + + "\x12ListIndicesRequest\"C\n" + + "\x13ListIndicesResponse\x12,\n" + + "\aindices\x18\x01 \x03(\v2\x12.eigenix.IndexInfoR\aindices\"+\n" + + "\x15IndexOperationRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\"L\n" + + "\x16IndexOperationResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\"V\n" + + "\rPointsRequest\x12\x1d\n" + + "\n" + + "index_name\x18\x01 \x01(\tR\tindexName\x12&\n" + + "\x06points\x18\x02 \x03(\v2\x0e.eigenix.PointR\x06points\"r\n" + + "\x0ePointsResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12%\n" + + "\x0eaffected_count\x18\x02 \x01(\x05R\raffectedCount\x12\x1f\n" + + "\vtotal_count\x18\x03 \x01(\x05R\n" + + "totalCount\"F\n" + + "\x13DeletePointsRequest\x12\x1d\n" + + "\n" + + "index_name\x18\x01 \x01(\tR\tindexName\x12\x10\n" + + "\x03ids\x18\x02 \x03(\x04R\x03ids\"U\n" + + "\x14DeletePointsResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12#\n" + + "\rdeleted_count\x18\x02 \x01(\x05R\fdeletedCount\"\xd4\x01\n" + + "\rSearchRequest\x12\x1d\n" + + "\n" + + "index_name\x18\x01 \x01(\tR\tindexName\x12\x16\n" + + "\x06vector\x18\x02 \x03(\x02R\x06vector\x12\x14\n" + + "\x05limit\x18\x03 \x01(\x05R\x05limit\x12:\n" + + "\rsearch_params\x18\x04 \x01(\v2\x15.eigenix.SearchParamsR\fsearchParams\x12:\n" + + "\rdatabase_type\x18\x05 \x01(\x0e2\x15.eigenix.DatabaseTypeR\fdatabaseType\"A\n" + + "\x0eSearchResponse\x12/\n" + + "\aresults\x18\x01 \x03(\v2\x15.eigenix.SearchResultR\aresults\"v\n" + + "\fSearchParams\x12\x17\n" + + "\ause_ivf\x18\x01 \x01(\bR\x06useIvf\x12\x1d\n" + + "\n" + + "cent_count\x18\x02 \x01(\x05R\tcentCount\x12.\n" + + "\x13search_indexed_only\x18\x03 \x01(\bR\x11searchIndexedOnly\"\xec\x01\n" + + "\x12BatchSearchRequest\x12\x1d\n" + + "\n" + + "index_name\x18\x01 \x01(\tR\tindexName\x12)\n" + + "\avectors\x18\x02 \x03(\v2\x0f.eigenix.VectorR\avectors\x12\x14\n" + + "\x05limit\x18\x03 \x01(\x05R\x05limit\x12:\n" + + "\rsearch_params\x18\x04 \x01(\v2\x15.eigenix.SearchParamsR\fsearchParams\x12:\n" + + "\rdatabase_type\x18\x05 \x01(\x0e2\x15.eigenix.DatabaseTypeR\fdatabaseType\" \n" + + "\x06Vector\x12\x16\n" + + "\x06values\x18\x01 \x03(\x02R\x06values\"P\n" + + "\x13BatchSearchResponse\x129\n" + + "\vresult_sets\x18\x01 \x03(\v2\x18.eigenix.SearchResultSetR\n" + + "resultSets\"B\n" + + "\x0fSearchResultSet\x12/\n" + + "\aresults\x18\x01 \x03(\v2\x15.eigenix.SearchResultR\aresults*(\n" + + "\fDatabaseType\x12\f\n" + + "\bHNSW_LIB\x10\x00\x12\n" + + "\n" + + "\x06QDRANT\x10\x012\x9c\x06\n" + + "\x0eEigenixService\x12H\n" + + "\vCreateIndex\x12\x1b.eigenix.CreateIndexRequest\x1a\x1c.eigenix.CreateIndexResponse\x12?\n" + + "\bGetIndex\x12\x18.eigenix.GetIndexRequest\x1a\x19.eigenix.GetIndexResponse\x12H\n" + + "\vDeleteIndex\x12\x1b.eigenix.DeleteIndexRequest\x1a\x1c.eigenix.DeleteIndexResponse\x12H\n" + + "\vListIndices\x12\x1b.eigenix.ListIndicesRequest\x1a\x1c.eigenix.ListIndicesResponse\x12L\n" + + "\tSaveIndex\x12\x1e.eigenix.IndexOperationRequest\x1a\x1f.eigenix.IndexOperationResponse\x12L\n" + + "\tLoadIndex\x12\x1e.eigenix.IndexOperationRequest\x1a\x1f.eigenix.IndexOperationResponse\x12<\n" + + "\tAddPoints\x12\x16.eigenix.PointsRequest\x1a\x17.eigenix.PointsResponse\x12?\n" + + "\fUpdatePoints\x12\x16.eigenix.PointsRequest\x1a\x17.eigenix.PointsResponse\x12K\n" + + "\fDeletePoints\x12\x1c.eigenix.DeletePointsRequest\x1a\x1d.eigenix.DeletePointsResponse\x129\n" + + "\x06Search\x12\x16.eigenix.SearchRequest\x1a\x17.eigenix.SearchResponse\x12H\n" + + "\vBatchSearch\x12\x1b.eigenix.BatchSearchRequest\x1a\x1c.eigenix.BatchSearchResponseB\x11Z\x0finternal/clientb\x06proto3" + +var ( + file_internal_client_eigenix_proto_rawDescOnce sync.Once + file_internal_client_eigenix_proto_rawDescData []byte +) + +func file_internal_client_eigenix_proto_rawDescGZIP() []byte { + file_internal_client_eigenix_proto_rawDescOnce.Do(func() { + file_internal_client_eigenix_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_internal_client_eigenix_proto_rawDesc), len(file_internal_client_eigenix_proto_rawDesc))) + }) + return file_internal_client_eigenix_proto_rawDescData +} + +var file_internal_client_eigenix_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_internal_client_eigenix_proto_msgTypes = make([]protoimpl.MessageInfo, 24) +var file_internal_client_eigenix_proto_goTypes = []any{ + (DatabaseType)(0), // 0: eigenix.DatabaseType + (*Point)(nil), // 1: eigenix.Point + (*SearchResult)(nil), // 2: eigenix.SearchResult + (*IndexInfo)(nil), // 3: eigenix.IndexInfo + (*CreateIndexRequest)(nil), // 4: eigenix.CreateIndexRequest + (*CreateIndexResponse)(nil), // 5: eigenix.CreateIndexResponse + (*GetIndexRequest)(nil), // 6: eigenix.GetIndexRequest + (*GetIndexResponse)(nil), // 7: eigenix.GetIndexResponse + (*DeleteIndexRequest)(nil), // 8: eigenix.DeleteIndexRequest + (*DeleteIndexResponse)(nil), // 9: eigenix.DeleteIndexResponse + (*ListIndicesRequest)(nil), // 10: eigenix.ListIndicesRequest + (*ListIndicesResponse)(nil), // 11: eigenix.ListIndicesResponse + (*IndexOperationRequest)(nil), // 12: eigenix.IndexOperationRequest + (*IndexOperationResponse)(nil), // 13: eigenix.IndexOperationResponse + (*PointsRequest)(nil), // 14: eigenix.PointsRequest + (*PointsResponse)(nil), // 15: eigenix.PointsResponse + (*DeletePointsRequest)(nil), // 16: eigenix.DeletePointsRequest + (*DeletePointsResponse)(nil), // 17: eigenix.DeletePointsResponse + (*SearchRequest)(nil), // 18: eigenix.SearchRequest + (*SearchResponse)(nil), // 19: eigenix.SearchResponse + (*SearchParams)(nil), // 20: eigenix.SearchParams + (*BatchSearchRequest)(nil), // 21: eigenix.BatchSearchRequest + (*Vector)(nil), // 22: eigenix.Vector + (*BatchSearchResponse)(nil), // 23: eigenix.BatchSearchResponse + (*SearchResultSet)(nil), // 24: eigenix.SearchResultSet +} +var file_internal_client_eigenix_proto_depIdxs = []int32{ + 3, // 0: eigenix.GetIndexResponse.index:type_name -> eigenix.IndexInfo + 3, // 1: eigenix.ListIndicesResponse.indices:type_name -> eigenix.IndexInfo + 1, // 2: eigenix.PointsRequest.points:type_name -> eigenix.Point + 20, // 3: eigenix.SearchRequest.search_params:type_name -> eigenix.SearchParams + 0, // 4: eigenix.SearchRequest.database_type:type_name -> eigenix.DatabaseType + 2, // 5: eigenix.SearchResponse.results:type_name -> eigenix.SearchResult + 22, // 6: eigenix.BatchSearchRequest.vectors:type_name -> eigenix.Vector + 20, // 7: eigenix.BatchSearchRequest.search_params:type_name -> eigenix.SearchParams + 0, // 8: eigenix.BatchSearchRequest.database_type:type_name -> eigenix.DatabaseType + 24, // 9: eigenix.BatchSearchResponse.result_sets:type_name -> eigenix.SearchResultSet + 2, // 10: eigenix.SearchResultSet.results:type_name -> eigenix.SearchResult + 4, // 11: eigenix.EigenixService.CreateIndex:input_type -> eigenix.CreateIndexRequest + 6, // 12: eigenix.EigenixService.GetIndex:input_type -> eigenix.GetIndexRequest + 8, // 13: eigenix.EigenixService.DeleteIndex:input_type -> eigenix.DeleteIndexRequest + 10, // 14: eigenix.EigenixService.ListIndices:input_type -> eigenix.ListIndicesRequest + 12, // 15: eigenix.EigenixService.SaveIndex:input_type -> eigenix.IndexOperationRequest + 12, // 16: eigenix.EigenixService.LoadIndex:input_type -> eigenix.IndexOperationRequest + 14, // 17: eigenix.EigenixService.AddPoints:input_type -> eigenix.PointsRequest + 14, // 18: eigenix.EigenixService.UpdatePoints:input_type -> eigenix.PointsRequest + 16, // 19: eigenix.EigenixService.DeletePoints:input_type -> eigenix.DeletePointsRequest + 18, // 20: eigenix.EigenixService.Search:input_type -> eigenix.SearchRequest + 21, // 21: eigenix.EigenixService.BatchSearch:input_type -> eigenix.BatchSearchRequest + 5, // 22: eigenix.EigenixService.CreateIndex:output_type -> eigenix.CreateIndexResponse + 7, // 23: eigenix.EigenixService.GetIndex:output_type -> eigenix.GetIndexResponse + 9, // 24: eigenix.EigenixService.DeleteIndex:output_type -> eigenix.DeleteIndexResponse + 11, // 25: eigenix.EigenixService.ListIndices:output_type -> eigenix.ListIndicesResponse + 13, // 26: eigenix.EigenixService.SaveIndex:output_type -> eigenix.IndexOperationResponse + 13, // 27: eigenix.EigenixService.LoadIndex:output_type -> eigenix.IndexOperationResponse + 15, // 28: eigenix.EigenixService.AddPoints:output_type -> eigenix.PointsResponse + 15, // 29: eigenix.EigenixService.UpdatePoints:output_type -> eigenix.PointsResponse + 17, // 30: eigenix.EigenixService.DeletePoints:output_type -> eigenix.DeletePointsResponse + 19, // 31: eigenix.EigenixService.Search:output_type -> eigenix.SearchResponse + 23, // 32: eigenix.EigenixService.BatchSearch:output_type -> eigenix.BatchSearchResponse + 22, // [22:33] is the sub-list for method output_type + 11, // [11:22] is the sub-list for method input_type + 11, // [11:11] is the sub-list for extension type_name + 11, // [11:11] is the sub-list for extension extendee + 0, // [0:11] is the sub-list for field type_name +} + +func init() { file_internal_client_eigenix_proto_init() } +func file_internal_client_eigenix_proto_init() { + if File_internal_client_eigenix_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_internal_client_eigenix_proto_rawDesc), len(file_internal_client_eigenix_proto_rawDesc)), + NumEnums: 1, + NumMessages: 24, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_internal_client_eigenix_proto_goTypes, + DependencyIndexes: file_internal_client_eigenix_proto_depIdxs, + EnumInfos: file_internal_client_eigenix_proto_enumTypes, + MessageInfos: file_internal_client_eigenix_proto_msgTypes, + }.Build() + File_internal_client_eigenix_proto = out.File + file_internal_client_eigenix_proto_goTypes = nil + file_internal_client_eigenix_proto_depIdxs = nil +} diff --git a/eigenix/internal/client/eigenix.proto b/eigenix/internal/client/eigenix.proto new file mode 100644 index 00000000..8c54b35b --- /dev/null +++ b/eigenix/internal/client/eigenix.proto @@ -0,0 +1,158 @@ +syntax = "proto3"; + +package eigenix; + +option go_package = "internal/client"; + +message Point { + uint64 id = 1; + repeated float vector = 2; +} + +message SearchResult { + uint64 id = 1; + float distance = 2; +} + +message IndexInfo { + string name = 1; + int32 dimension = 2; + string space = 3; + int32 max_elements = 4; + int32 m = 5; + int32 ef_construction = 6; + bool allow_replace_deleted = 7; + int32 current_count = 8; +} + +message CreateIndexRequest { + string name = 1; + int32 dimension = 2; + string space = 3; + int32 max_elements = 4; + int32 m = 5; + int32 ef_construction = 6; + bool allow_replace_deleted = 7; +} + +message CreateIndexResponse { + bool success = 1; + string message = 2; +} + +message GetIndexRequest { + string name = 1; +} + +message GetIndexResponse { + IndexInfo index = 1; +} + +message DeleteIndexRequest { + string name = 1; +} + +message DeleteIndexResponse { + bool success = 1; + string message = 2; +} + +message ListIndicesRequest {} +message ListIndicesResponse { + repeated IndexInfo indices = 1; +} + +message IndexOperationRequest { + string name = 1; +} + +message IndexOperationResponse { + bool success = 1; + string message = 2; +} + +message PointsRequest { + string index_name = 1; + repeated Point points = 2; +} + +message PointsResponse { + bool success = 1; + int32 affected_count = 2; + int32 total_count = 3; +} + +message DeletePointsRequest { + string index_name = 1; + repeated uint64 ids = 2; +} + +message DeletePointsResponse { + bool success = 1; + int32 deleted_count = 2; +} + +message SearchRequest { + string index_name = 1; + repeated float vector = 2; + int32 limit = 3; + SearchParams search_params = 4; + DatabaseType database_type = 5; // Which database backend to use +} + +message SearchResponse { + repeated SearchResult results = 1; +} + +// DatabaseType specifies which vector database backend to use +enum DatabaseType { + HNSW_LIB = 0; // Use local HNSW library (in-memory) + QDRANT = 1; // Use Qdrant (remote vector database) +} + +message SearchParams{ + bool use_ivf = 1; + int32 cent_count = 2; // Number of centroids to search (only for IVF mode) + bool search_indexed_only = 3; +} + +message BatchSearchRequest { + string index_name = 1; + repeated Vector vectors = 2; + int32 limit = 3; + SearchParams search_params = 4; + DatabaseType database_type = 5; // Which database backend to use +} + +message Vector { + repeated float values = 1; +} + +message BatchSearchResponse { + repeated SearchResultSet result_sets = 1; +} + +message SearchResultSet { + repeated SearchResult results = 1; +} + +service EigenixService { + + // Index management + rpc CreateIndex(CreateIndexRequest) returns (CreateIndexResponse); + rpc GetIndex(GetIndexRequest) returns (GetIndexResponse); + rpc DeleteIndex(DeleteIndexRequest) returns (DeleteIndexResponse); + rpc ListIndices(ListIndicesRequest) returns (ListIndicesResponse); + + rpc SaveIndex(IndexOperationRequest) returns (IndexOperationResponse); + rpc LoadIndex(IndexOperationRequest) returns (IndexOperationResponse); + + // Points + rpc AddPoints(PointsRequest) returns (PointsResponse); + rpc UpdatePoints(PointsRequest) returns (PointsResponse); + rpc DeletePoints(DeletePointsRequest) returns (DeletePointsResponse); + + // Search + rpc Search(SearchRequest) returns (SearchResponse); + rpc BatchSearch(BatchSearchRequest) returns (BatchSearchResponse); +} diff --git a/eigenix/internal/client/eigenix_grpc.pb.go b/eigenix/internal/client/eigenix_grpc.pb.go new file mode 100644 index 00000000..c6c5263b --- /dev/null +++ b/eigenix/internal/client/eigenix_grpc.pb.go @@ -0,0 +1,507 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v5.29.3 +// source: internal/client/eigenix.proto + +package client + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + EigenixService_CreateIndex_FullMethodName = "/eigenix.EigenixService/CreateIndex" + EigenixService_GetIndex_FullMethodName = "/eigenix.EigenixService/GetIndex" + EigenixService_DeleteIndex_FullMethodName = "/eigenix.EigenixService/DeleteIndex" + EigenixService_ListIndices_FullMethodName = "/eigenix.EigenixService/ListIndices" + EigenixService_SaveIndex_FullMethodName = "/eigenix.EigenixService/SaveIndex" + EigenixService_LoadIndex_FullMethodName = "/eigenix.EigenixService/LoadIndex" + EigenixService_AddPoints_FullMethodName = "/eigenix.EigenixService/AddPoints" + EigenixService_UpdatePoints_FullMethodName = "/eigenix.EigenixService/UpdatePoints" + EigenixService_DeletePoints_FullMethodName = "/eigenix.EigenixService/DeletePoints" + EigenixService_Search_FullMethodName = "/eigenix.EigenixService/Search" + EigenixService_BatchSearch_FullMethodName = "/eigenix.EigenixService/BatchSearch" +) + +// EigenixServiceClient is the client API for EigenixService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type EigenixServiceClient interface { + // Index management + CreateIndex(ctx context.Context, in *CreateIndexRequest, opts ...grpc.CallOption) (*CreateIndexResponse, error) + GetIndex(ctx context.Context, in *GetIndexRequest, opts ...grpc.CallOption) (*GetIndexResponse, error) + DeleteIndex(ctx context.Context, in *DeleteIndexRequest, opts ...grpc.CallOption) (*DeleteIndexResponse, error) + ListIndices(ctx context.Context, in *ListIndicesRequest, opts ...grpc.CallOption) (*ListIndicesResponse, error) + SaveIndex(ctx context.Context, in *IndexOperationRequest, opts ...grpc.CallOption) (*IndexOperationResponse, error) + LoadIndex(ctx context.Context, in *IndexOperationRequest, opts ...grpc.CallOption) (*IndexOperationResponse, error) + // Points + AddPoints(ctx context.Context, in *PointsRequest, opts ...grpc.CallOption) (*PointsResponse, error) + UpdatePoints(ctx context.Context, in *PointsRequest, opts ...grpc.CallOption) (*PointsResponse, error) + DeletePoints(ctx context.Context, in *DeletePointsRequest, opts ...grpc.CallOption) (*DeletePointsResponse, error) + // Search + Search(ctx context.Context, in *SearchRequest, opts ...grpc.CallOption) (*SearchResponse, error) + BatchSearch(ctx context.Context, in *BatchSearchRequest, opts ...grpc.CallOption) (*BatchSearchResponse, error) +} + +type eigenixServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewEigenixServiceClient(cc grpc.ClientConnInterface) EigenixServiceClient { + return &eigenixServiceClient{cc} +} + +func (c *eigenixServiceClient) CreateIndex(ctx context.Context, in *CreateIndexRequest, opts ...grpc.CallOption) (*CreateIndexResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(CreateIndexResponse) + err := c.cc.Invoke(ctx, EigenixService_CreateIndex_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) GetIndex(ctx context.Context, in *GetIndexRequest, opts ...grpc.CallOption) (*GetIndexResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetIndexResponse) + err := c.cc.Invoke(ctx, EigenixService_GetIndex_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) DeleteIndex(ctx context.Context, in *DeleteIndexRequest, opts ...grpc.CallOption) (*DeleteIndexResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DeleteIndexResponse) + err := c.cc.Invoke(ctx, EigenixService_DeleteIndex_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) ListIndices(ctx context.Context, in *ListIndicesRequest, opts ...grpc.CallOption) (*ListIndicesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ListIndicesResponse) + err := c.cc.Invoke(ctx, EigenixService_ListIndices_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) SaveIndex(ctx context.Context, in *IndexOperationRequest, opts ...grpc.CallOption) (*IndexOperationResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(IndexOperationResponse) + err := c.cc.Invoke(ctx, EigenixService_SaveIndex_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) LoadIndex(ctx context.Context, in *IndexOperationRequest, opts ...grpc.CallOption) (*IndexOperationResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(IndexOperationResponse) + err := c.cc.Invoke(ctx, EigenixService_LoadIndex_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) AddPoints(ctx context.Context, in *PointsRequest, opts ...grpc.CallOption) (*PointsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(PointsResponse) + err := c.cc.Invoke(ctx, EigenixService_AddPoints_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) UpdatePoints(ctx context.Context, in *PointsRequest, opts ...grpc.CallOption) (*PointsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(PointsResponse) + err := c.cc.Invoke(ctx, EigenixService_UpdatePoints_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) DeletePoints(ctx context.Context, in *DeletePointsRequest, opts ...grpc.CallOption) (*DeletePointsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DeletePointsResponse) + err := c.cc.Invoke(ctx, EigenixService_DeletePoints_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) Search(ctx context.Context, in *SearchRequest, opts ...grpc.CallOption) (*SearchResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(SearchResponse) + err := c.cc.Invoke(ctx, EigenixService_Search_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *eigenixServiceClient) BatchSearch(ctx context.Context, in *BatchSearchRequest, opts ...grpc.CallOption) (*BatchSearchResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(BatchSearchResponse) + err := c.cc.Invoke(ctx, EigenixService_BatchSearch_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// EigenixServiceServer is the server API for EigenixService service. +// All implementations must embed UnimplementedEigenixServiceServer +// for forward compatibility. +type EigenixServiceServer interface { + // Index management + CreateIndex(context.Context, *CreateIndexRequest) (*CreateIndexResponse, error) + GetIndex(context.Context, *GetIndexRequest) (*GetIndexResponse, error) + DeleteIndex(context.Context, *DeleteIndexRequest) (*DeleteIndexResponse, error) + ListIndices(context.Context, *ListIndicesRequest) (*ListIndicesResponse, error) + SaveIndex(context.Context, *IndexOperationRequest) (*IndexOperationResponse, error) + LoadIndex(context.Context, *IndexOperationRequest) (*IndexOperationResponse, error) + // Points + AddPoints(context.Context, *PointsRequest) (*PointsResponse, error) + UpdatePoints(context.Context, *PointsRequest) (*PointsResponse, error) + DeletePoints(context.Context, *DeletePointsRequest) (*DeletePointsResponse, error) + // Search + Search(context.Context, *SearchRequest) (*SearchResponse, error) + BatchSearch(context.Context, *BatchSearchRequest) (*BatchSearchResponse, error) + mustEmbedUnimplementedEigenixServiceServer() +} + +// UnimplementedEigenixServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedEigenixServiceServer struct{} + +func (UnimplementedEigenixServiceServer) CreateIndex(context.Context, *CreateIndexRequest) (*CreateIndexResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateIndex not implemented") +} +func (UnimplementedEigenixServiceServer) GetIndex(context.Context, *GetIndexRequest) (*GetIndexResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetIndex not implemented") +} +func (UnimplementedEigenixServiceServer) DeleteIndex(context.Context, *DeleteIndexRequest) (*DeleteIndexResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeleteIndex not implemented") +} +func (UnimplementedEigenixServiceServer) ListIndices(context.Context, *ListIndicesRequest) (*ListIndicesResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListIndices not implemented") +} +func (UnimplementedEigenixServiceServer) SaveIndex(context.Context, *IndexOperationRequest) (*IndexOperationResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SaveIndex not implemented") +} +func (UnimplementedEigenixServiceServer) LoadIndex(context.Context, *IndexOperationRequest) (*IndexOperationResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method LoadIndex not implemented") +} +func (UnimplementedEigenixServiceServer) AddPoints(context.Context, *PointsRequest) (*PointsResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method AddPoints not implemented") +} +func (UnimplementedEigenixServiceServer) UpdatePoints(context.Context, *PointsRequest) (*PointsResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method UpdatePoints not implemented") +} +func (UnimplementedEigenixServiceServer) DeletePoints(context.Context, *DeletePointsRequest) (*DeletePointsResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeletePoints not implemented") +} +func (UnimplementedEigenixServiceServer) Search(context.Context, *SearchRequest) (*SearchResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Search not implemented") +} +func (UnimplementedEigenixServiceServer) BatchSearch(context.Context, *BatchSearchRequest) (*BatchSearchResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method BatchSearch not implemented") +} +func (UnimplementedEigenixServiceServer) mustEmbedUnimplementedEigenixServiceServer() {} +func (UnimplementedEigenixServiceServer) testEmbeddedByValue() {} + +// UnsafeEigenixServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to EigenixServiceServer will +// result in compilation errors. +type UnsafeEigenixServiceServer interface { + mustEmbedUnimplementedEigenixServiceServer() +} + +func RegisterEigenixServiceServer(s grpc.ServiceRegistrar, srv EigenixServiceServer) { + // If the following call pancis, it indicates UnimplementedEigenixServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&EigenixService_ServiceDesc, srv) +} + +func _EigenixService_CreateIndex_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateIndexRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).CreateIndex(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_CreateIndex_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).CreateIndex(ctx, req.(*CreateIndexRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_GetIndex_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetIndexRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).GetIndex(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_GetIndex_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).GetIndex(ctx, req.(*GetIndexRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_DeleteIndex_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeleteIndexRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).DeleteIndex(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_DeleteIndex_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).DeleteIndex(ctx, req.(*DeleteIndexRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_ListIndices_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListIndicesRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).ListIndices(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_ListIndices_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).ListIndices(ctx, req.(*ListIndicesRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_SaveIndex_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(IndexOperationRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).SaveIndex(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_SaveIndex_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).SaveIndex(ctx, req.(*IndexOperationRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_LoadIndex_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(IndexOperationRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).LoadIndex(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_LoadIndex_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).LoadIndex(ctx, req.(*IndexOperationRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_AddPoints_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PointsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).AddPoints(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_AddPoints_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).AddPoints(ctx, req.(*PointsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_UpdatePoints_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PointsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).UpdatePoints(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_UpdatePoints_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).UpdatePoints(ctx, req.(*PointsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_DeletePoints_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeletePointsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).DeletePoints(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_DeletePoints_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).DeletePoints(ctx, req.(*DeletePointsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_Search_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SearchRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).Search(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_Search_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).Search(ctx, req.(*SearchRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _EigenixService_BatchSearch_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(BatchSearchRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EigenixServiceServer).BatchSearch(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: EigenixService_BatchSearch_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EigenixServiceServer).BatchSearch(ctx, req.(*BatchSearchRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// EigenixService_ServiceDesc is the grpc.ServiceDesc for EigenixService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var EigenixService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "eigenix.EigenixService", + HandlerType: (*EigenixServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "CreateIndex", + Handler: _EigenixService_CreateIndex_Handler, + }, + { + MethodName: "GetIndex", + Handler: _EigenixService_GetIndex_Handler, + }, + { + MethodName: "DeleteIndex", + Handler: _EigenixService_DeleteIndex_Handler, + }, + { + MethodName: "ListIndices", + Handler: _EigenixService_ListIndices_Handler, + }, + { + MethodName: "SaveIndex", + Handler: _EigenixService_SaveIndex_Handler, + }, + { + MethodName: "LoadIndex", + Handler: _EigenixService_LoadIndex_Handler, + }, + { + MethodName: "AddPoints", + Handler: _EigenixService_AddPoints_Handler, + }, + { + MethodName: "UpdatePoints", + Handler: _EigenixService_UpdatePoints_Handler, + }, + { + MethodName: "DeletePoints", + Handler: _EigenixService_DeletePoints_Handler, + }, + { + MethodName: "Search", + Handler: _EigenixService_Search_Handler, + }, + { + MethodName: "BatchSearch", + Handler: _EigenixService_BatchSearch_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "internal/client/eigenix.proto", +} diff --git a/eigenix/internal/hnswlib/database.go b/eigenix/internal/hnswlib/database.go new file mode 100644 index 00000000..7a86823c --- /dev/null +++ b/eigenix/internal/hnswlib/database.go @@ -0,0 +1,725 @@ +package hnswlib + +/* +#cgo CXXFLAGS: -std=c++11 -I${SRCDIR}/../../pkg/hnswlib +#cgo LDFLAGS: -lstdc++ +#include "../../pkg/hnswlib/hnsw_wrapper.h" +#include +*/ +import "C" + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "runtime" + "sync" + "sync/atomic" + "unsafe" + + pb "github.com/Meesho/skye-eigenix/internal/client" +) + +const ( + // Directory to store persisted indices + INDICES_DIR = "./indices" + // Metadata file extension + METADATA_EXT = ".meta" + // Index file extension + INDEX_EXT = ".idx" +) + +// HNSWIndex represents an in-memory HNSW index +type HNSWIndex struct { + Name string `json:"name"` + Dimension int32 `json:"dimension"` + Space string `json:"space"` // "l2", "ip", "cosine" + MaxElements int32 `json:"max_elements"` + M int32 `json:"m"` + EfConstruct int32 `json:"ef_construction"` + AllowReplaceDeleted bool `json:"allow_replace_deleted"` + CIndex C.HNSWIndex `json:"-"` // C++ HNSW index pointer +} + +// searchBuffer holds pre-allocated C memory for search operations +type searchBuffer struct { + k int + labels *C.ulonglong + distances *C.float +} + +type job struct { + idx int + query []float32 +} + +// Error handling +type result struct { + idx int + results []SearchResult + err error +} + +// Database implements the Manager interface +type Database struct { + indices map[string]*HNSWIndex + mutex sync.RWMutex + bufferPools map[int]*sync.Pool // Pools of search buffers by k value + poolsMutex sync.RWMutex +} + +var ( + database Manager + once sync.Once +) + +// Init initializes and returns the singleton database instance +func Init() Manager { + if database == nil { + once.Do(func() { + database = &Database{ + indices: make(map[string]*HNSWIndex), + bufferPools: make(map[int]*sync.Pool), + } + }) + } + return database +} + +// getSearchBuffer retrieves or creates a search buffer from the pool +func (db *Database) getSearchBuffer(k int) *searchBuffer { + db.poolsMutex.RLock() + pool, exists := db.bufferPools[k] + db.poolsMutex.RUnlock() + + if !exists { + // Create a new pool for this k value + db.poolsMutex.Lock() + // Double-check after acquiring write lock + pool, exists = db.bufferPools[k] + if !exists { + pool = &sync.Pool{ + New: func() interface{} { + return &searchBuffer{ + k: k, + labels: (*C.ulonglong)(C.malloc(C.size_t(k) * C.size_t(unsafe.Sizeof(C.ulonglong(0))))), + distances: (*C.float)(C.malloc(C.size_t(k) * C.size_t(unsafe.Sizeof(C.float(0))))), + } + }, + } + db.bufferPools[k] = pool + } + db.poolsMutex.Unlock() + } + + return pool.Get().(*searchBuffer) +} + +// putSearchBuffer returns a search buffer to the pool +func (db *Database) putSearchBuffer(buf *searchBuffer) { + db.poolsMutex.RLock() + pool, exists := db.bufferPools[buf.k] + db.poolsMutex.RUnlock() + + if exists { + pool.Put(buf) + } +} + +// CreateIndex creates a new HNSW index +func (db *Database) CreateIndex(name, space string, dimension int32, maxElements int32, m int32, efConstruct int32, allowReplaceDeleted bool) error { + db.mutex.Lock() + defer db.mutex.Unlock() + + // Check if index already exists + if _, exists := db.indices[name]; exists { + return fmt.Errorf("index already exists") + } + + // Create new HNSW index using C++ implementation + cSpaceName := C.CString(space) + defer C.free(unsafe.Pointer(cSpaceName)) + + var cAllowReplace C.int + if allowReplaceDeleted { + cAllowReplace = 1 + } else { + cAllowReplace = 0 + } + + cIndex := C.hnsw_create_index( + cSpaceName, + C.int(dimension), + C.int(maxElements), + C.int(m), + C.int(efConstruct), + C.int(100), // random seed + cAllowReplace, + ) + + if cIndex == nil { + return fmt.Errorf("failed to create HNSW index") + } + + // Create new index + index := &HNSWIndex{ + Name: name, + Dimension: dimension, + Space: space, + MaxElements: maxElements, + M: m, + EfConstruct: efConstruct, + AllowReplaceDeleted: allowReplaceDeleted, + CIndex: cIndex, + } + + db.indices[name] = index + return nil +} + +// GetIndex retrieves an index by name +func (db *Database) GetIndex(name string) (*HNSWIndex, bool) { + db.mutex.RLock() + defer db.mutex.RUnlock() + index, exists := db.indices[name] + return index, exists +} + +// ListIndices returns all indices +func (db *Database) ListIndices() []*HNSWIndex { + db.mutex.RLock() + defer db.mutex.RUnlock() + + indices := make([]*HNSWIndex, 0, len(db.indices)) + for _, index := range db.indices { + indices = append(indices, index) + } + return indices +} + +// DeleteIndex removes an index from memory +func (db *Database) DeleteIndex(name string) error { + db.mutex.Lock() + defer db.mutex.Unlock() + + index, exists := db.indices[name] + if !exists { + return fmt.Errorf("index not found") + } + + // Clean up C++ resources + if index.CIndex != nil { + C.hnsw_delete_index(index.CIndex) + } + + delete(db.indices, name) + return nil +} + +// AddPoint adds a single point to an index (pure database operation) +func (db *Database) AddPoint(indexName string, id uint64, vector []float32) error { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return fmt.Errorf("index not found") + } + + // Convert Go slice to C array + cVector := (*C.float)(unsafe.Pointer(&vector[0])) + + // Add point to HNSW index + result := C.hnsw_add_point(index.CIndex, cVector, C.ulonglong(id)) + if result != 0 { + return fmt.Errorf("failed to add point") + } + + return nil +} + +// AddPoints adds multiple points to an index (pure database operation) +// Uses parallel processing for batches > 100 points +func (db *Database) AddPoints(indexName string, points []*pb.Point) (int, error) { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return 0, fmt.Errorf("index not found") + } + + numWorkers := runtime.NumCPU() / 4 + var wg sync.WaitGroup + var successCount atomic.Int32 + + // Create work queue + jobs := make(chan *pb.Point, len(points)) + + // Start workers + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for point := range jobs { + if point == nil || len(point.Vector) == 0 { + continue + } + cVector := (*C.float)(unsafe.Pointer(&point.Vector[0])) + result := C.hnsw_add_point(index.CIndex, cVector, C.ulonglong(point.Id)) + if result == 0 { + successCount.Add(1) + } + } + }() + } + + for _, point := range points { + jobs <- point + } + close(jobs) + + wg.Wait() + + return int(successCount.Load()), nil +} + +// Search performs k-NN search on an index (pure database operation) +func (db *Database) Search(indexName string, query []float32, k int) ([]SearchResult, error) { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return nil, fmt.Errorf("index not found") + } + + if index.CIndex == nil { + return []SearchResult{}, fmt.Errorf("index not initialized") + } + + // Convert Go slice to C array + cQuery := (*C.float)(unsafe.Pointer(&query[0])) + + // Get reusable buffer from pool + buf := db.getSearchBuffer(k) + defer db.putSearchBuffer(buf) + + // Perform HNSW search + resultCount := int(C.hnsw_search_knn(index.CIndex, cQuery, C.int(k), buf.labels, buf.distances)) + + // Convert C arrays back to Go slices + results := make([]SearchResult, resultCount) + labelsSlice := (*[1 << 30]C.ulonglong)(unsafe.Pointer(buf.labels))[:resultCount:resultCount] + distancesSlice := (*[1 << 30]C.float)(unsafe.Pointer(buf.distances))[:resultCount:resultCount] + + for i := 0; i < resultCount; i++ { + distance := distancesSlice[i] + if index.Space == "cosine" { + distance = 1 - distancesSlice[i] + } + results[i] = SearchResult{ + ID: uint64(labelsSlice[i]), + Distance: float32(distance), + } + } + + return results, nil +} + +// BatchSearch performs k-NN search for multiple query vectors in parallel (pure database operation) +func (db *Database) BatchSearch(indexName string, queries []*pb.Vector, k int) ([][]SearchResult, error) { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return nil, fmt.Errorf("index not found") + } + + if index.CIndex == nil { + return nil, fmt.Errorf("index not initialized") + } + + numQueries := len(queries) + results := make([][]SearchResult, numQueries) + + // For large batches, use parallel processing with worker pool + numWorkers := runtime.NumCPU() + if numWorkers < 1 { + numWorkers = 1 + } + + var wg sync.WaitGroup + jobs := make(chan job, numQueries) + + resultsChan := make(chan result, numQueries) + + // Start workers + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := range jobs { + cQuery := (*C.float)(unsafe.Pointer(&j.query[0])) + + // Get reusable buffer from pool + buf := db.getSearchBuffer(k) + + // Perform HNSW search + resultCount := int(C.hnsw_search_knn(index.CIndex, cQuery, C.int(k), buf.labels, buf.distances)) + + // Convert C arrays back to Go slices + queryResults := make([]SearchResult, resultCount) + labelsSlice := (*[1 << 30]C.ulonglong)(unsafe.Pointer(buf.labels))[:resultCount:resultCount] + distancesSlice := (*[1 << 30]C.float)(unsafe.Pointer(buf.distances))[:resultCount:resultCount] + + for i := 0; i < resultCount; i++ { + distance := distancesSlice[i] + if index.Space == "cosine" { + distance = 1 - distancesSlice[i] + } + queryResults[i] = SearchResult{ + ID: uint64(labelsSlice[i]), + Distance: float32(distance), + } + } + + // Return buffer to pool + db.putSearchBuffer(buf) + + resultsChan <- result{idx: j.idx, results: queryResults, err: nil} + } + }() + } + + // Send jobs to workers + for i, v := range queries { + jobs <- job{idx: i, query: v.Values} + } + close(jobs) + + // Wait for all workers to complete + go func() { + wg.Wait() + close(resultsChan) + }() + + // Collect results + for r := range resultsChan { + if r.err != nil { + return nil, fmt.Errorf("query %d failed: %v", r.idx, r.err) + } + results[r.idx] = r.results + } + + return results, nil +} + +// GetIndexStats returns statistics about an index +func (db *Database) GetIndexStats(indexName string) (int, error) { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return 0, fmt.Errorf("index not found") + } + + return int(C.hnsw_get_current_count(index.CIndex)), nil +} + +// UpdatePoint updates a single point in an index (pure database operation) +func (db *Database) UpdatePoint(indexName string, id uint64, vector []float32) error { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return fmt.Errorf("index not found") + } + + // Convert Go slice to C array + cVector := (*C.float)(unsafe.Pointer(&vector[0])) + + // Update point in HNSW index + result := C.hnsw_update_point(index.CIndex, cVector, C.ulonglong(id)) + if result == -2 { + return fmt.Errorf("point not found") + } + if result != 0 { + return fmt.Errorf("failed to update point") + } + + return nil +} + +// UpdatePoints updates multiple points in an index (pure database operation) +func (db *Database) UpdatePoints(indexName string, points []*pb.Point) (int, error) { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return 0, fmt.Errorf("index not found") + } + + numWorkers := runtime.NumCPU() / 4 + if numWorkers < 1 { + numWorkers = 1 + } + + var wg sync.WaitGroup + var successCount atomic.Int32 + + // Create work queue + jobs := make(chan *pb.Point, len(points)) + + // Start workers + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for point := range jobs { + cVector := (*C.float)(unsafe.Pointer(&point.Vector[0])) + result := C.hnsw_update_point(index.CIndex, cVector, C.ulonglong(point.Id)) + if result == 0 { + successCount.Add(1) + } + } + }() + } + + // Send jobs to workers + for _, point := range points { + jobs <- point + } + close(jobs) + + wg.Wait() + + return int(successCount.Load()), nil +} + +// DeletePoint marks a single point as deleted in an index (soft delete, pure database operation) +func (db *Database) DeletePoint(indexName string, id uint64) error { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return fmt.Errorf("index not found") + } + + // Mark point as deleted in HNSW index + result := C.hnsw_mark_deleted(index.CIndex, C.ulonglong(id)) + if result != 0 { + return fmt.Errorf("failed to delete point") + } + + return nil +} + +// DeletePoints marks multiple points as deleted in an index (soft delete, pure database operation) +func (db *Database) DeletePoints(indexName string, ids []uint64) (int, error) { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return 0, fmt.Errorf("index not found") + } + + numWorkers := runtime.NumCPU() / 4 + if numWorkers < 1 { + numWorkers = 1 + } + + var wg sync.WaitGroup + var successCount atomic.Int32 + + // Create work queue + jobs := make(chan uint64, len(ids)) + + // Start workers + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for id := range jobs { + result := C.hnsw_mark_deleted(index.CIndex, C.ulonglong(id)) + if result == 0 { + successCount.Add(1) + } + } + }() + } + + // Send jobs to workers + for _, id := range ids { + jobs <- id + } + close(jobs) + + wg.Wait() + + return int(successCount.Load()), nil +} + +// SaveIndex saves an index to disk +func (db *Database) SaveIndex(indexName string) error { + db.mutex.RLock() + index, exists := db.indices[indexName] + db.mutex.RUnlock() + + if !exists { + return fmt.Errorf("index %s not found", indexName) + } + + // Ensure indices directory exists + if err := os.MkdirAll(INDICES_DIR, 0755); err != nil { + return fmt.Errorf("failed to create indices directory: %v", err) + } + + // Save index data + indexPath := filepath.Join(INDICES_DIR, indexName+INDEX_EXT) + cIndexPath := C.CString(indexPath) + defer C.free(unsafe.Pointer(cIndexPath)) + + result := C.hnsw_save_index(index.CIndex, cIndexPath) + if result != 0 { + return fmt.Errorf("failed to save index data") + } + + // Save metadata + metadata := IndexMetadata{ + Name: index.Name, + Dimension: index.Dimension, + Space: index.Space, + MaxElements: index.MaxElements, + M: index.M, + EfConstruct: index.EfConstruct, + AllowReplaceDeleted: index.AllowReplaceDeleted, + } + + metadataPath := filepath.Join(INDICES_DIR, indexName+METADATA_EXT) + metadataFile, err := os.Create(metadataPath) + if err != nil { + return fmt.Errorf("failed to create metadata file: %v", err) + } + defer metadataFile.Close() + + if err := json.NewEncoder(metadataFile).Encode(metadata); err != nil { + return fmt.Errorf("failed to save metadata: %v", err) + } + + log.Printf("Successfully saved index %s", indexName) + return nil +} + +// LoadIndex loads an index from disk +func (db *Database) LoadIndex(indexName string) error { + // Check if metadata file exists + metadataPath := filepath.Join(INDICES_DIR, indexName+METADATA_EXT) + if _, err := os.Stat(metadataPath); os.IsNotExist(err) { + return fmt.Errorf("metadata file not found for index %s", indexName) + } + + // Load metadata + metadataFile, err := os.Open(metadataPath) + if err != nil { + return fmt.Errorf("failed to open metadata file: %v", err) + } + defer metadataFile.Close() + + var metadata IndexMetadata + if err := json.NewDecoder(metadataFile).Decode(&metadata); err != nil { + return fmt.Errorf("failed to decode metadata: %v", err) + } + + // Check if index file exists + indexPath := filepath.Join(INDICES_DIR, indexName+INDEX_EXT) + if _, err := os.Stat(indexPath); os.IsNotExist(err) { + return fmt.Errorf("index file not found for index %s", indexName) + } + + // Load index using C++ implementation + cSpaceName := C.CString(metadata.Space) + cIndexPath := C.CString(indexPath) + defer C.free(unsafe.Pointer(cSpaceName)) + defer C.free(unsafe.Pointer(cIndexPath)) + + cIndex := C.hnsw_load_index( + cSpaceName, + C.int(metadata.Dimension), + cIndexPath, + C.int(metadata.MaxElements), + ) + + if cIndex == nil { + return fmt.Errorf("failed to load HNSW index from file") + } + + // Create new index object + index := &HNSWIndex{ + Name: metadata.Name, + Dimension: metadata.Dimension, + Space: metadata.Space, + MaxElements: metadata.MaxElements, + M: metadata.M, + EfConstruct: metadata.EfConstruct, + AllowReplaceDeleted: metadata.AllowReplaceDeleted, + CIndex: cIndex, + } + + db.mutex.Lock() + db.indices[indexName] = index + db.mutex.Unlock() + + log.Printf("Successfully loaded index %s", indexName) + return nil +} + +// SaveAllIndices saves all indices to disk +func (db *Database) SaveAllIndices() { + db.mutex.RLock() + indexNames := make([]string, 0, len(db.indices)) + for name := range db.indices { + indexNames = append(indexNames, name) + } + db.mutex.RUnlock() + + for _, name := range indexNames { + if err := db.SaveIndex(name); err != nil { + log.Printf("Failed to save index %s: %v", name, err) + } + } +} + +// LoadAllIndices loads all indices from disk +func (db *Database) LoadAllIndices() { + // Check if indices directory exists + if _, err := os.Stat(INDICES_DIR); os.IsNotExist(err) { + log.Println("No indices directory found, starting with empty state") + return + } + + // Read directory contents + files, err := os.ReadDir(INDICES_DIR) + if err != nil { + log.Printf("Failed to read indices directory: %v", err) + return + } + + // Find all metadata files and load corresponding indices + for _, file := range files { + if filepath.Ext(file.Name()) == METADATA_EXT { + indexName := file.Name()[:len(file.Name())-len(METADATA_EXT)] + if err := db.LoadIndex(indexName); err != nil { + log.Printf("Failed to load index %s: %v", indexName, err) + } + } + } +} diff --git a/eigenix/internal/hnswlib/hnsw_bridge.cc b/eigenix/internal/hnswlib/hnsw_bridge.cc new file mode 100644 index 00000000..254ca201 --- /dev/null +++ b/eigenix/internal/hnswlib/hnsw_bridge.cc @@ -0,0 +1,2 @@ +// This file bridges the C++ HNSW implementation for CGO +#include "../../pkg/hnswlib/hnsw_wrapper.cc" diff --git a/eigenix/internal/hnswlib/manager.go b/eigenix/internal/hnswlib/manager.go new file mode 100644 index 00000000..7ea18f00 --- /dev/null +++ b/eigenix/internal/hnswlib/manager.go @@ -0,0 +1,31 @@ +package hnswlib + +import ( + pb "github.com/Meesho/skye-eigenix/internal/client" +) + +// Manager defines the interface for database operations on HNSW indices +type Manager interface { + // Index management (pure CRUD operations) + CreateIndex(name, space string, dimension int32, maxElements int32, m int32, efConstruct int32, allowReplaceDeleted bool) error + GetIndex(name string) (*HNSWIndex, bool) + ListIndices() []*HNSWIndex + DeleteIndex(name string) error + + // Data operations (pure database operations) + AddPoint(indexName string, id uint64, vector []float32) error + AddPoints(indexName string, points []*pb.Point) (int, error) + UpdatePoint(indexName string, id uint64, vector []float32) error + UpdatePoints(indexName string, points []*pb.Point) (int, error) + DeletePoint(indexName string, id uint64) error + DeletePoints(indexName string, ids []uint64) (int, error) + Search(indexName string, query []float32, k int) ([]SearchResult, error) + BatchSearch(indexName string, queries []*pb.Vector, k int) ([][]SearchResult, error) + GetIndexStats(indexName string) (int, error) + + // Persistence (pure file I/O operations) + SaveIndex(indexName string) error + LoadIndex(indexName string) error + SaveAllIndices() + LoadAllIndices() +} diff --git a/eigenix/internal/hnswlib/model.go b/eigenix/internal/hnswlib/model.go new file mode 100644 index 00000000..613d351a --- /dev/null +++ b/eigenix/internal/hnswlib/model.go @@ -0,0 +1,113 @@ +package hnswlib + +// IndexMetadata stores configuration needed to reload an index +type IndexMetadata struct { + Name string `json:"name"` + Dimension int32 `json:"dimension"` + Space string `json:"space"` + MaxElements int32 `json:"max_elements"` + M int32 `json:"m"` + EfConstruct int32 `json:"ef_construction"` + AllowReplaceDeleted bool `json:"allow_replace_deleted"` +} + +// Point represents a vector point +type Point struct { + ID uint64 `json:"id"` + Vector []float32 `json:"vector"` +} + +// SearchRequest represents a search request +type SearchRequest struct { + Vector []float32 `json:"vector"` + K int `json:"k"` +} + +// SearchResult represents a search result +type SearchResult struct { + ID uint64 `json:"id"` + Distance float32 `json:"distance"` +} + +// SearchResponse represents the response from a search +type SearchResponse struct { + Results []SearchResult `json:"results"` +} + +// CreateIndexRequest represents a request to create an index +type CreateIndexRequest struct { + Name string `json:"name"` + Dimension int `json:"dimension"` + Space string `json:"space"` + MaxElements int `json:"max_elements"` + M int `json:"m"` + EfConstruct int `json:"ef_construction"` + AllowReplaceDeleted bool `json:"allow_replace_deleted"` +} + +// GetIndexRequest represents a request to get an index +type GetIndexRequest struct { + Name string `json:"name"` +} + +// DeleteIndexRequest represents a request to delete an index +type DeleteIndexRequest struct { + Name string `json:"name"` +} + +// SaveIndexRequest represents a request to save an index +type SaveIndexRequest struct { + Name string `json:"name"` +} + +// LoadIndexRequest represents a request to load an index +type LoadIndexRequest struct { + Name string `json:"name"` +} + +// AddPointRequest represents a request to add a single point to an index +type AddPointRequest struct { + IndexName string `json:"index_name"` + Point Point `json:"point"` +} + +// AddPointsRequest represents a request to add points to an index +type AddPointsRequest struct { + IndexName string `json:"index_name"` + Points []Point `json:"points"` +} + +// BatchSearchRequest represents a request to search with multiple query vectors +type BatchSearchRequest struct { + Queries [][]float32 `json:"vectors"` + K int `json:"k"` +} + +// BatchSearchResponse represents the response from a batch search +type BatchSearchResponse struct { + Results [][]SearchResult `json:"results"` +} + +// UpdatePointRequest represents a request to update a single point +type UpdatePointRequest struct { + IndexName string `json:"index_name"` + Point Point `json:"point"` +} + +// UpdatePointsRequest represents a request to update multiple points +type UpdatePointsRequest struct { + IndexName string `json:"index_name"` + Points []Point `json:"points"` +} + +// DeletePointRequest represents a request to delete a single point +type DeletePointRequest struct { + IndexName string `json:"index_name"` + ID uint64 `json:"id"` +} + +// DeletePointsRequest represents a request to delete multiple points +type DeletePointsRequest struct { + IndexName string `json:"index_name"` + IDs []uint64 `json:"ids"` +} diff --git a/eigenix/internal/qdrant/database.go b/eigenix/internal/qdrant/database.go new file mode 100644 index 00000000..11dc6277 --- /dev/null +++ b/eigenix/internal/qdrant/database.go @@ -0,0 +1,433 @@ +package qdrant + +import ( + "context" + "encoding/json" + "fmt" + "log" + "math" + "os" + "sort" + "sync" + + "github.com/qdrant/go-client/qdrant" + "github.com/spf13/viper" +) + +const ( + MaxWorkers = 32 +) + +// Database implements the Manager interface for Qdrant operations +type Database struct { + client *qdrant.Client + centers [][]float32 // Centroids stored as float32 + centersMutex sync.RWMutex +} + +var ( + database Manager + once sync.Once +) + +// Init initializes and returns the singleton database instance +func Init() Manager { + if database == nil { + once.Do(func() { + client, err := qdrant.NewClient(&qdrant.Config{ + Host: viper.GetString("QDRANT_HOST"), + Port: viper.GetInt("QDRANT_PORT"), + UseTLS: false, + }) + if err != nil { + log.Fatalf("failed to create Qdrant client: %v", err) + } + database = &Database{ + client: client, + } + }) + } + return database +} + +// Close closes the Qdrant client connection +func (db *Database) Close() error { + // Qdrant Go client doesn't require explicit close + // But we can nil out the reference + db.client = nil + return nil +} + +// LoadCentroids loads K-means centroids from a JSON file +func (db *Database) LoadCentroids(filepath string) error { + db.centersMutex.Lock() + defer db.centersMutex.Unlock() + + log.Printf("[qdrant] Loading centroids from %s...", filepath) + + // Load file + data, err := os.ReadFile(filepath) + if err != nil { + return fmt.Errorf("failed to read centroids file: %w", err) + } + + // Load as float64 from JSON + var centersFloat64 [][]float64 + if err := json.Unmarshal(data, ¢ersFloat64); err != nil { + return fmt.Errorf("failed to unmarshal centroids: %w", err) + } + + if len(centersFloat64) == 0 { + return fmt.Errorf("no centroids found in file") + } + + rows := len(centersFloat64) + cols := len(centersFloat64[0]) + + // Convert to float32 and normalize + db.centers = make([][]float32, rows) + for i, center := range centersFloat64 { + if len(center) != cols { + return fmt.Errorf("inconsistent centroid dimensions at index %d", i) + } + + // Convert to float32 + centerFloat32 := make([]float32, cols) + for j, v := range center { + centerFloat32[j] = float32(v) + } + + // L2 normalize the center + db.centers[i] = l2Normalize(centerFloat32) + } + + log.Printf("[qdrant] Loaded %d centroids with %d dimensions (normalized)", rows, cols) + + return nil +} + +// GetTopCentroids finds the top N closest centroids for a given embedding +func (db *Database) GetTopCentroids(embedding []float32, topN int) ([]int, error) { + db.centersMutex.RLock() + defer db.centersMutex.RUnlock() + + if db.centers == nil { + return nil, fmt.Errorf("centroids not loaded") + } + + // Normalize query vector + q := l2Normalize(embedding) + rows := len(db.centers) + + // Calculate similarities using dot product (cosine similarity for normalized vectors) + sims := make([]struct { + idx int + sim float32 + }, rows) + + for i := 0; i < rows; i++ { + sim := float32(0.0) + for j := 0; j < len(q); j++ { + sim += db.centers[i][j] * q[j] + } + sims[i] = struct { + idx int + sim float32 + }{i, sim} + } + + // Sort by similarity descending + sort.Slice(sims, func(i, j int) bool { + return sims[i].sim > sims[j].sim + }) + + // Return top N centroid indices + result := make([]int, topN) + for i := 0; i < topN && i < len(sims); i++ { + result[i] = sims[i].idx + } + + return result, nil +} + +// Search performs intelligent routing based on SearchParams +// - If UseIVF is true: routes to multiple centroid-based collections (_c0001, _c0002, etc.) +// - If UseIVF is false: queries single collection with index_name +func (db *Database) Search(ctx context.Context, req SearchRequest) ([]SearchResponse, error) { + if db.client == nil { + return nil, fmt.Errorf("client not initialized") + } + + if len(req.Vector) == 0 { + return nil, fmt.Errorf("empty query vector") + } + + // Route based on IVF configuration + if req.Params.UseIVF { + // IVF mode: route to multiple centroid-based collections + return db.searchIVF(ctx, req, req.Vector) + } + + // Single collection mode: query the base collection directly + return db.searchSingle(ctx, req, req.Vector) +} + +// BatchSearch performs batch search across multiple query vectors in parallel +func (db *Database) BatchSearch(ctx context.Context, req BatchSearchRequest) ([][]SearchResponse, error) { + fmt.Println("BatchSearch", req) + if db.client == nil { + return nil, fmt.Errorf("client not initialized") + } + + if len(req.Vectors) == 0 { + return nil, fmt.Errorf("empty query vectors") + } + + numVectors := len(req.Vectors) + results := make([][]SearchResponse, numVectors) + + // Use worker pool for parallel processing + type job struct { + idx int + vec []float32 + } + + type result struct { + idx int + results []SearchResponse + err error + } + + // Create job and result channels + jobs := make(chan job, numVectors) + resultsChan := make(chan result, numVectors) + + // Determine number of workers (use MaxWorkers or number of CPUs) + numWorkers := MaxWorkers + if numVectors < numWorkers { + numWorkers = numVectors + } + + // Start worker pool + var wg sync.WaitGroup + for w := 0; w < numWorkers; w++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := range jobs { + // Create a single search request + searchReq := SearchRequest{ + IndexName: req.IndexName, + Vector: j.vec, + Limit: req.Limit, + Params: req.Params, + } + + searchResults, err := db.Search(ctx, searchReq) + if err != nil { + resultsChan <- result{ + idx: j.idx, + err: fmt.Errorf("search failed for vector %d: %w", j.idx, err), + } + continue + } + + resultsChan <- result{ + idx: j.idx, + results: searchResults, + err: nil, + } + } + }() + } + + // Send jobs to workers + go func() { + for i, vec := range req.Vectors { + if vec == nil || len(vec.Values) == 0 { + resultsChan <- result{ + idx: i, + err: fmt.Errorf("empty query vector at index %d", i), + } + continue + } + + jobs <- job{ + idx: i, + vec: vec.Values, + } + } + close(jobs) + }() + + // Wait for all workers to complete + go func() { + wg.Wait() + close(resultsChan) + }() + + // Collect results + for res := range resultsChan { + if res.err != nil { + return nil, res.err + } + results[res.idx] = res.results + } + + return results, nil +} + +// searchSingle performs search on a single collection +func (db *Database) searchSingle(ctx context.Context, req SearchRequest, queryVec []float32) ([]SearchResponse, error) { + results, err := db.queryCollection(ctx, req.IndexName, queryVec, req.Limit) + if err != nil { + return nil, fmt.Errorf("search failed for collection %s: %w", req.IndexName, err) + } + + return results, nil +} + +// searchIVF performs IVF-routed search across multiple centroid-based collections +func (db *Database) searchIVF(ctx context.Context, req SearchRequest, queryVec []float32) ([]SearchResponse, error) { + // Get top centroids for routing + topCentroids, err := db.GetTopCentroids(req.Vector, req.Params.CentroidCount) + if err != nil { + return nil, fmt.Errorf("failed to get top centroids: %w", err) + } + + // Build collection names from centroids (format: indexname_c0001, indexname_c0002, etc.) + collections := make([]string, len(topCentroids)) + for i, cid := range topCentroids { + collections[i] = fmt.Sprintf("%s_c%04d", req.IndexName, cid+1) + } + + // Search across multiple collections concurrently + allResults := db.searchMultipleCollections(ctx, collections, queryVec, req.Limit) + + // Dedupe and merge results + merged := db.dedupeAndMerge(allResults, req.Limit) + + return merged, nil +} + +// queryCollection performs a search on a single Qdrant collection +func (db *Database) queryCollection(ctx context.Context, collection string, queryVec []float32, limit uint64) ([]SearchResponse, error) { + var results []SearchResponse + + queryPoints, err := db.client.Query(ctx, &qdrant.QueryPoints{ + CollectionName: collection, + Query: &qdrant.Query{ + Variant: &qdrant.Query_Nearest{ + Nearest: &qdrant.VectorInput{ + Variant: &qdrant.VectorInput_Dense{ + Dense: &qdrant.DenseVector{ + Data: queryVec, + }, + }, + }, + }, + }, + Limit: &limit, + WithPayload: &qdrant.WithPayloadSelector{ + SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: false}, + }, + WithVectors: &qdrant.WithVectorsSelector{ + SelectorOptions: &qdrant.WithVectorsSelector_Enable{Enable: false}, + }, + }) + if err != nil { + return nil, err + } + + results = make([]SearchResponse, 0, len(queryPoints)) + for _, point := range queryPoints { + results = append(results, SearchResponse{ + ID: point.Id.GetNum(), + Score: point.Score, + }) + } + + return results, nil +} + +// searchMultipleCollections searches across multiple collections concurrently +func (db *Database) searchMultipleCollections(ctx context.Context, collections []string, queryVec []float32, perCollectionLimit uint64) []SearchResponse { + var wg sync.WaitGroup + resultsChan := make(chan []SearchResponse, len(collections)) + + // Limit concurrency + sem := make(chan struct{}, MaxWorkers) + + for _, collection := range collections { + wg.Add(1) + go func(coll string) { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + + results, err := db.queryCollection(ctx, coll, queryVec, perCollectionLimit) + if err != nil { + log.Printf("[warn] search failed for collection %s: %v", coll, err) + return + } + resultsChan <- results + }(collection) + } + + wg.Wait() + close(resultsChan) + + // Merge results from all collections + allResults := []SearchResponse{} + for results := range resultsChan { + allResults = append(allResults, results...) + } + + return allResults +} + +// dedupeAndMerge deduplicates results by ID, keeping the best score, and limits to finalLimit +func (db *Database) dedupeAndMerge(results []SearchResponse, finalLimit uint64) []SearchResponse { + // Dedupe by ID, keep best score + byID := make(map[uint64]SearchResponse) + for _, result := range results { + if existing, ok := byID[result.ID]; !ok || result.Score > existing.Score { + byID[result.ID] = result + } + } + + // Convert to slice + merged := make([]SearchResponse, 0, len(byID)) + for _, result := range byID { + merged = append(merged, result) + } + + // Sort by score descending + sort.Slice(merged, func(i, j int) bool { + return merged[i].Score > merged[j].Score + }) + + // Limit to finalLimit + if len(merged) > int(finalLimit) { + merged = merged[:finalLimit] + } + + return merged +} + +// ==================== HELPER FUNCTIONS ==================== + +// l2Normalize normalizes a vector using L2 norm (float32 version) +func l2Normalize(vec []float32) []float32 { + norm := float32(0.0) + for _, v := range vec { + norm += v * v + } + norm = float32(math.Sqrt(float64(norm))) + 1e-12 + + result := make([]float32, len(vec)) + for i, v := range vec { + result[i] = v / norm + } + return result +} diff --git a/eigenix/internal/qdrant/manager.go b/eigenix/internal/qdrant/manager.go new file mode 100644 index 00000000..2842f3f6 --- /dev/null +++ b/eigenix/internal/qdrant/manager.go @@ -0,0 +1,13 @@ +package qdrant + +import ( + "context" +) + +// Manager defines the interface for Qdrant database operations +type Manager interface { + // Client initialization + LoadCentroids(filepath string) error + Search(ctx context.Context, req SearchRequest) ([]SearchResponse, error) + BatchSearch(ctx context.Context, req BatchSearchRequest) ([][]SearchResponse, error) +} diff --git a/eigenix/internal/qdrant/model.go b/eigenix/internal/qdrant/model.go new file mode 100644 index 00000000..4ad1d83f --- /dev/null +++ b/eigenix/internal/qdrant/model.go @@ -0,0 +1,44 @@ +package qdrant + +import pb "github.com/Meesho/skye-eigenix/internal/client" + +// QdrantConfig holds configuration for Qdrant client initialization +type QdrantConfig struct { + Host string `json:"host"` + Port int `json:"port"` +} + +// SearchParams defines search behavior parameters from proto +type SearchParams struct { + UseIVF bool `json:"use_ivf"` // If true, use IVF routing to multiple collections + CentroidCount int `json:"centroid_count"` // Number of centroids to search (only for IVF mode) + SearchIndexedOnly bool `json:"search_indexed_only"` // Only search indexed vectors +} + +// SearchRequest represents a search request with optional IVF routing +type SearchRequest struct { + IndexName string `json:"index_name"` // Base collection/index name + Vector []float32 `json:"vector"` // Query vector + Limit uint64 `json:"limit"` // Number of results to return + Params SearchParams `json:"params"` // Search parameters (IVF, exact search, etc.) +} + +type BatchSearchRequest struct { + IndexName string `json:"index_name"` // Base collection/index name + Vectors []*pb.Vector `json:"vectors"` // Query vector + Limit uint64 `json:"limit"` // Number of results to return + Params SearchParams `json:"params"` // Search parameters (IVF, exact search, etc.) +} + +// SearchResponse represents the response from a search operation +type SearchResponse struct { + ID uint64 `json:"id"` + Score float32 `json:"score"` +} + +// CentroidInfo stores information about K-means centroids for IVF +type CentroidInfo struct { + Centers [][]float32 `json:"centers"` // Normalized centroid vectors + TotalCount int `json:"total_count"` // Total number of centroids + Dimension int `json:"dimension"` // Dimension of vectors +} diff --git a/eigenix/internal/qdrant/utils.go b/eigenix/internal/qdrant/utils.go new file mode 100644 index 00000000..d818436b --- /dev/null +++ b/eigenix/internal/qdrant/utils.go @@ -0,0 +1,87 @@ +package qdrant + +import ( + "encoding/json" + "fmt" + "math" + "os" + "sort" +) + +// LoadEmbeddingsFromFile loads embeddings from a JSON file +// Supports both wrapped format {"embeddings": [...]} and direct array format +func LoadEmbeddingsFromFile(filepath string) ([][]float64, error) { + data, err := os.ReadFile(filepath) + if err != nil { + return nil, fmt.Errorf("failed to read embeddings file: %w", err) + } + + // Try to unmarshal as wrapped format first + var wrappedEmbeddings struct { + Embeddings [][]float64 `json:"embeddings"` + } + if err := json.Unmarshal(data, &wrappedEmbeddings); err == nil && len(wrappedEmbeddings.Embeddings) > 0 { + return wrappedEmbeddings.Embeddings, nil + } + + // Fallback to direct array format + var embeddings [][]float64 + if err := json.Unmarshal(data, &embeddings); err != nil { + return nil, fmt.Errorf("failed to unmarshal embeddings: %w", err) + } + + return embeddings, nil +} + +// CalculateStatistics computes mean, std dev, min, max, and median of a float64 slice +func CalculateStatistics(values []float64) (mean, std, min, max, median float64) { + if len(values) == 0 { + return 0, 0, 0, 0, 0 + } + + // Mean + sum := 0.0 + for _, v := range values { + sum += v + } + mean = sum / float64(len(values)) + + // Std dev + sumSq := 0.0 + for _, v := range values { + diff := v - mean + sumSq += diff * diff + } + std = math.Sqrt(sumSq / float64(len(values))) + + // Min/Max + min = values[0] + max = values[0] + for _, v := range values { + if v < min { + min = v + } + if v > max { + max = v + } + } + + // Median + sorted := make([]float64, len(values)) + copy(sorted, values) + sort.Float64s(sorted) + + midIdx := len(sorted) / 2 + if len(sorted)%2 == 0 { + median = (sorted[midIdx-1] + sorted[midIdx]) / 2 + } else { + median = sorted[midIdx] + } + + return +} + +// CollectionNameFromCentroid generates a collection name from a centroid ID +func CollectionNameFromCentroid(prefix string, centroidID int) string { + return fmt.Sprintf("%s_c%04d", prefix, centroidID+1) +} diff --git a/eigenix/internal/serving/serving.go b/eigenix/internal/serving/serving.go new file mode 100644 index 00000000..05043a96 --- /dev/null +++ b/eigenix/internal/serving/serving.go @@ -0,0 +1,251 @@ +package serving + +import ( + "context" + "sync" + + pb "github.com/Meesho/skye-eigenix/internal/client" + "github.com/Meesho/skye-eigenix/internal/hnswlib" + "github.com/Meesho/skye-eigenix/internal/qdrant" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type HandlerV1 struct { + pb.EigenixServiceServer + hnswlib hnswlib.Manager + qdrant qdrant.Manager +} + +var ( + handlerV1 HandlerV1 + once sync.Once +) + +func Init() pb.EigenixServiceServer { + once.Do(func() { + handlerV1 = HandlerV1{ + hnswlib: hnswlib.Init(), + qdrant: qdrant.Init(), + } + }) + return &handlerV1 +} + +func (h *HandlerV1) CreateIndex(ctx context.Context, req *pb.CreateIndexRequest) (*pb.CreateIndexResponse, error) { + if err := ValidateCreateIndexRequest(req); err != nil { + return nil, status.Errorf(codes.InvalidArgument, err.Error()) + } + + if err := h.hnswlib.CreateIndex(req.Name, req.Space, req.Dimension, req.MaxElements, req.M, req.EfConstruction, req.AllowReplaceDeleted); err != nil { + return nil, err + } + return &pb.CreateIndexResponse{Success: true, Message: "Index created successfully"}, nil +} + +func (h *HandlerV1) GetIndex(ctx context.Context, req *pb.GetIndexRequest) (*pb.GetIndexResponse, error) { + if req.Name == "" { + return nil, status.Errorf(codes.InvalidArgument, "index name is required") + } + + index, exists := h.hnswlib.GetIndex(req.Name) + if !exists { + return nil, status.Errorf(codes.NotFound, "index not found") + } + count, err := h.hnswlib.GetIndexStats(index.Name) + if err != nil { + return nil, err + } + + return &pb.GetIndexResponse{Index: &pb.IndexInfo{ + Name: index.Name, + Dimension: index.Dimension, + Space: index.Space, + MaxElements: index.MaxElements, + M: index.M, + EfConstruction: index.EfConstruct, + AllowReplaceDeleted: index.AllowReplaceDeleted, + CurrentCount: int32(count), + }}, nil +} + +func (h *HandlerV1) DeleteIndex(ctx context.Context, req *pb.DeleteIndexRequest) (*pb.DeleteIndexResponse, error) { + if req.Name == "" { + return nil, status.Errorf(codes.InvalidArgument, "index name is required") + } + + if err := h.hnswlib.DeleteIndex(req.Name); err != nil { + return nil, err + } + return &pb.DeleteIndexResponse{Success: true, Message: "Index deleted successfully"}, nil +} + +func (h *HandlerV1) ListIndices(ctx context.Context, req *pb.ListIndicesRequest) (*pb.ListIndicesResponse, error) { + indices := h.hnswlib.ListIndices() + indicesInfo := make([]*pb.IndexInfo, len(indices)) + for i, index := range indices { + count, err := h.hnswlib.GetIndexStats(index.Name) + if err != nil { + return nil, err + } + indicesInfo[i] = &pb.IndexInfo{ + Name: index.Name, + Dimension: index.Dimension, + Space: index.Space, + MaxElements: index.MaxElements, + M: index.M, + EfConstruction: index.EfConstruct, + AllowReplaceDeleted: index.AllowReplaceDeleted, + CurrentCount: int32(count), + } + } + return &pb.ListIndicesResponse{Indices: indicesInfo}, nil +} + +func (h *HandlerV1) SaveIndex(ctx context.Context, req *pb.IndexOperationRequest) (*pb.IndexOperationResponse, error) { + if req.Name == "" { + return nil, status.Errorf(codes.InvalidArgument, "index name is required") + } + + if err := h.hnswlib.SaveIndex(req.Name); err != nil { + return nil, err + } + return &pb.IndexOperationResponse{Success: true, Message: "Index saved successfully"}, nil +} + +func (h *HandlerV1) LoadIndex(ctx context.Context, req *pb.IndexOperationRequest) (*pb.IndexOperationResponse, error) { + if req.Name == "" { + return nil, status.Errorf(codes.InvalidArgument, "index name is required") + } + + if err := h.hnswlib.LoadIndex(req.Name); err != nil { + return nil, err + } + return &pb.IndexOperationResponse{Success: true, Message: "Index loaded successfully"}, nil +} + +func (h *HandlerV1) AddPoints(ctx context.Context, req *pb.PointsRequest) (*pb.PointsResponse, error) { + if req.IndexName == "" { + return nil, status.Errorf(codes.InvalidArgument, "index name is required") + } + if len(req.Points) == 0 { + return nil, status.Errorf(codes.InvalidArgument, "points are required") + } + count, err := h.hnswlib.AddPoints(req.IndexName, req.Points) + if err != nil { + return nil, err + } + return &pb.PointsResponse{Success: true, AffectedCount: int32(count), TotalCount: int32(len(req.Points))}, nil +} + +func (h *HandlerV1) UpdatePoints(ctx context.Context, req *pb.PointsRequest) (*pb.PointsResponse, error) { + if req.IndexName == "" { + return nil, status.Errorf(codes.InvalidArgument, "index name is required") + } + if len(req.Points) == 0 { + return nil, status.Errorf(codes.InvalidArgument, "points are required") + } + count, err := h.hnswlib.UpdatePoints(req.IndexName, req.Points) + if err != nil { + return nil, err + } + return &pb.PointsResponse{Success: true, AffectedCount: int32(count), TotalCount: int32(len(req.Points))}, nil +} + +func (h *HandlerV1) DeletePoints(ctx context.Context, req *pb.DeletePointsRequest) (*pb.DeletePointsResponse, error) { + if req.IndexName == "" { + return nil, status.Errorf(codes.InvalidArgument, "index name is required") + } + if len(req.Ids) == 0 { + return nil, status.Errorf(codes.InvalidArgument, "ids are required") + } + count, err := h.hnswlib.DeletePoints(req.IndexName, req.Ids) + if err != nil { + return nil, err + } + return &pb.DeletePointsResponse{Success: true, DeletedCount: int32(count)}, nil +} + +func (h *HandlerV1) Search(ctx context.Context, req *pb.SearchRequest) (*pb.SearchResponse, error) { + if err := ValidateSearchRequest(req); err != nil { + return nil, status.Errorf(codes.InvalidArgument, err.Error()) + } + switch req.DatabaseType { + case pb.DatabaseType_HNSW_LIB: + results, err := h.hnswlib.Search(req.IndexName, req.Vector, int(req.Limit)) + if err != nil { + return nil, err + } + resultsPb := make([]*pb.SearchResult, len(results)) + for i, result := range results { + resultsPb[i] = &pb.SearchResult{Id: result.ID, Distance: result.Distance} + } + return &pb.SearchResponse{Results: resultsPb}, nil + case pb.DatabaseType_QDRANT: + results, err := h.qdrant.Search(ctx, qdrant.SearchRequest{ + IndexName: req.IndexName, + Vector: req.Vector, + Limit: uint64(req.Limit), + Params: qdrant.SearchParams{ + UseIVF: req.SearchParams.UseIvf, + CentroidCount: int(req.SearchParams.CentCount), + SearchIndexedOnly: req.SearchParams.SearchIndexedOnly, + }, + }) + if err != nil { + return nil, err + } + resultsPb := make([]*pb.SearchResult, len(results)) + for i, result := range results { + resultsPb[i] = &pb.SearchResult{Id: result.ID, Distance: result.Score} + } + return &pb.SearchResponse{Results: resultsPb}, nil + default: + return nil, status.Errorf(codes.InvalidArgument, "invalid database type") + } +} + +func (h *HandlerV1) BatchSearch(ctx context.Context, req *pb.BatchSearchRequest) (*pb.BatchSearchResponse, error) { + switch req.DatabaseType { + case pb.DatabaseType_HNSW_LIB: + results, err := h.hnswlib.BatchSearch(req.IndexName, req.Vectors, int(req.Limit)) + if err != nil { + return nil, err + } + resultSets := make([]*pb.SearchResultSet, len(results)) + for i, set := range results { + sr := make([]*pb.SearchResult, len(set)) + for j, r := range set { + sr[j] = &pb.SearchResult{Id: r.ID, Distance: r.Distance} + } + resultSets[i] = &pb.SearchResultSet{Results: sr} + } + return &pb.BatchSearchResponse{ResultSets: resultSets}, nil + case pb.DatabaseType_QDRANT: + results, err := h.qdrant.BatchSearch(ctx, qdrant.BatchSearchRequest{ + IndexName: req.IndexName, + Vectors: req.Vectors, + Limit: uint64(req.Limit), + Params: qdrant.SearchParams{ + UseIVF: req.SearchParams.UseIvf, + CentroidCount: int(req.SearchParams.CentCount), + SearchIndexedOnly: req.SearchParams.SearchIndexedOnly, + }, + }) + if err != nil { + return nil, err + } + resultSets := make([]*pb.SearchResultSet, len(results)) + for i, set := range results { + sr := make([]*pb.SearchResult, len(set)) + for j, r := range set { + sr[j] = &pb.SearchResult{Id: r.ID, Distance: r.Score} + } + resultSets[i] = &pb.SearchResultSet{Results: sr} + } + return &pb.BatchSearchResponse{ResultSets: resultSets}, nil + default: + return nil, status.Errorf(codes.InvalidArgument, "invalid database type") + } +} diff --git a/eigenix/internal/serving/validator.go b/eigenix/internal/serving/validator.go new file mode 100644 index 00000000..eb002d4d --- /dev/null +++ b/eigenix/internal/serving/validator.go @@ -0,0 +1,77 @@ +package serving + +import ( + "errors" + + "github.com/Meesho/skye-eigenix/internal/client" +) + +func ValidateCreateIndexRequest(req *client.CreateIndexRequest) error { + if req.Name == "" { + return errors.New("name is required") + } + if req.Dimension == 0 { + return errors.New("dimension is required") + } + if req.Space == "" { + return errors.New("space is required") + } + if req.MaxElements == 0 { + return errors.New("max_elements is required") + } + if req.M == 0 { + return errors.New("m is required") + } + if req.EfConstruction == 0 { + return errors.New("ef_construction is required") + } + return nil +} + +func ValidateSearchRequest(req *client.SearchRequest) error { + if req.IndexName == "" { + return errors.New("index name is required") + } + if req.Limit <= 0 { + return errors.New("limit must be greater than 0") + } + if len(req.Vector) == 0 { + return errors.New("vector is required") + } + + // Validate SearchParams if provided + if req.SearchParams != nil { + if req.SearchParams.UseIvf { + // If IVF is enabled, centroid count must be specified + if req.SearchParams.CentCount <= 0 { + return errors.New("cent_count must be greater than 0 when use_ivf is true") + } + } + } + + return nil +} + +func ValidateBatchSearchRequest(req *client.BatchSearchRequest) error { + if req.IndexName == "" { + return errors.New("index name is required") + } + if req.Limit <= 0 { + return errors.New("limit must be greater than 0") + } + if len(req.Vectors) == 0 { + return errors.New("vectors are required") + } + + // Validate SearchParams if provided + if req.SearchParams != nil { + if req.SearchParams.UseIvf { + // If IVF is enabled, centroid count must be specified + if req.SearchParams.CentCount <= 0 { + return errors.New("cent_count must be greater than 0 when use_ivf is true") + } + } + } + + return nil +} diff --git a/eigenix/pkg/hnswlib/bruteforce.h b/eigenix/pkg/hnswlib/bruteforce.h new file mode 100644 index 00000000..8727cc8a --- /dev/null +++ b/eigenix/pkg/hnswlib/bruteforce.h @@ -0,0 +1,173 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace hnswlib { +template +class BruteforceSearch : public AlgorithmInterface { + public: + char *data_; + size_t maxelements_; + size_t cur_element_count; + size_t size_per_element_; + + size_t data_size_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::mutex index_lock; + + std::unordered_map dict_external_to_internal; + + + BruteforceSearch(SpaceInterface *s) + : data_(nullptr), + maxelements_(0), + cur_element_count(0), + size_per_element_(0), + data_size_(0), + dist_func_param_(nullptr) { + } + + + BruteforceSearch(SpaceInterface *s, const std::string &location) + : data_(nullptr), + maxelements_(0), + cur_element_count(0), + size_per_element_(0), + data_size_(0), + dist_func_param_(nullptr) { + loadIndex(location, s); + } + + + BruteforceSearch(SpaceInterface *s, size_t maxElements) { + maxelements_ = maxElements; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxElements * size_per_element_); + if (data_ == nullptr) + throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + cur_element_count = 0; + } + + + ~BruteforceSearch() { + free(data_); + } + + + void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { + int idx; + { + std::unique_lock lock(index_lock); + + auto search = dict_external_to_internal.find(label); + if (search != dict_external_to_internal.end()) { + idx = search->second; + } else { + if (cur_element_count >= maxelements_) { + throw std::runtime_error("The number of elements exceeds the specified limit\n"); + } + idx = cur_element_count; + dict_external_to_internal[label] = idx; + cur_element_count++; + } + } + memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); + memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); + } + + + void removePoint(labeltype cur_external) { + std::unique_lock lock(index_lock); + + auto found = dict_external_to_internal.find(cur_external); + if (found == dict_external_to_internal.end()) { + return; + } + + dict_external_to_internal.erase(found); + + size_t cur_c = found->second; + labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); + dict_external_to_internal[label] = cur_c; + memcpy(data_ + size_per_element_ * cur_c, + data_ + size_per_element_ * (cur_element_count-1), + data_size_+sizeof(labeltype)); + cur_element_count--; + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + assert(k <= cur_element_count); + std::priority_queue> topResults; + if (cur_element_count == 0) return topResults; + for (int i = 0; i < k; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); + if ((!isIdAllowed) || (*isIdAllowed)(label)) { + topResults.emplace(dist, label); + } + } + dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; + for (int i = k; i < cur_element_count; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + if (dist <= lastdist) { + labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); + if ((!isIdAllowed) || (*isIdAllowed)(label)) { + topResults.emplace(dist, label); + } + if (topResults.size() > k) + topResults.pop(); + + if (!topResults.empty()) { + lastdist = topResults.top().first; + } + } + } + return topResults; + } + + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, maxelements_); + writeBinaryPOD(output, size_per_element_); + writeBinaryPOD(output, cur_element_count); + + output.write(data_, maxelements_ * size_per_element_); + + output.close(); + } + + + void loadIndex(const std::string &location, SpaceInterface *s) { + std::ifstream input(location, std::ios::binary); + std::streampos position; + + readBinaryPOD(input, maxelements_); + readBinaryPOD(input, size_per_element_); + readBinaryPOD(input, cur_element_count); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxelements_ * size_per_element_); + if (data_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + + input.read(data_, maxelements_ * size_per_element_); + + input.close(); + } +}; +} // namespace hnswlib diff --git a/eigenix/pkg/hnswlib/hnsw_wrapper.cc b/eigenix/pkg/hnswlib/hnsw_wrapper.cc new file mode 100644 index 00000000..b3e3dd1e --- /dev/null +++ b/eigenix/pkg/hnswlib/hnsw_wrapper.cc @@ -0,0 +1,316 @@ +#include "hnsw_wrapper.h" +#include "hnswlib.h" +#include +#include + +struct HNSWIndexImpl { + hnswlib::SpaceInterface* space; + hnswlib::HierarchicalNSW* alg_hnsw; + std::string space_name; + int dimension; + + HNSWIndexImpl(const std::string& space_name, int dim) + : space_name(space_name), dimension(dim), alg_hnsw(nullptr) { + + if (space_name == "l2") { + space = new hnswlib::L2Space(dim); + } else if (space_name == "ip") { + space = new hnswlib::InnerProductSpace(dim); + } else if (space_name == "cosine") { + space = new hnswlib::InnerProductSpace(dim); + } else { + throw std::runtime_error("Unknown space name"); + } + } + + ~HNSWIndexImpl() { + if (alg_hnsw) delete alg_hnsw; + if (space) delete space; + } +}; + +extern "C" { + +HNSWIndex hnsw_create_index(const char* space_name, int dim, int max_elements, int M, int ef_construction, int random_seed, int allow_replace_deleted) { + try { + auto* impl = new HNSWIndexImpl(space_name, dim); + bool allow_replace = (allow_replace_deleted != 0); + impl->alg_hnsw = new hnswlib::HierarchicalNSW(impl->space, max_elements, M, ef_construction, random_seed, allow_replace); + return static_cast(impl); + } catch (...) { + return nullptr; + } +} + +void hnsw_delete_index(HNSWIndex index) { + if (index) { + delete static_cast(index); + } +} + +int hnsw_add_point(HNSWIndex index, const float* data, unsigned long long label) { + if (!index || !data) return -1; + + try { + auto* impl = static_cast(index); + if (!impl->alg_hnsw) return -1; + + // For cosine similarity, we need to normalize the vector + if (impl->space_name == "cosine") { + std::vector normalized_data(impl->dimension); + float norm = 0.0f; + for (int i = 0; i < impl->dimension; i++) { + norm += data[i] * data[i]; + } + norm = std::sqrt(norm); + + if (norm > 0) { + for (int i = 0; i < impl->dimension; i++) { + normalized_data[i] = data[i] / norm; + } + impl->alg_hnsw->addPoint(normalized_data.data(), label); + } else { + return -1; // Cannot normalize zero vector + } + } else { + impl->alg_hnsw->addPoint(data, label); + } + return 0; + } catch (...) { + return -1; + } +} + +int hnsw_search_knn(HNSWIndex index, const float* query_data, int k, unsigned long long* labels, float* distances) { + if (!index || !query_data || !labels || !distances) return -1; + + try { + auto* impl = static_cast(index); + if (!impl->alg_hnsw) return -1; + + std::vector query_vector(impl->dimension); + + // For cosine similarity, normalize the query vector + if (impl->space_name == "cosine") { + float norm = 0.0f; + for (int i = 0; i < impl->dimension; i++) { + norm += query_data[i] * query_data[i]; + } + norm = std::sqrt(norm); + + if (norm > 0) { + for (int i = 0; i < impl->dimension; i++) { + query_vector[i] = query_data[i] / norm; + } + } else { + return -1; // Cannot normalize zero vector + } + } else { + for (int i = 0; i < impl->dimension; i++) { + query_vector[i] = query_data[i]; + } + } + + auto result = impl->alg_hnsw->searchKnn(query_vector.data(), k); + + int count = 0; + while (!result.empty() && count < k) { + distances[k - 1 - count] = result.top().first; + labels[k - 1 - count] = result.top().second; + result.pop(); + count++; + } + + return count; + } catch (...) { + return -1; + } +} + +void hnsw_set_ef(HNSWIndex index, int ef) { + if (!index) return; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->ef_ = ef; + } +} + +int hnsw_get_current_count(HNSWIndex index) { + if (!index) return -1; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + return impl->alg_hnsw->cur_element_count; + } + return -1; +} + +int hnsw_save_index(HNSWIndex index, const char* path) { + if (!index || !path) return -1; + + try { + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->saveIndex(path); + return 0; + } + return -1; + } catch (...) { + return -1; + } +} + +HNSWIndex hnsw_load_index(const char* space_name, int dim, const char* path, int max_elements) { + if (!space_name || !path) return nullptr; + + try { + auto* impl = new HNSWIndexImpl(space_name, dim); + impl->alg_hnsw = new hnswlib::HierarchicalNSW(impl->space, path, false, max_elements); + return static_cast(impl); + } catch (...) { + return nullptr; + } +} + +void hnsw_clear(HNSWIndex index) { + if (!index) return; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->clear(); + } +} + +int hnsw_resize_index(HNSWIndex index, int new_max_elements) { + if (!index) return -1; + + try { + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->resizeIndex(new_max_elements); + return 0; + } + return -1; + } catch (...) { + return -1; + } +} + +int hnsw_update_point(HNSWIndex index, const float* data, unsigned long long label) { + if (!index || !data) return -1; + + try { + auto* impl = static_cast(index); + if (!impl->alg_hnsw) return -1; + + // Find the internal ID for the label + hnswlib::tableint internal_id; + { + std::unique_lock lock(impl->alg_hnsw->label_lookup_lock); + auto search = impl->alg_hnsw->label_lookup_.find(label); + if (search == impl->alg_hnsw->label_lookup_.end()) { + return -2; // Label not found + } + internal_id = search->second; + } + + // Normalize for cosine similarity + if (impl->space_name == "cosine") { + std::vector normalized_data(impl->dimension); + float norm = 0.0f; + for (int i = 0; i < impl->dimension; i++) { + norm += data[i] * data[i]; + } + norm = std::sqrt(norm); + + if (norm > 0) { + for (int i = 0; i < impl->dimension; i++) { + normalized_data[i] = data[i] / norm; + } + impl->alg_hnsw->updatePoint(normalized_data.data(), internal_id, 1.0); + } else { + return -1; // Cannot normalize zero vector + } + } else { + impl->alg_hnsw->updatePoint(data, internal_id, 1.0); + } + return 0; + } catch (...) { + return -1; + } +} + +int hnsw_mark_deleted(HNSWIndex index, unsigned long long label) { + if (!index) return -1; + + try { + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->markDelete(label); + return 0; + } + return -1; + } catch (...) { + return -1; + } +} + +int hnsw_unmark_deleted(HNSWIndex index, unsigned long long label) { + if (!index) return -1; + + try { + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->unmarkDelete(label); + return 0; + } + return -1; + } catch (...) { + return -1; + } +} + +int hnsw_is_marked_deleted(HNSWIndex index, unsigned long long label) { + if (!index) return -1; + + try { + auto* impl = static_cast(index); + if (!impl->alg_hnsw) return -1; + + // Find the internal ID for the label + std::unique_lock lock(impl->alg_hnsw->label_lookup_lock); + auto search = impl->alg_hnsw->label_lookup_.find(label); + if (search == impl->alg_hnsw->label_lookup_.end()) { + return -2; // Label not found + } + hnswlib::tableint internal_id = search->second; + lock.unlock(); + + return impl->alg_hnsw->isMarkedDeleted(internal_id) ? 1 : 0; + } catch (...) { + return -1; + } +} + +int hnsw_get_deleted_count(HNSWIndex index) { + if (!index) return -1; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + return impl->alg_hnsw->num_deleted_; + } + return -1; +} + +int hnsw_get_max_elements(HNSWIndex index) { + if (!index) return -1; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + return impl->alg_hnsw->max_elements_; + } + return -1; +} + +} // extern "C" diff --git a/eigenix/pkg/hnswlib/hnsw_wrapper.cpp b/eigenix/pkg/hnswlib/hnsw_wrapper.cpp new file mode 100644 index 00000000..b3e3dd1e --- /dev/null +++ b/eigenix/pkg/hnswlib/hnsw_wrapper.cpp @@ -0,0 +1,316 @@ +#include "hnsw_wrapper.h" +#include "hnswlib.h" +#include +#include + +struct HNSWIndexImpl { + hnswlib::SpaceInterface* space; + hnswlib::HierarchicalNSW* alg_hnsw; + std::string space_name; + int dimension; + + HNSWIndexImpl(const std::string& space_name, int dim) + : space_name(space_name), dimension(dim), alg_hnsw(nullptr) { + + if (space_name == "l2") { + space = new hnswlib::L2Space(dim); + } else if (space_name == "ip") { + space = new hnswlib::InnerProductSpace(dim); + } else if (space_name == "cosine") { + space = new hnswlib::InnerProductSpace(dim); + } else { + throw std::runtime_error("Unknown space name"); + } + } + + ~HNSWIndexImpl() { + if (alg_hnsw) delete alg_hnsw; + if (space) delete space; + } +}; + +extern "C" { + +HNSWIndex hnsw_create_index(const char* space_name, int dim, int max_elements, int M, int ef_construction, int random_seed, int allow_replace_deleted) { + try { + auto* impl = new HNSWIndexImpl(space_name, dim); + bool allow_replace = (allow_replace_deleted != 0); + impl->alg_hnsw = new hnswlib::HierarchicalNSW(impl->space, max_elements, M, ef_construction, random_seed, allow_replace); + return static_cast(impl); + } catch (...) { + return nullptr; + } +} + +void hnsw_delete_index(HNSWIndex index) { + if (index) { + delete static_cast(index); + } +} + +int hnsw_add_point(HNSWIndex index, const float* data, unsigned long long label) { + if (!index || !data) return -1; + + try { + auto* impl = static_cast(index); + if (!impl->alg_hnsw) return -1; + + // For cosine similarity, we need to normalize the vector + if (impl->space_name == "cosine") { + std::vector normalized_data(impl->dimension); + float norm = 0.0f; + for (int i = 0; i < impl->dimension; i++) { + norm += data[i] * data[i]; + } + norm = std::sqrt(norm); + + if (norm > 0) { + for (int i = 0; i < impl->dimension; i++) { + normalized_data[i] = data[i] / norm; + } + impl->alg_hnsw->addPoint(normalized_data.data(), label); + } else { + return -1; // Cannot normalize zero vector + } + } else { + impl->alg_hnsw->addPoint(data, label); + } + return 0; + } catch (...) { + return -1; + } +} + +int hnsw_search_knn(HNSWIndex index, const float* query_data, int k, unsigned long long* labels, float* distances) { + if (!index || !query_data || !labels || !distances) return -1; + + try { + auto* impl = static_cast(index); + if (!impl->alg_hnsw) return -1; + + std::vector query_vector(impl->dimension); + + // For cosine similarity, normalize the query vector + if (impl->space_name == "cosine") { + float norm = 0.0f; + for (int i = 0; i < impl->dimension; i++) { + norm += query_data[i] * query_data[i]; + } + norm = std::sqrt(norm); + + if (norm > 0) { + for (int i = 0; i < impl->dimension; i++) { + query_vector[i] = query_data[i] / norm; + } + } else { + return -1; // Cannot normalize zero vector + } + } else { + for (int i = 0; i < impl->dimension; i++) { + query_vector[i] = query_data[i]; + } + } + + auto result = impl->alg_hnsw->searchKnn(query_vector.data(), k); + + int count = 0; + while (!result.empty() && count < k) { + distances[k - 1 - count] = result.top().first; + labels[k - 1 - count] = result.top().second; + result.pop(); + count++; + } + + return count; + } catch (...) { + return -1; + } +} + +void hnsw_set_ef(HNSWIndex index, int ef) { + if (!index) return; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->ef_ = ef; + } +} + +int hnsw_get_current_count(HNSWIndex index) { + if (!index) return -1; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + return impl->alg_hnsw->cur_element_count; + } + return -1; +} + +int hnsw_save_index(HNSWIndex index, const char* path) { + if (!index || !path) return -1; + + try { + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->saveIndex(path); + return 0; + } + return -1; + } catch (...) { + return -1; + } +} + +HNSWIndex hnsw_load_index(const char* space_name, int dim, const char* path, int max_elements) { + if (!space_name || !path) return nullptr; + + try { + auto* impl = new HNSWIndexImpl(space_name, dim); + impl->alg_hnsw = new hnswlib::HierarchicalNSW(impl->space, path, false, max_elements); + return static_cast(impl); + } catch (...) { + return nullptr; + } +} + +void hnsw_clear(HNSWIndex index) { + if (!index) return; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->clear(); + } +} + +int hnsw_resize_index(HNSWIndex index, int new_max_elements) { + if (!index) return -1; + + try { + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->resizeIndex(new_max_elements); + return 0; + } + return -1; + } catch (...) { + return -1; + } +} + +int hnsw_update_point(HNSWIndex index, const float* data, unsigned long long label) { + if (!index || !data) return -1; + + try { + auto* impl = static_cast(index); + if (!impl->alg_hnsw) return -1; + + // Find the internal ID for the label + hnswlib::tableint internal_id; + { + std::unique_lock lock(impl->alg_hnsw->label_lookup_lock); + auto search = impl->alg_hnsw->label_lookup_.find(label); + if (search == impl->alg_hnsw->label_lookup_.end()) { + return -2; // Label not found + } + internal_id = search->second; + } + + // Normalize for cosine similarity + if (impl->space_name == "cosine") { + std::vector normalized_data(impl->dimension); + float norm = 0.0f; + for (int i = 0; i < impl->dimension; i++) { + norm += data[i] * data[i]; + } + norm = std::sqrt(norm); + + if (norm > 0) { + for (int i = 0; i < impl->dimension; i++) { + normalized_data[i] = data[i] / norm; + } + impl->alg_hnsw->updatePoint(normalized_data.data(), internal_id, 1.0); + } else { + return -1; // Cannot normalize zero vector + } + } else { + impl->alg_hnsw->updatePoint(data, internal_id, 1.0); + } + return 0; + } catch (...) { + return -1; + } +} + +int hnsw_mark_deleted(HNSWIndex index, unsigned long long label) { + if (!index) return -1; + + try { + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->markDelete(label); + return 0; + } + return -1; + } catch (...) { + return -1; + } +} + +int hnsw_unmark_deleted(HNSWIndex index, unsigned long long label) { + if (!index) return -1; + + try { + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + impl->alg_hnsw->unmarkDelete(label); + return 0; + } + return -1; + } catch (...) { + return -1; + } +} + +int hnsw_is_marked_deleted(HNSWIndex index, unsigned long long label) { + if (!index) return -1; + + try { + auto* impl = static_cast(index); + if (!impl->alg_hnsw) return -1; + + // Find the internal ID for the label + std::unique_lock lock(impl->alg_hnsw->label_lookup_lock); + auto search = impl->alg_hnsw->label_lookup_.find(label); + if (search == impl->alg_hnsw->label_lookup_.end()) { + return -2; // Label not found + } + hnswlib::tableint internal_id = search->second; + lock.unlock(); + + return impl->alg_hnsw->isMarkedDeleted(internal_id) ? 1 : 0; + } catch (...) { + return -1; + } +} + +int hnsw_get_deleted_count(HNSWIndex index) { + if (!index) return -1; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + return impl->alg_hnsw->num_deleted_; + } + return -1; +} + +int hnsw_get_max_elements(HNSWIndex index) { + if (!index) return -1; + + auto* impl = static_cast(index); + if (impl->alg_hnsw) { + return impl->alg_hnsw->max_elements_; + } + return -1; +} + +} // extern "C" diff --git a/eigenix/pkg/hnswlib/hnsw_wrapper.h b/eigenix/pkg/hnswlib/hnsw_wrapper.h new file mode 100644 index 00000000..5ad77aea --- /dev/null +++ b/eigenix/pkg/hnswlib/hnsw_wrapper.h @@ -0,0 +1,63 @@ +#ifndef HNSW_WRAPPER_H +#define HNSW_WRAPPER_H + +#ifdef __cplusplus +extern "C" { +#endif + +// Opaque pointer to the C++ HNSW index +typedef void* HNSWIndex; + +// Create a new HNSW index +HNSWIndex hnsw_create_index(const char* space_name, int dim, int max_elements, int M, int ef_construction, int random_seed, int allow_replace_deleted); + +// Delete the HNSW index +void hnsw_delete_index(HNSWIndex index); + +// Add a point to the index +int hnsw_add_point(HNSWIndex index, const float* data, unsigned long long label); + +// Search for k nearest neighbors +int hnsw_search_knn(HNSWIndex index, const float* query_data, int k, unsigned long long* labels, float* distances); + +// Set ef parameter for search +void hnsw_set_ef(HNSWIndex index, int ef); + +// Get current element count +int hnsw_get_current_count(HNSWIndex index); + +// Save index to file +int hnsw_save_index(HNSWIndex index, const char* path); + +// Load index from file +HNSWIndex hnsw_load_index(const char* space_name, int dim, const char* path, int max_elements); + +// Clear all data from the index +void hnsw_clear(HNSWIndex index); + +// Resize the index to new maximum capacity +int hnsw_resize_index(HNSWIndex index, int new_max_elements); + +// Update an existing point's vector +int hnsw_update_point(HNSWIndex index, const float* data, unsigned long long label); + +// Mark a point as deleted (soft delete) +int hnsw_mark_deleted(HNSWIndex index, unsigned long long label); + +// Unmark a deleted point (restore) +int hnsw_unmark_deleted(HNSWIndex index, unsigned long long label); + +// Check if a point is marked as deleted +int hnsw_is_marked_deleted(HNSWIndex index, unsigned long long label); + +// Get number of deleted elements +int hnsw_get_deleted_count(HNSWIndex index); + +// Get maximum capacity +int hnsw_get_max_elements(HNSWIndex index); + +#ifdef __cplusplus +} +#endif + +#endif // HNSW_WRAPPER_H diff --git a/eigenix/pkg/hnswlib/hnswalg.h b/eigenix/pkg/hnswlib/hnswalg.h new file mode 100644 index 00000000..7614f3fc --- /dev/null +++ b/eigenix/pkg/hnswlib/hnswalg.h @@ -0,0 +1,1412 @@ +#pragma once + +#include "visited_list_pool.h" +#include "hnswlib.h" +#include +#include +#include +#include +#include +#include +#include + +namespace hnswlib { +typedef unsigned int tableint; +typedef unsigned int linklistsizeint; + +template +class HierarchicalNSW : public AlgorithmInterface { + public: + static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; + static const unsigned char DELETE_MARK = 0x01; + + size_t max_elements_{0}; + mutable std::atomic cur_element_count{0}; // current number of elements + size_t size_data_per_element_{0}; + size_t size_links_per_element_{0}; + mutable std::atomic num_deleted_{0}; // number of deleted elements + size_t M_{0}; + size_t maxM_{0}; + size_t maxM0_{0}; + size_t ef_construction_{0}; + size_t ef_{ 0 }; + + double mult_{0.0}, revSize_{0.0}; + int maxlevel_{0}; + + std::unique_ptr visited_list_pool_{nullptr}; + + // Locks operations with element by label value + mutable std::vector label_op_locks_; + + std::mutex global; + std::vector link_list_locks_; + + tableint enterpoint_node_{0}; + + size_t size_links_level0_{0}; + size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 }; + + char *data_level0_memory_{nullptr}; + char **linkLists_{nullptr}; + std::vector element_levels_; // keeps level of each element + + size_t data_size_{0}; + + DISTFUNC fstdistfunc_; + void *dist_func_param_{nullptr}; + + mutable std::mutex label_lookup_lock; // lock for label_lookup_ + std::unordered_map label_lookup_; + + std::default_random_engine level_generator_; + std::default_random_engine update_probability_generator_; + + mutable std::atomic metric_distance_computations{0}; + mutable std::atomic metric_hops{0}; + + bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions + + std::mutex deleted_elements_lock; // lock for deleted_elements + std::unordered_set deleted_elements; // contains internal ids of deleted elements + + + HierarchicalNSW(SpaceInterface *s) { + } + + + HierarchicalNSW( + SpaceInterface *s, + const std::string &location, + bool nmslib = false, + size_t max_elements = 0, + bool allow_replace_deleted = false) + : allow_replace_deleted_(allow_replace_deleted) { + loadIndex(location, s, max_elements); + } + + + HierarchicalNSW( + SpaceInterface *s, + size_t max_elements, + size_t M = 16, + size_t ef_construction = 200, + size_t random_seed = 100, + bool allow_replace_deleted = false) + : label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + link_list_locks_(max_elements), + element_levels_(max_elements), + allow_replace_deleted_(allow_replace_deleted) { + max_elements_ = max_elements; + num_deleted_ = 0; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + if ( M <= 10000 ) { + M_ = M; + } else { + HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl; + HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl; + M_ = 10000; + } + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction, M_); + ef_ = 10; + + level_generator_.seed(random_seed); + update_probability_generator_.seed(random_seed + 1); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + offsetData_ = size_links_level0_; + label_offset_ = size_links_level0_ + data_size_; + offsetLevel0_ = 0; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = std::unique_ptr(new VisitedListPool(1, max_elements)); + + // initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } + + + ~HierarchicalNSW() { + clear(); + } + + void clear() { + free(data_level0_memory_); + data_level0_memory_ = nullptr; + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + linkLists_ = nullptr; + cur_element_count = 0; + visited_list_pool_.reset(nullptr); + } + + + struct CompareByFirst { + constexpr bool operator()(std::pair const& a, + std::pair const& b) const noexcept { + return a.first < b.first; + } + }; + + + void setEf(size_t ef) { + ef_ = ef; + } + + + inline std::mutex& getLabelOpMutex(labeltype label) const { + // calculate hash + size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); + return label_op_locks_[lock_id]; + } + + + inline labeltype getExternalLabel(tableint internal_id) const { + labeltype return_label; + memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + return return_label; + } + + + inline void setExternalLabel(tableint internal_id, labeltype label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } + + + inline labeltype *getExternalLabeLp(tableint internal_id) const { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } + + + inline char *getDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } + + + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } + + size_t getMaxElements() { + return max_elements_; + } + + size_t getCurrentElementCount() { + return cur_element_count; + } + + size_t getDeletedCount() { + return num_deleted_; + } + + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; + + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; + + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { + break; + } + candidateSet.pop(); + + tableint curNodeNum = curr_el_pair.second; + + std::unique_lock lock(link_list_locks_[curNodeNum]); + + int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); +// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); +#endif + + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); +// if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); +#endif + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); + + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); +#endif + + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); + + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + + // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST( + tableint ep_id, + const void *data_point, + size_t ef, + BaseFilterFunctor* isIdAllowed = nullptr, + BaseSearchStopCondition* stop_condition = nullptr) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; + if (bare_bone_search || + (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { + char* ep_data = getDataByInternalId(ep_id); + dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_); + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist); + } + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); + } + + visited_array[ep_id] = visited_array_tag; + + while (!candidate_set.empty()) { + std::pair current_node_pair = candidate_set.top(); + dist_t candidate_dist = -current_node_pair.first; + + bool flag_stop_search; + if (bare_bone_search) { + flag_stop_search = candidate_dist > lowerBound; + } else { + if (stop_condition) { + flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound); + } else { + flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef; + } + } + if (flag_stop_search) { + break; + } + candidate_set.pop(); + + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); +// bool cur_node_deleted = isMarkedDeleted(current_node_id); + if (collect_metrics) { + metric_hops++; + metric_distance_computations+=size; + } + +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); +#endif + + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); +// if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0); //////////// +#endif + if (!(visited_array[candidate_id] == visited_array_tag)) { + visited_array[candidate_id] = visited_array_tag; + + char *currObj1 = (getDataByInternalId(candidate_id)); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + + bool flag_consider_candidate; + if (!bare_bone_search && stop_condition) { + flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound); + } else { + flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist; + } + + if (flag_consider_candidate) { + candidate_set.emplace(-dist, candidate_id); +#ifdef USE_SSE + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_, /////////// + _MM_HINT_T0); //////////////////////// +#endif + + if (bare_bone_search || + (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { + top_candidates.emplace(dist, candidate_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist); + } + } + + bool flag_remove_extra = false; + if (!bare_bone_search && stop_condition) { + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + while (flag_remove_extra) { + tableint id = top_candidates.top().second; + top_candidates.pop(); + if (!bare_bone_search && stop_condition) { + stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist); + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + } + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + } + + visited_list_pool_->releaseVisitedList(vl); + return top_candidates; + } + + + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { + if (top_candidates.size() < M) { + return; + } + + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(second_pair.second), + getDataByInternalId(curent_pair.second), + dist_func_param_); + if (curdist < dist_to_query) { + good = false; + break; + } + } + if (good) { + return_list.push_back(curent_pair); + } + } + + for (std::pair curent_pair : return_list) { + top_candidates.emplace(-curent_pair.first, curent_pair.second); + } + } + + + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } + + + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } + + + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + } + + + linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { + return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); + } + + + tableint mutuallyConnectNewElement( + const void *data_point, + tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, + bool isUpdate) { + size_t Mcurmax = level ? maxM_ : maxM0_; + getNeighborsByHeuristic2(top_candidates, M_); + if (top_candidates.size() > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); + + std::vector selectedNeighbors; + selectedNeighbors.reserve(M_); + while (top_candidates.size() > 0) { + selectedNeighbors.push_back(top_candidates.top().second); + top_candidates.pop(); + } + + tableint next_closest_entry_point = selectedNeighbors.back(); + + { + // lock only during the update + // because during the addition the lock for cur_c is already acquired + std::unique_lock lock(link_list_locks_[cur_c], std::defer_lock); + if (isUpdate) { + lock.lock(); + } + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + + if (*ll_cur && !isUpdate) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur, selectedNeighbors.size()); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + if (data[idx] && !isUpdate) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + data[idx] = selectedNeighbors[idx]; + } + } + + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); + + size_t sz_link_list_other = getListCount(ll_other); + + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + tableint *data = (tableint *) (ll_other + 1); + + bool is_cur_c_present = false; + if (isUpdate) { + for (size_t j = 0; j < sz_link_list_other; j++) { + if (data[j] == cur_c) { + is_cur_c_present = true; + break; + } + } + } + + // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. + if (!is_cur_c_present) { + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } + + getNeighborsByHeuristic2(candidates, Mcurmax); + + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } + } + if (indx >= 0) { + data[indx] = cur_c; + } */ + } + } + } + + return next_closest_entry_point; + } + + + void resizeIndex(size_t new_max_elements) { + if (new_max_elements < cur_element_count) + throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); + + visited_list_pool_.reset(new VisitedListPool(1, new_max_elements)); + + element_levels_.resize(new_max_elements); + + std::vector(new_max_elements).swap(link_list_locks_); + + // Reallocate base layer + char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + data_level0_memory_ = data_level0_memory_new; + + // Reallocate all other layers + char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + linkLists_ = linkLists_new; + + max_elements_ = new_max_elements; + } + + size_t indexFileSize() const { + size_t size = 0; + size += sizeof(offsetLevel0_); + size += sizeof(max_elements_); + size += sizeof(cur_element_count); + size += sizeof(size_data_per_element_); + size += sizeof(label_offset_); + size += sizeof(offsetData_); + size += sizeof(maxlevel_); + size += sizeof(enterpoint_node_); + size += sizeof(maxM_); + + size += sizeof(maxM0_); + size += sizeof(M_); + size += sizeof(mult_); + size += sizeof(ef_construction_); + + size += cur_element_count * size_data_per_element_; + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + size += sizeof(linkListSize); + size += linkListSize; + } + return size; + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + clear(); + // get file size: + input.seekg(0, input.end); + std::streampos total_filesize = input.tellg(); + input.seekg(0, input.beg); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements = max_elements_i; + if (max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos = input.tellg(); + + /// Optional - check if index is ok: + input.seekg(cur_element_count * size_data_per_element_, input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if (input.tellg() < 0 || input.tellg() >= total_filesize) { + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } + + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize, input.cur); + } + } + + // throw exception if it either corrupted or old index + if (input.tellg() != total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + /// Optional check end + + input.seekg(pos, input.beg); + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); + + visited_list_pool_.reset(new VisitedListPool(1, max_elements)); + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)] = i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + for (size_t i = 0; i < cur_element_count; i++) { + if (isMarkedDeleted(i)) { + num_deleted_ += 1; + if (allow_replace_deleted_) deleted_elements.insert(i); + } + } + + input.close(); + + return; + } + + + template + std::vector getDataByLabel(labeltype label) const { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + char* data_ptrv = getDataByInternalId(internalId); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (size_t i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } + + + /* + * Marks an element with the given label deleted, does NOT really change the current graph. + */ + void markDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + markDeletedInternal(internalId); + } + + + /* + * Uses the last 16 bits of the memory for the linked list size to store the mark, + * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. + */ + void markDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (!isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + num_deleted_ += 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.insert(internalId); + } + } else { + throw std::runtime_error("The requested to delete element is already deleted"); + } + } + + + /* + * Removes the deleted mark of the node, does NOT really change the current graph. + * + * Note: the method is not safe to use when replacement of deleted elements is enabled, + * because elements marked as deleted can be completely removed by addPoint + */ + void unmarkDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + unmarkDeletedInternal(internalId); + } + + + + /* + * Remove the deleted mark of the node. + */ + void unmarkDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; + *ll_cur &= ~DELETE_MARK; + num_deleted_ -= 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.erase(internalId); + } + } else { + throw std::runtime_error("The requested to undelete element is not deleted"); + } + } + + + /* + * Checks the first 16 bits of the memory to see if the element is marked deleted. + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; + return *ll_cur & DELETE_MARK; + } + + + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } + + + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } + + + /* + * Adds point. Updates the point if it is already in the index. + * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point + */ + void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { + if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { + throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); + } + + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + if (!replace_deleted) { + addPoint(data_point, label, -1); + return; + } + // check if there is vacant place + tableint internal_id_replaced; + std::unique_lock lock_deleted_elements(deleted_elements_lock); + bool is_vacant_place = !deleted_elements.empty(); + if (is_vacant_place) { + internal_id_replaced = *deleted_elements.begin(); + deleted_elements.erase(internal_id_replaced); + } + lock_deleted_elements.unlock(); + + // if there is no vacant place then add or update point + // else add point to vacant place + if (!is_vacant_place) { + addPoint(data_point, label, -1); + } else { + // we assume that there are no concurrent operations on deleted element + labeltype label_replaced = getExternalLabel(internal_id_replaced); + setExternalLabel(internal_id_replaced, label); + + std::unique_lock lock_table(label_lookup_lock); + label_lookup_.erase(label_replaced); + label_lookup_[label] = internal_id_replaced; + lock_table.unlock(); + + unmarkDeletedInternal(internal_id_replaced); + updatePoint(data_point, internal_id_replaced, 1.0); + } + } + + + void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { + // update the feature vector associated with existing point with new vector + memcpy(getDataByInternalId(internalId), dataPoint, data_size_); + + int maxLevelCopy = maxlevel_; + tableint entryPointCopy = enterpoint_node_; + // If point to be updated is entry point and graph just contains single element then just return. + if (entryPointCopy == internalId && cur_element_count == 1) + return; + + int elemLevel = element_levels_[internalId]; + std::uniform_real_distribution distribution(0.0, 1.0); + for (int layer = 0; layer <= elemLevel; layer++) { + std::unordered_set sCand; + std::unordered_set sNeigh; + std::vector listOneHop = getConnectionsWithLock(internalId, layer); + if (listOneHop.size() == 0) + continue; + + sCand.insert(internalId); + + for (auto&& elOneHop : listOneHop) { + sCand.insert(elOneHop); + + if (distribution(update_probability_generator_) > updateNeighborProbability) + continue; + + sNeigh.insert(elOneHop); + + std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); + for (auto&& elTwoHop : listTwoHop) { + sCand.insert(elTwoHop); + } + } + + for (auto&& neigh : sNeigh) { + // if (neigh == internalId) + // continue; + + std::priority_queue, std::vector>, CompareByFirst> candidates; + size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 + size_t elementsToKeep = std::min(ef_construction_, size); + for (auto&& cand : sCand) { + if (cand == neigh) + continue; + + dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); + if (candidates.size() < elementsToKeep) { + candidates.emplace(distance, cand); + } else { + if (distance < candidates.top().first) { + candidates.pop(); + candidates.emplace(distance, cand); + } + } + } + + // Retrieve neighbours using heuristic and set connections. + getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); + + { + std::unique_lock lock(link_list_locks_[neigh]); + linklistsizeint *ll_cur; + ll_cur = get_linklist_at_level(neigh, layer); + size_t candSize = candidates.size(); + setListCount(ll_cur, candSize); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < candSize; idx++) { + data[idx] = candidates.top().second; + candidates.pop(); + } + } + } + } + + repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); + } + + + void repairConnectionsForUpdate( + const void *dataPoint, + tableint entryPointInternalId, + tableint dataPointInternalId, + int dataPointLevel, + int maxLevel) { + tableint currObj = entryPointInternalId; + if (dataPointLevel < maxLevel) { + dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxLevel; level > dataPointLevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist_at_level(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); +#endif + for (int i = 0; i < size; i++) { +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); +#endif + tableint cand = datal[i]; + dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + if (dataPointLevel > maxLevel) + throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); + + for (int level = dataPointLevel; level >= 0; level--) { + std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( + currObj, dataPoint, level); + + std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; + while (topCandidates.size() > 0) { + if (topCandidates.top().second != dataPointInternalId) + filteredTopCandidates.push(topCandidates.top()); + + topCandidates.pop(); + } + + // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. + // To prevent self loops, the `topCandidates` is filtered and thus can be empty. + if (filteredTopCandidates.size() > 0) { + bool epDeleted = isMarkedDeleted(entryPointInternalId); + if (epDeleted) { + filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); + if (filteredTopCandidates.size() > ef_construction_) + filteredTopCandidates.pop(); + } + + currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + } + } + } + + + std::vector getConnectionsWithLock(tableint internalId, int level) { + std::unique_lock lock(link_list_locks_[internalId]); + unsigned int *data = get_linklist_at_level(internalId, level); + int size = getListCount(data); + std::vector result(size); + tableint *ll = (tableint *) (data + 1); + memcpy(result.data(), ll, size * sizeof(tableint)); + return result; + } + + + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; + { + // Checking if the element with the same label already exists + // if so, updating it *instead* of creating a new element. + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + tableint existingInternalId = search->second; + if (allow_replace_deleted_) { + if (isMarkedDeleted(existingInternalId)) { + throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); + } + } + lock_table.unlock(); + + if (isMarkedDeleted(existingInternalId)) { + unmarkDeletedInternal(existingInternalId); + } + updatePoint(data_point, existingInternalId, 1.0); + + return existingInternalId; + } + + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + } + + cur_c = cur_element_count; + cur_element_count++; + label_lookup_[label] = cur_c; + } + + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + + element_levels_[cur_c] = curlevel; + + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + + if ((signed)currObj != -1) { + if (curlevel < maxlevelcopy) { + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj, level); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + } + currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); + } + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + } + + // Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + bool bare_bone_search = !num_deleted_ && !isIdAllowed; + if (bare_bone_search) { + top_candidates = searchBaseLayerST( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } else { + top_candidates = searchBaseLayerST( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } + + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); + result.push(std::pair(rez.first, getExternalLabel(rez.second))); + top_candidates.pop(); + } + return result; + } + + + std::vector> + searchStopConditionClosest( + const void *query_data, + BaseSearchStopCondition& stop_condition, + BaseFilterFunctor* isIdAllowed = nullptr) const { + std::vector> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + top_candidates = searchBaseLayerST(currObj, query_data, 0, isIdAllowed, &stop_condition); + + size_t sz = top_candidates.size(); + result.resize(sz); + while (!top_candidates.empty()) { + result[--sz] = top_candidates.top(); + top_candidates.pop(); + } + + stop_condition.filter_results(result); + + return result; + } + + + void checkIntegrity() { + int connections_checked = 0; + std::vector inbound_connections_num(cur_element_count, 0); + for (int i = 0; i < cur_element_count; i++) { + for (int l = 0; l <= element_levels_[i]; l++) { + linklistsizeint *ll_cur = get_linklist_at_level(i, l); + int size = getListCount(ll_cur); + tableint *data = (tableint *) (ll_cur + 1); + std::unordered_set s; + for (int j = 0; j < size; j++) { + assert(data[j] < cur_element_count); + assert(data[j] != i); + inbound_connections_num[data[j]]++; + s.insert(data[j]); + connections_checked++; + } + assert(s.size() == size); + } + } + if (cur_element_count > 1) { + int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0]; + for (int i=0; i < cur_element_count; i++) { + assert(inbound_connections_num[i] > 0); + min1 = std::min(inbound_connections_num[i], min1); + max1 = std::max(inbound_connections_num[i], max1); + } + std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; + } + std::cout << "integrity ok, checked " << connections_checked << " connections\n"; + } +}; +} // namespace hnswlib diff --git a/eigenix/pkg/hnswlib/hnswlib.h b/eigenix/pkg/hnswlib/hnswlib.h new file mode 100644 index 00000000..7ccfbba5 --- /dev/null +++ b/eigenix/pkg/hnswlib/hnswlib.h @@ -0,0 +1,228 @@ +#pragma once + +// https://github.com/nmslib/hnswlib/pull/508 +// This allows others to provide their own error stream (e.g. RcppHNSW) +#ifndef HNSWLIB_ERR_OVERRIDE + #define HNSWERR std::cerr +#else + #define HNSWERR HNSWLIB_ERR_OVERRIDE +#endif + +#ifndef NO_MANUAL_VECTORIZATION +#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) +#define USE_SSE +#ifdef __AVX__ +#define USE_AVX +#ifdef __AVX512F__ +#define USE_AVX512 +#endif +#endif +#endif +#endif + +#if defined(USE_AVX) || defined(USE_SSE) +#ifdef _MSC_VER +#include +#include +static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { + __cpuidex(out, eax, ecx); +} +static __int64 xgetbv(unsigned int x) { + return _xgetbv(x); +} +#else +#include +#include +#include +static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { + __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); +} +static uint64_t xgetbv(unsigned int index) { + uint32_t eax, edx; + __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); + return ((uint64_t)edx << 32) | eax; +} +#endif + +#if defined(USE_AVX512) +#include +#endif + +#if defined(__GNUC__) +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#define PORTABLE_ALIGN64 __attribute__((aligned(64))) +#else +#define PORTABLE_ALIGN32 __declspec(align(32)) +#define PORTABLE_ALIGN64 __declspec(align(64)) +#endif + +// Adapted from https://github.com/Mysticial/FeatureDetector +#define _XCR_XFEATURE_ENABLED_MASK 0 + +static bool AVXCapable() { + int cpuInfo[4]; + + // CPU support + cpuid(cpuInfo, 0, 0); + int nIds = cpuInfo[0]; + + bool HW_AVX = false; + if (nIds >= 0x00000001) { + cpuid(cpuInfo, 0x00000001, 0); + HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0; + } + + // OS support + cpuid(cpuInfo, 1, 0); + + bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; + bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; + + bool avxSupported = false; + if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { + uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); + avxSupported = (xcrFeatureMask & 0x6) == 0x6; + } + return HW_AVX && avxSupported; +} + +static bool AVX512Capable() { + if (!AVXCapable()) return false; + + int cpuInfo[4]; + + // CPU support + cpuid(cpuInfo, 0, 0); + int nIds = cpuInfo[0]; + + bool HW_AVX512F = false; + if (nIds >= 0x00000007) { // AVX512 Foundation + cpuid(cpuInfo, 0x00000007, 0); + HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; + } + + // OS support + cpuid(cpuInfo, 1, 0); + + bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; + bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; + + bool avx512Supported = false; + if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { + uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); + avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6; + } + return HW_AVX512F && avx512Supported; +} +#endif + +#include +#include +#include +#include + +namespace hnswlib { +typedef size_t labeltype; + +// This can be extended to store state for filtering (e.g. from a std::set) +class BaseFilterFunctor { + public: + virtual bool operator()(hnswlib::labeltype id) { return true; } + virtual ~BaseFilterFunctor() {}; +}; + +template +class BaseSearchStopCondition { + public: + virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0; + + virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0; + + virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0; + + virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0; + + virtual bool should_remove_extra() = 0; + + virtual void filter_results(std::vector> &candidates) = 0; + + virtual ~BaseSearchStopCondition() {} +}; + +template +class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; + } +}; + +template +static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); +} + +template +static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); +} + +template +using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + +template +class SpaceInterface { + public: + // virtual void search(void *); + virtual size_t get_data_size() = 0; + + virtual DISTFUNC get_dist_func() = 0; + + virtual void *get_dist_func_param() = 0; + + virtual ~SpaceInterface() {} +}; + +template +class AlgorithmInterface { + public: + virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; + + virtual std::priority_queue> + searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; + + // Return k nearest neighbor in the order of closer fist + virtual std::vector> + searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; + + virtual void saveIndex(const std::string &location) = 0; + virtual ~AlgorithmInterface(){ + } +}; + +template +std::vector> +AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + BaseFilterFunctor* isIdAllowed) const { + std::vector> result; + + // here searchKnn returns the result in the order of further first + auto ret = searchKnn(query_data, k, isIdAllowed); + { + size_t sz = ret.size(); + result.resize(sz); + while (!ret.empty()) { + result[--sz] = ret.top(); + ret.pop(); + } + } + + return result; +} +} // namespace hnswlib + +#include "space_l2.h" +#include "space_ip.h" +#include "stop_condition.h" +#include "bruteforce.h" +#include "hnswalg.h" diff --git a/eigenix/pkg/hnswlib/space_ip.h b/eigenix/pkg/hnswlib/space_ip.h new file mode 100644 index 00000000..0e6834c1 --- /dev/null +++ b/eigenix/pkg/hnswlib/space_ip.h @@ -0,0 +1,400 @@ +#pragma once +#include "hnswlib.h" + +namespace hnswlib { + +static float +InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; + } + return res; +} + +static float +InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) { + return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); +} + +#if defined(USE_AVX) + +// Favor using AVX if available. +static float +InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + __m128 v1, v2; + __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + return sum; +} + +static float +InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr); +} + +#endif + +#if defined(USE_SSE) + +static float +InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return sum; +} + +static float +InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr); +} + +#endif + + +#if defined(USE_AVX512) + +static float +InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN64 TmpRes[16]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m512 sum512 = _mm512_set1_ps(0); + + size_t loop = qty16 / 4; + + while (loop--) { + __m512 v1 = _mm512_loadu_ps(pVect1); + __m512 v2 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + __m512 v3 = _mm512_loadu_ps(pVect1); + __m512 v4 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + __m512 v5 = _mm512_loadu_ps(pVect1); + __m512 v6 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + __m512 v7 = _mm512_loadu_ps(pVect1); + __m512 v8 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + sum512 = _mm512_fmadd_ps(v1, v2, sum512); + sum512 = _mm512_fmadd_ps(v3, v4, sum512); + sum512 = _mm512_fmadd_ps(v5, v6, sum512); + sum512 = _mm512_fmadd_ps(v7, v8, sum512); + } + + while (pVect1 < pEnd1) { + __m512 v1 = _mm512_loadu_ps(pVect1); + __m512 v2 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + sum512 = _mm512_fmadd_ps(v1, v2, sum512); + } + + float sum = _mm512_reduce_add_ps(sum512); + return sum; +} + +static float +InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr); +} + +#endif + +#if defined(USE_AVX) + +static float +InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + _mm256_store_ps(TmpRes, sum256); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return sum; +} + +static float +InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr); +} + +#endif + +#if defined(USE_SSE) + +static float +InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return sum; +} + +static float +InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr); +} + +#endif + +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; +static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; +static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; +static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; + +static float +InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + return 1.0f - (res + res_tail); +} + +static float +InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; + + float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; + + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + + return 1.0f - (res + res_tail); +} +#endif + +class InnerProductSpace : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + InnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistance; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #endif + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; + } + #endif + + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + +~InnerProductSpace() {} +}; + +} // namespace hnswlib diff --git a/eigenix/pkg/hnswlib/space_l2.h b/eigenix/pkg/hnswlib/space_l2.h new file mode 100644 index 00000000..834d19f7 --- /dev/null +++ b/eigenix/pkg/hnswlib/space_l2.h @@ -0,0 +1,324 @@ +#pragma once +#include "hnswlib.h" + +namespace hnswlib { + +static float +L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + float res = 0; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; + res += t * t; + } + return (res); +} + +#if defined(USE_AVX512) + +// Favor using AVX512 if available. +static float +L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN64 TmpRes[16]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m512 diff, v1, v2; + __m512 sum = _mm512_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm512_loadu_ps(pVect1); + pVect1 += 16; + v2 = _mm512_loadu_ps(pVect2); + pVect2 += 16; + diff = _mm512_sub_ps(v1, v2); + // sum = _mm512_fmadd_ps(diff, diff, sum); + sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); + } + + _mm512_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + + TmpRes[13] + TmpRes[14] + TmpRes[15]; + + return (res); +} +#endif + +#if defined(USE_AVX) + +// Favor using AVX if available. +static float +L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + } + + _mm256_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; +} + +#endif + +#if defined(USE_SSE) + +static float +L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} +#endif + +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; + +static float +L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + return (res + res_tail); +} +#endif + + +#if defined(USE_SSE) +static float +L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + + size_t qty4 = qty >> 2; + + const float *pEnd1 = pVect1 + (qty4 << 2); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} + +static float +L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; + + float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; + + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + + return (res + res_tail); +} +#endif + +class L2Space : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + L2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #endif + + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2Space() {} +}; + +static int +L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + int res = 0; + unsigned char *a = (unsigned char *) pVect1; + unsigned char *b = (unsigned char *) pVect2; + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + } + return (res); +} + +static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + int res = 0; + unsigned char* a = (unsigned char*)pVect1; + unsigned char* b = (unsigned char*)pVect2; + + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + } + return (res); +} + +class L2SpaceI : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + L2SpaceI(size_t dim) { + if (dim % 4 == 0) { + fstdistfunc_ = L2SqrI4x; + } else { + fstdistfunc_ = L2SqrI; + } + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2SpaceI() {} +}; +} // namespace hnswlib diff --git a/eigenix/pkg/hnswlib/stop_condition.h b/eigenix/pkg/hnswlib/stop_condition.h new file mode 100644 index 00000000..acc80ebe --- /dev/null +++ b/eigenix/pkg/hnswlib/stop_condition.h @@ -0,0 +1,276 @@ +#pragma once +#include "space_l2.h" +#include "space_ip.h" +#include +#include + +namespace hnswlib { + +template +class BaseMultiVectorSpace : public SpaceInterface { + public: + virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0; + + virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0; +}; + + +template +class MultiVectorL2Space : public BaseMultiVectorSpace { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t vector_size_; + size_t dim_; + + public: + MultiVectorL2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #endif + + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + vector_size_ = dim * sizeof(float); + data_size_ = vector_size_ + sizeof(DOCIDTYPE); + } + + size_t get_data_size() override { + return data_size_; + } + + DISTFUNC get_dist_func() override { + return fstdistfunc_; + } + + void *get_dist_func_param() override { + return &dim_; + } + + DOCIDTYPE get_doc_id(const void *datapoint) override { + return *(DOCIDTYPE *)((char *)datapoint + vector_size_); + } + + void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { + *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; + } + + ~MultiVectorL2Space() {} +}; + + +template +class MultiVectorInnerProductSpace : public BaseMultiVectorSpace { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t vector_size_; + size_t dim_; + + public: + MultiVectorInnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistance; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #endif + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; + } + #endif + + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif + vector_size_ = dim * sizeof(float); + data_size_ = vector_size_ + sizeof(DOCIDTYPE); + } + + size_t get_data_size() override { + return data_size_; + } + + DISTFUNC get_dist_func() override { + return fstdistfunc_; + } + + void *get_dist_func_param() override { + return &dim_; + } + + DOCIDTYPE get_doc_id(const void *datapoint) override { + return *(DOCIDTYPE *)((char *)datapoint + vector_size_); + } + + void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { + *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; + } + + ~MultiVectorInnerProductSpace() {} +}; + + +template +class MultiVectorSearchStopCondition : public BaseSearchStopCondition { + size_t curr_num_docs_; + size_t num_docs_to_search_; + size_t ef_collection_; + std::unordered_map doc_counter_; + std::priority_queue> search_results_; + BaseMultiVectorSpace& space_; + + public: + MultiVectorSearchStopCondition( + BaseMultiVectorSpace& space, + size_t num_docs_to_search, + size_t ef_collection = 10) + : space_(space) { + curr_num_docs_ = 0; + num_docs_to_search_ = num_docs_to_search; + ef_collection_ = std::max(ef_collection, num_docs_to_search); + } + + void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + DOCIDTYPE doc_id = space_.get_doc_id(datapoint); + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ += 1; + } + search_results_.emplace(dist, doc_id); + doc_counter_[doc_id] += 1; + } + + void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + DOCIDTYPE doc_id = space_.get_doc_id(datapoint); + doc_counter_[doc_id] -= 1; + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ -= 1; + } + search_results_.pop(); + } + + bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { + bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_; + return stop_search; + } + + bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { + bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist; + return flag_consider_candidate; + } + + bool should_remove_extra() override { + bool flag_remove_extra = curr_num_docs_ > ef_collection_; + return flag_remove_extra; + } + + void filter_results(std::vector> &candidates) override { + while (curr_num_docs_ > num_docs_to_search_) { + dist_t dist_cand = candidates.back().first; + dist_t dist_res = search_results_.top().first; + assert(dist_cand == dist_res); + DOCIDTYPE doc_id = search_results_.top().second; + doc_counter_[doc_id] -= 1; + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ -= 1; + } + search_results_.pop(); + candidates.pop_back(); + } + } + + ~MultiVectorSearchStopCondition() {} +}; + + +template +class EpsilonSearchStopCondition : public BaseSearchStopCondition { + float epsilon_; + size_t min_num_candidates_; + size_t max_num_candidates_; + size_t curr_num_items_; + + public: + EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) { + assert(min_num_candidates <= max_num_candidates); + epsilon_ = epsilon; + min_num_candidates_ = min_num_candidates; + max_num_candidates_ = max_num_candidates; + curr_num_items_ = 0; + } + + void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + curr_num_items_ += 1; + } + + void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + curr_num_items_ -= 1; + } + + bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { + if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) { + // new candidate can't improve found results + return true; + } + if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) { + // new candidate is out of epsilon region and + // minimum number of candidates is checked + return true; + } + return false; + } + + bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { + bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist; + return flag_consider_candidate; + } + + bool should_remove_extra() { + bool flag_remove_extra = curr_num_items_ > max_num_candidates_; + return flag_remove_extra; + } + + void filter_results(std::vector> &candidates) override { + while (!candidates.empty() && candidates.back().first > epsilon_) { + candidates.pop_back(); + } + while (candidates.size() > max_num_candidates_) { + candidates.pop_back(); + } + } + + ~EpsilonSearchStopCondition() {} +}; +} // namespace hnswlib diff --git a/eigenix/pkg/hnswlib/visited_list_pool.h b/eigenix/pkg/hnswlib/visited_list_pool.h new file mode 100644 index 00000000..2e201ec4 --- /dev/null +++ b/eigenix/pkg/hnswlib/visited_list_pool.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include + +namespace hnswlib { +typedef unsigned short int vl_type; + +class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; + + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } + + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); + curV++; + } + } + + ~VisitedList() { delete[] mass; } +}; +/////////////////////////////////////////////////////////// +// +// Class for multi-threaded pool-management of VisitedLists +// +///////////////////////////////////////////////////////// + +class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; + + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } + + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { + std::unique_lock lock(poolguard); + if (pool.size() > 0) { + rez = pool.front(); + pool.pop_front(); + } else { + rez = new VisitedList(numelements); + } + } + rez->reset(); + return rez; + } + + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + } + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + } +}; +} // namespace hnswlib