diff --git a/go-sdk/VERSION b/go-sdk/VERSION index 60453e69..992977ad 100644 --- a/go-sdk/VERSION +++ b/go-sdk/VERSION @@ -1 +1 @@ -v1.0.0 \ No newline at end of file +v1.1.0 \ No newline at end of file diff --git a/inferflow/PREDICT_APIS_AND_FEATURE_LOGGING.md b/inferflow/PREDICT_APIS_AND_FEATURE_LOGGING.md new file mode 100644 index 00000000..b8916147 --- /dev/null +++ b/inferflow/PREDICT_APIS_AND_FEATURE_LOGGING.md @@ -0,0 +1,153 @@ +# Predict APIs and Feature Logging + +This document describes the Predict gRPC APIs exposed by Inferflow, and the high-level request/logging flow for PointWise, PairWise, and SlateWise inference. + +## APIs Exposed + +The `Predict` service exposes three RPCs: + +- `InferPointWise(PointWiseRequest) returns (PointWiseResponse)` +- `InferPairWise(PairWiseRequest) returns (PairWiseResponse)` +- `InferSlateWise(SlateWiseRequest) returns (SlateWiseResponse)` + +Reference: `inferflow/server/proto/predict.proto`. + +## API Intent + +### PointWise (sometimes referred to as "pintwise") + +- Use when you need per-target scoring. +- Input: + - `context_features` (request-level features) + - `target_input_schema` (schema for `Target.feature_values`) + - `targets` (entities/items to score) +- Output: + - `target_output_schema` + - `target_scores` (one score row per target) + - `request_error` (request-level error) + +### PairWise + +- Use when you need pair-level scoring/ranking and optional per-target outputs. +- Input: + - `targets` (base target pool) + - `pairs` (`first_target_index` + `second_target_index`) + - `pair_input_schema` for pair-level features + - `target_input_schema` for target-level features +- Output: + - `pair_scores` aligned with `pair_output_schema` + - `target_scores` aligned with `target_output_schema` + - `request_error` + +### SlateWise + +- Use when you need slate-level scoring plus optional per-target outputs. +- Input: + - `targets` (base target pool) + - `slates` (`target_indices` per slate) + - `slate_input_schema` for slate-level features + - `target_input_schema` for target-level features +- Output: + - `slate_scores` aligned with `slate_output_schema` + - `target_scores` aligned with `target_output_schema` + - `request_error` + +## High-Level Flow + +The runtime path is common across all three APIs with adapter-specific shaping: + +1. Receive gRPC request in `PredictService` (`predict_handler.go`). +2. Load model config via `config.GetModelConfig(model_config_id)`. +3. Adapt request into `components.ComponentRequest` (`predict_adapter.go`): + - Build `ComponentData` (target matrix). + - For PairWise/SlateWise, also build `SlateData` (slate matrix). +4. Execute DAG components via `executor.Execute(...)`. +5. Build RPC response from matrices (`predict_response.go`): + - PointWise from target matrix. + - PairWise/SlateWise from target + slate matrices. +6. Emit metrics (`request.total`, `latency`, `batch.size`). +7. Optionally trigger feature logging asynchronously (`maybeLogInferenceResponse`). + +## Matrix Model Used by Predict + +- `ComponentData`: + - Main per-target matrix. + - Always present. +- `SlateData`: + - Per-slate matrix. + - Present for PairWise and SlateWise. + - Contains `slate_target_indices` and slate-level features. + +## Feature Logging: Current Behavior + +Feature logging is implemented in `inferflow/handlers/inferflow/feature_logging.go`. + +### Trigger Conditions + +Logging is attempted only when all are true: + +- `conf.ResponseConfig.LoggingPerc > 0` +- Random sampling check passes: `rand.Intn(100)+1 <= LoggingPerc` +- `tracking_id` is non-empty + +Reference: `maybeLogInferenceResponse` in `predict_handler.go`. + +### Logging Format + +V2 format is selected using: + +- `config.GetModelConfigMap().ServiceConfig.V2LoggingType` + +Supported format values: + +- `proto` +- `arrow` +- `parquet` + +### Logged Message Shape + +Logged payload uses `InferflowLog` (`inferflow_logging.proto`): + +- `user_id` +- `tracking_id` +- `model_config_id` +- `entities` +- `features` (`PerEntityFeatures.encoded_features`) +- `metadata` +- `parent_entity` + +### What Data Is Logged Today + +Current logging functions (`logInferflowResponseBytes`, `logInferflowResponseArrow`, `logInferflowResponseParquet`) read from: + +- `compRequest.ComponentData` (target matrix) + +They do not currently read from: + +- `compRequest.SlateData` (slate matrix) + +So for PairWise/SlateWise requests, logging currently captures target-matrix features, not slate-matrix rows. + +### Metadata and Transport + +- Metadata byte packs: + - compression-enabled bit + - cache version + - format type (proto/arrow/parquet) +- Logs are sent to Kafka through prism logger (`PublishInferenceInsightsLog`). +- Event name used: `inferflow_inference_logs`. + +### Batching + +- Logs are batched before Kafka publish. +- Batch size uses `ResponseConfig.LogBatchSize`, default `500`. + +## Notes for Future Slate Logging + +If slate logging is required, the cleanest approach is to add a parallel slate logging path that: + +- reads from `compRequest.SlateData` +- builds slate-oriented encoded payloads +- publishes with same `tracking_id` and `model_config_id` for correlation + +This avoids breaking existing target-log consumers. \ No newline at end of file diff --git a/inferflow/cmd/inferflow/Dockerfile b/inferflow/cmd/inferflow/Dockerfile index e873e8c9..9aa15839 100644 --- a/inferflow/cmd/inferflow/Dockerfile +++ b/inferflow/cmd/inferflow/Dockerfile @@ -1,5 +1,5 @@ # Stage 1: Build the Go binary -FROM golang:1.24.4-bullseye AS builder +FROM golang:1.24.9-bookworm AS builder ARG TARGETOS ARG TARGETARCH diff --git a/inferflow/go.mod b/inferflow/go.mod index 3bac3746..a63d7d6d 100644 --- a/inferflow/go.mod +++ b/inferflow/go.mod @@ -1,17 +1,20 @@ module github.com/Meesho/BharatMLStack/inferflow -go 1.24.4 +go 1.24.9 require ( github.com/DataDog/datadog-go/v5 v5.5.0 github.com/Meesho/BharatMLStack/helix-client v1.0.0-alpha-649f16 + github.com/apache/arrow/go/v16 v16.1.0 github.com/cockroachdb/cmux v0.0.0-20170110192607-30d10be49292 github.com/coocood/freecache v1.2.4 github.com/dgraph-io/ristretto v0.2.0 github.com/emirpasic/gods v1.18.1 github.com/h2so5/half v1.0.0 github.com/knadh/koanf v1.5.0 + github.com/parquet-go/parquet-go v0.27.0 github.com/rs/zerolog v1.34.0 + github.com/segmentio/kafka-go v0.4.50 github.com/spaolacci/murmur3 v1.1.0 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.11.1 @@ -23,23 +26,32 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect + github.com/google/flatbuffers v24.3.25+incompatible // indirect + github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect + github.com/parquet-go/bitpack v1.0.0 // indirect + github.com/parquet-go/jsonlite v1.0.0 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect @@ -50,15 +62,21 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/twpayne/go-geom v1.6.1 // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect go.etcd.io/etcd/api/v3 v3.5.17 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.17 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.21.0 // indirect golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect + golang.org/x/mod v0.26.0 // indirect golang.org/x/net v0.43.0 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.28.0 // indirect + golang.org/x/tools v0.35.0 // indirect + golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect google.golang.org/grpc/examples v0.0.0-20230318005552-70c52915099a // indirect diff --git a/inferflow/go.sum b/inferflow/go.sum index 9cf3c4b6..aecb62b6 100644 --- a/inferflow/go.sum +++ b/inferflow/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/DataDog/datadog-go/v5 v5.5.0 h1:G5KHeB8pWBNXT4Jtw0zAkhdxEAWSpWH00geHI6LDrKU= github.com/DataDog/datadog-go/v5 v5.5.0/go.mod h1:K9kcYBlxkcPP8tvvjZZKs/m1edNAUFzBbdpTUKfCsuw= github.com/Meesho/BharatMLStack/helix-client v1.0.0-alpha-649f16 h1:fj2wKwLi5fL6fgG+oOI/J1nixQqMstVx36xAI4ddu9c= @@ -8,12 +10,20 @@ github.com/Meesho/BharatMLStack/helix-client v1.0.0-alpha-649f16/go.mod h1:Pw9Ry github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/alecthomas/assert/v2 v2.10.0 h1:jjRCHsj6hBJhkmhznrCzoNpbA3zqy0fYiUcYZP/GkPY= +github.com/alecthomas/assert/v2 v2.10.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/apache/arrow/go/v16 v16.1.0 h1:dwgfOya6s03CzH9JrjCBx6bkVb4yPD4ma3haj9p7FXI= +github.com/apache/arrow/go/v16 v16.1.0/go.mod h1:9wnc9mn6vEDTRIm4+27pEjQpRKuTvBaessPoEXQzxWA= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= @@ -91,6 +101,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-test/deep v1.0.2-0.20181118220953-042da051cf31/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -116,6 +128,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= +github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -169,6 +183,8 @@ github.com/hashicorp/vault/api v1.0.4/go.mod h1:gDcqh3WGcR1cpF5AJz/B1UFheUEneMoI github.com/hashicorp/vault/sdk v0.1.13/go.mod h1:B+hVj7TpuQY1Y/GPbCpffmgd+tSEwvhkWnjtSYCaS2M= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hjson/hjson-go/v4 v4.0.0 h1:wlm6IYYqHjOdXH1gHev4VoXCaW20HdQAGCxdOEEg2cs= github.com/hjson/hjson-go/v4 v4.0.0/go.mod h1:KaYt3bTw3zhBjYqnXkYywcYctk0A2nxeEFTse3rH13E= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= @@ -183,6 +199,10 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/knadh/koanf v1.5.0 h1:q2TSd/3Pyc/5yP9ldIrSdIz26MCcyNQzW0pEAugLPNs= github.com/knadh/koanf v1.5.0/go.mod h1:Hgyjp4y8v44hpZtPzs7JZfRAW5AhN7KfZcwv1RYggDs= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -240,6 +260,12 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/npillmayer/nestext v0.1.3/go.mod h1:h2lrijH8jpicr25dFY+oAJLyzlya6jhnuG+zWp9L0Uk= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= +github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA= +github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs= +github.com/parquet-go/jsonlite v1.0.0 h1:87QNdi56wOfsE5bdgas0vRzHPxfJgzrXGml1zZdd7VU= +github.com/parquet-go/jsonlite v1.0.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0= +github.com/parquet-go/parquet-go v0.27.0 h1:vHWK2xaHbj+v1DYps03yDRpEsdtOeKbhiXUaixoPb3g= +github.com/parquet-go/parquet-go v0.27.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= @@ -248,6 +274,8 @@ github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCko github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= +github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -287,6 +315,8 @@ github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgY github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/segmentio/kafka-go v0.4.50 h1:mcyC3tT5WeyWzrFbd6O374t+hmcu1NKt2Pu1L3QaXmc= +github.com/segmentio/kafka-go v0.4.50/go.mod h1:Y1gn60kzLEEaW28YshXyk2+VCUKbJ3Qr6DrnT3i4+9E= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= @@ -321,11 +351,25 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4= +github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.etcd.io/etcd/api/v3 v3.5.4/go.mod h1:5GB2vv4A4AOn3yk7MftYGHkUfGtDHnEraIjym4dYz5A= go.etcd.io/etcd/api/v3 v3.5.17 h1:cQB8eb8bxwuxOilBpMJAEo8fAONyrdXTHUNcMd8yT1w= go.etcd.io/etcd/api/v3 v3.5.17/go.mod h1:d1hvkRuXkts6PmaYk2Vrgqbv7H4ADfAKhyJqHNLJCB4= @@ -375,6 +419,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -406,6 +452,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -439,8 +487,8 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20181227161524-e6919f6577db/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -464,10 +512,14 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/inferflow/handlers/components/constants.go b/inferflow/handlers/components/constants.go index c2192603..243fc733 100644 --- a/inferflow/handlers/components/constants.go +++ b/inferflow/handlers/components/constants.go @@ -1,11 +1,12 @@ package components const ( - DataTypeString = "string" - errorType = "error-type" - compConfigErr = "invalid-component-config" - modelId = "model-id" - component = "component-name" - defaultCacheVersion = 1 - defaultCacheTtl = 100 + DataTypeString = "string" + errorType = "error-type" + compConfigErr = "invalid-component-config" + modelId = "model-id" + component = "component-name" + defaultCacheVersion = 1 + defaultCacheTtl = 100 + SlateTargetIndicesColumn = "slate_target_indices" ) diff --git a/inferflow/handlers/components/feature_init_component.go b/inferflow/handlers/components/feature_init_component.go index b85e40c9..a8a40950 100644 --- a/inferflow/handlers/components/feature_init_component.go +++ b/inferflow/handlers/components/feature_init_component.go @@ -42,6 +42,7 @@ func (fiComponent *FeatureInitComponent) Run(request interface{}) { componentMatrix := req.ComponentData componentMatrix.InitComponentMatrix(rowCount, stringColumnIndexMap, byteColumnIndexMap) populateStringData(componentMatrix, req) + initSlateMatrix(req) metrics.Timing("inferflow.component.execution.latency", time.Since(startTime), metricTags) } @@ -259,6 +260,9 @@ func extractPredatorColumns(byteColumnIndexMap map[string]matrix.Column, compCon if !ok { continue } + if predatorComp.SlateComponent { + continue + } for _, output := range predatorComp.Outputs { for scoreIdx, modelScore := range output.ModelScores { dataType := output.DataType @@ -284,6 +288,10 @@ func extractNumerixColumns(byteColumnIndexMap map[string]matrix.Column, compConf for _, comp := range compConfig.NumerixComponentConfig.Values() { numerixComp, ok := comp.(config.NumerixComponentConfig) if ok { + + if numerixComp.SlateComponent { + continue + } byteColumnIndexMap[numerixComp.ScoreColumn] = matrix.Column{ Name: numerixComp.ScoreColumn, DataType: numerixComp.DataType, @@ -312,3 +320,136 @@ func extractEntityColumns(stringColumnIndexMap map[string]matrix.Column, req Com stringColumnIndexMap[entity] = matrix.Column{Name: entity, DataType: DataTypeString, Index: idx} } } + +func initSlateMatrix(req ComponentRequest) { + if req.SlateData == nil { + return + } + slateRowCount := len(req.SlateData.Rows) + if slateRowCount == 0 { + return + } + + // Preserve adapter-populated slate string features before matrix re-init. + // InitComponentMatrix reallocates rows, so we need to restore these columns. + preservedSlateStrings := make(map[string][]string, len(req.SlateData.StringColumnIndexMap)) + for colName, col := range req.SlateData.StringColumnIndexMap { + values := make([]string, slateRowCount) + for i := 0; i < slateRowCount && i < len(req.SlateData.Rows); i++ { + if col.Index < len(req.SlateData.Rows[i].StringData) { + values[i] = req.SlateData.Rows[i].StringData[col.Index] + } + } + preservedSlateStrings[colName] = values + } + // Preserve adapter-populated slate byte features as well. + preservedSlateBytes := make(map[string][][]byte, len(req.SlateData.ByteColumnIndexMap)) + for colName, col := range req.SlateData.ByteColumnIndexMap { + values := make([][]byte, slateRowCount) + for i := 0; i < slateRowCount && i < len(req.SlateData.Rows); i++ { + if col.Index < len(req.SlateData.Rows[i].ByteData) { + values[i] = req.SlateData.Rows[i].ByteData[col.Index] + } + } + preservedSlateBytes[colName] = values + } + + slateByteColumns := buildSlateByteDataSchema(req.ComponentConfig, req.SlateData) + slateStringColumns := buildSlateStringDataSchema(req.SlateData) + + req.SlateData.InitComponentMatrix(slateRowCount, slateStringColumns, slateByteColumns) + + // Restore adapter-populated slate string features (including slate_target_indices and + // any additional slate features passed in the request). + for colName, values := range preservedSlateStrings { + req.SlateData.PopulateStringData(colName, values) + } + for colName, values := range preservedSlateBytes { + req.SlateData.PopulateByteData(colName, values) + } +} + +// buildSlateByteDataSchema creates byte columns for slate-level predator and iris outputs. +func buildSlateByteDataSchema(compConfig *config.ComponentConfig, slateData *matrix.ComponentMatrix) map[string]matrix.Column { + byteColumnIndexMap := make(map[string]matrix.Column) + index := 0 + // Preserve adapter-provided slate byte feature columns first. + if slateData != nil { + for colName, col := range slateData.ByteColumnIndexMap { + byteColumnIndexMap[colName] = matrix.Column{ + Name: colName, + DataType: col.DataType, + Index: index, + } + index++ + } + } + + // Slate predator output columns + for _, comp := range compConfig.PredatorComponentConfig.Values() { + predatorComp, ok := comp.(config.PredatorComponentConfig) + if !ok || !predatorComp.SlateComponent { + continue + } + for _, output := range predatorComp.Outputs { + for scoreIdx, modelScore := range output.ModelScores { + dataType := output.DataType + if len(output.ModelScoresDims) > scoreIdx && len(output.ModelScoresDims[scoreIdx]) > 0 { + dataType = extactPredatorOutputDataType(&output, output.ModelScoresDims[scoreIdx]) + } + if _, exists := byteColumnIndexMap[modelScore]; exists { + continue + } + byteColumnIndexMap[modelScore] = matrix.Column{ + Name: modelScore, + DataType: dataType, + Index: index, + } + index++ + } + } + } + + // Slate numerix output columns + for _, comp := range compConfig.NumerixComponentConfig.Values() { + irisComp, ok := comp.(config.NumerixComponentConfig) + if !ok || !irisComp.SlateComponent { + continue + } + if _, exists := byteColumnIndexMap[irisComp.ScoreColumn]; exists { + continue + } + byteColumnIndexMap[irisComp.ScoreColumn] = matrix.Column{ + Name: irisComp.ScoreColumn, + DataType: irisComp.DataType, + Index: index, + } + index++ + } + + return byteColumnIndexMap +} + +// buildSlateStringDataSchema creates string columns for SlateData. +// It always includes slate_target_indices and preserves any adapter-provided slate feature columns. +func buildSlateStringDataSchema(slateData *matrix.ComponentMatrix) map[string]matrix.Column { + stringColumnIndexMap := make(map[string]matrix.Column, len(slateData.StringColumnIndexMap)+1) + stringColumnIndexMap[SlateTargetIndicesColumn] = matrix.Column{ + Name: SlateTargetIndicesColumn, + DataType: DataTypeString, + Index: 0, + } + index := 1 + for colName := range slateData.StringColumnIndexMap { + if colName == SlateTargetIndicesColumn { + continue + } + stringColumnIndexMap[colName] = matrix.Column{ + Name: colName, + DataType: DataTypeString, + Index: index, + } + index++ + } + return stringColumnIndexMap +} diff --git a/inferflow/handlers/components/models.go b/inferflow/handlers/components/models.go index 5244f9f4..325c1572 100644 --- a/inferflow/handlers/components/models.go +++ b/inferflow/handlers/components/models.go @@ -7,6 +7,7 @@ import ( type ComponentRequest struct { ComponentData *matrix.ComponentMatrix + SlateData *matrix.ComponentMatrix // Slate-level matrix (one row per slate); nil when no slate components Entities *[]string EntityIds *[][]string ComponentConfig *config.ComponentConfig diff --git a/inferflow/handlers/components/numerix_component.go b/inferflow/handlers/components/numerix_component.go index b9eac888..fc1ddd4e 100644 --- a/inferflow/handlers/components/numerix_component.go +++ b/inferflow/handlers/components/numerix_component.go @@ -4,6 +4,9 @@ import ( "fmt" "time" + "strconv" + "strings" + "github.com/Meesho/BharatMLStack/inferflow/handlers/models" "github.com/Meesho/BharatMLStack/inferflow/pkg/matrix" @@ -54,6 +57,12 @@ func (iComponent *NumerixComponent) Run(request interface{}) { return } + if iConfig.SlateComponent && componentRequest.SlateData != nil { + iComponent.runSlate(componentRequest, iConfig, &matrixUtil, errLoggingPercent, metricTags) + metrics.Timing("inferflow.component.execution.latency", time.Duration(time.Since(t)), metricTags) + return + } + numerixComponentBuilder := &models.NumerixComponentBuilder{} initializeNumerixComponentBuilder(numerixComponentBuilder, &iConfig, &matrixUtil) @@ -72,6 +81,88 @@ func (iComponent *NumerixComponent) Run(request interface{}) { } } +// runSlate handles the slate-aware iris path. +// For each slate it makes a **separate inference request**: +// 1. Read slate_target_indices from SlateData (comma-separated row indices into the target matrix). +// 2. Build a per-slate matrix view containing only that slate's target rows. +// 3. Run the standard iris flow (init builder → populate matrix → score). +// 4. Write the per-slate score into the corresponding SlateData row. +func (iComponent *NumerixComponent) runSlate( + req ComponentRequest, + iConfig config.NumerixComponentConfig, + targetMatrix *matrix.ComponentMatrix, + errLoggingPercent int, + metricTags []string, +) { + slateMatrix := req.SlateData + slateRows := slateMatrix.Rows + numSlates := len(slateRows) + if numSlates == 0 { + return + } + + indicesCol, ok := slateMatrix.StringColumnIndexMap[SlateTargetIndicesColumn] + if !ok { + logger.Error("slate_target_indices column not found in SlateData", nil) + metrics.Count("inferflow.component.execution.error", 1, append(metricTags, errorType, compConfigErr)) + return + } + + // Accumulate one score per slate + slateScores := make([][]byte, numSlates) + + // Process each slate independently + for s, slateRow := range slateRows { + // Parse target indices for this slate + idxStr := slateRow.StringData[indicesCol.Index] + parts := strings.Split(idxStr, ",") + targetRows := make([]matrix.Row, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + idx, err := strconv.Atoi(p) + if err != nil || idx < 0 || idx >= len(targetMatrix.Rows) { + continue + } + targetRows = append(targetRows, targetMatrix.Rows[idx]) + } + if len(targetRows) == 0 { + continue + } + + // Per-slate matrix view (reuses target column maps, only this slate's target rows) + perSlateMatrix := &matrix.ComponentMatrix{ + StringColumnIndexMap: targetMatrix.StringColumnIndexMap, + ByteColumnIndexMap: targetMatrix.ByteColumnIndexMap, + Rows: targetRows, + } + + // Standard iris flow on the per-slate matrix + builder := &models.NumerixComponentBuilder{} + initializeNumerixComponentBuilder(builder, &iConfig, perSlateMatrix) + + builder.MatrixColumns = append(builder.MatrixColumns, iConfig.ComponentId) + builder.Schema = append(builder.Schema, iConfig.ComponentId) + for key, value := range iConfig.ScoreMapping { + builder.Schema = append(builder.Schema, key) + builder.MatrixColumns = append(builder.MatrixColumns, value) + } + + perSlateMatrix.PopulateMatrixOfColumnSlice(builder) + iComponent.populateScoreMap(req.ModelId, iConfig, builder.Schema, builder, errLoggingPercent) + + // Take the first output row as the slate-level score + if len(builder.Scores) > 0 { + slateScores[s] = builder.Scores[0] + } + } + + // Write accumulated slate scores to SlateData + slateMatrix.PopulateByteData(iConfig.ScoreColumn, slateScores) +} + func initializeNumerixComponentBuilder(numerixComponentBuilder *models.NumerixComponentBuilder, iConfig *config.NumerixComponentConfig, matrixUtil *matrix.ComponentMatrix) { numRow := len(matrixUtil.Rows) numCols := len(iConfig.ScoreMapping) + 1 diff --git a/inferflow/handlers/components/predator_component.go b/inferflow/handlers/components/predator_component.go index 0355d8a8..503c98a2 100644 --- a/inferflow/handlers/components/predator_component.go +++ b/inferflow/handlers/components/predator_component.go @@ -2,6 +2,7 @@ package components import ( "fmt" + "strconv" "strings" "sync" "sync/atomic" @@ -63,6 +64,13 @@ func (pComponent *PredatorComponent) Run(request interface{}) { logger.Error(fmt.Sprintf("Invalid component request for model-id %s and component %s ", modelID, pComponent.GetComponentName()), nil) return } + + if pConfig.SlateComponent && componentRequest.SlateData != nil { + pComponent.runSlate(componentRequest, pConfig, &matrixUtil, errLoggingPercent, metricTags) + metrics.Timing("inferflow.component.execution.latency", time.Since(t), metricTags) + return + } + // get payload for model predatorComponentBuilder := &models.PredatorComponentBuilder{} initializePredatorComponentBuilder(predatorComponentBuilder, &pConfig, &matrixUtil) @@ -83,6 +91,105 @@ func (pComponent *PredatorComponent) Run(request interface{}) { } } +// runSlate handles the slate-aware predator path. +// +// For each slate: +// 1. Parse slate_target_indices → pick target rows from the filled target matrix. +// 2. Build a per-slate matrix view with those rows. +// 3. Same predator flow: init builder → gather features → call predator. +// 4. Model returns 1 score for N inputs → write that score directly +// into SlateData row [s], score column. +func (pComponent *PredatorComponent) runSlate( + req ComponentRequest, + pConfig config.PredatorComponentConfig, + targetMatrix *matrix.ComponentMatrix, + errLoggingPercent int, + metricTags []string, +) { + slateMatrix := req.SlateData + numSlates := len(slateMatrix.Rows) + if numSlates == 0 { + return + } + + indicesCol, ok := slateMatrix.StringColumnIndexMap[SlateTargetIndicesColumn] + if !ok { + logger.Error(fmt.Sprintf("slate_target_indices column not found in SlateData for component %s", pComponent.GetComponentName()), nil) + metrics.Count("inferflow.component.execution.error", 1, append(metricTags, errorType, compConfigErr)) + return + } + + // --- Process each slate (parallel) --- + var wg sync.WaitGroup + for s := 0; s < numSlates; s++ { + wg.Add(1) + go func(s int) { + defer wg.Done() + defer func() { + if rec := recover(); rec != nil { + logger.Error(fmt.Sprintf("panic in slate %d for model-id %s component %s: %v", + s, req.ModelId, pComponent.GetComponentName(), rec), nil) + } + }() + + // 1. Parse target indices for this slate + targetRows := parseSlateTargetRows(slateMatrix.Rows[s], indicesCol.Index, targetMatrix) + if len(targetRows) == 0 { + return + } + + // 2. Build per-slate matrix view (same columns, subset of rows) + perSlateMatrix := &matrix.ComponentMatrix{ + StringColumnIndexMap: targetMatrix.StringColumnIndexMap, + ByteColumnIndexMap: targetMatrix.ByteColumnIndexMap, + Rows: targetRows, + } + + // 3. Same predator flow: init → gather features → call predator + builder := &models.PredatorComponentBuilder{} + initializePredatorComponentBuilder(builder, &pConfig, perSlateMatrix) + GetMatrixOfColumnSliceWithDataType(&pConfig, perSlateMatrix, builder, metricTags) + pComponent.populateScores(req.ModelId, pConfig, builder, errLoggingPercent) + + // 4. Write the score directly into SlateData row [s], for each score column. + // Model returns 1 score (builder.Scores[i][0]) for this slate. + counter := 0 + + for _, out := range pConfig.Outputs { + for _, scoreName := range out.ModelScores { + if counter < len(builder.Scores) && len(builder.Scores[counter]) > 0 { + if col, ok := slateMatrix.ByteColumnIndexMap[scoreName]; ok { + slateMatrix.Rows[s].ByteData[col.Index] = builder.Scores[counter][0] + } + } + counter++ + } + } + }(s) + } + wg.Wait() +} + +// parseSlateTargetRows reads the comma-separated target indices from a slate row +// and returns the corresponding rows from the target matrix. +func parseSlateTargetRows(slateRow matrix.Row, indicesColIdx int, targetMatrix *matrix.ComponentMatrix) []matrix.Row { + idxStr := slateRow.StringData[indicesColIdx] + parts := strings.Split(idxStr, ",") + rows := make([]matrix.Row, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + idx, err := strconv.Atoi(p) + if err != nil || idx < 0 || idx >= len(targetMatrix.Rows) { + continue + } + rows = append(rows, targetMatrix.Rows[idx]) + } + return rows +} + func GetMatrixOfColumnSliceWithDataType( pConfig *config.PredatorComponentConfig, matrixUtil *matrix.ComponentMatrix, diff --git a/inferflow/handlers/config/config_schema.go b/inferflow/handlers/config/config_schema.go new file mode 100644 index 00000000..c0d2ac83 --- /dev/null +++ b/inferflow/handlers/config/config_schema.go @@ -0,0 +1,289 @@ +package config + +import ( + "fmt" + "strings" + "sync" + + "github.com/Meesho/BharatMLStack/inferflow/internal/errors" + "github.com/Meesho/BharatMLStack/inferflow/pkg/logger" + "github.com/Meesho/BharatMLStack/inferflow/pkg/metrics" + "github.com/emirpasic/gods/maps/linkedhashmap" +) + +var ( + featureSchemaCache = make(map[string]*linkedhashmap.Map) + schemaCacheMu sync.RWMutex +) + +// InitAllFeatureSchema builds and caches feature schemas for all model configs. +// Should be called during config initialization and on config reload. +func InitAllFeatureSchema(modelConfig *ModelConfig) error { + if modelConfig == nil { + return &errors.ParsingError{ErrorMsg: "modelConfig is nil"} + } + if len(modelConfig.ConfigMap) == 0 { + return &errors.ParsingError{ErrorMsg: "modelConfig.ConfigMap is empty"} + } + + schemaCacheMu.Lock() + defer schemaCacheMu.Unlock() + + newCache := make(map[string]*linkedhashmap.Map) + for modelConfigID, config := range modelConfig.ConfigMap { + logger.Info(fmt.Sprintf("building feature schema for modelConfigID: %s", modelConfigID)) + schema, err := buildFeatureSchema(&config) + if err != nil { + metrics.Count("inferflow_config_schema_parsing_error", 1, []string{"model-id", modelConfigID}) + return &errors.ParsingError{ErrorMsg: "failed to build feature schema for modelConfigID " + modelConfigID + ": " + err.Error()} + } + newCache[modelConfigID] = schema + } + featureSchemaCache = newCache + return nil +} + +// GetFeatureSchema returns the cached feature schema for a given model config ID. +func GetFeatureSchema(modelConfigID string) (*linkedhashmap.Map, error) { + schemaCacheMu.RLock() + defer schemaCacheMu.RUnlock() + + schema, exists := featureSchemaCache[modelConfigID] + if !exists { + return nil, &errors.RequestError{ErrorMsg: "feature schema not found for modelConfigID: " + modelConfigID} + } + return schema, nil +} + +func buildFeatureSchema(config *Config) (*linkedhashmap.Map, error) { + if config == nil { + return nil, &errors.ParsingError{ErrorMsg: "config is nil"} + } + + existingFeatures := make(map[string]bool) + var response []SchemaComponents + + addUniqueComponents := func(components []SchemaComponents) { + for _, component := range components { + if !existingFeatures[component.FeatureName] { + response = append(response, component) + existingFeatures[component.FeatureName] = true + } + } + } + + addOrUpdateComponents := func(components []SchemaComponents) { + for _, component := range components { + if !existingFeatures[component.FeatureName] { + component.FeatureType = "String" + response = append(response, component) + existingFeatures[component.FeatureName] = true + } + } + } + + fsComponents := processLinkedHashMap[FeatureComponentConfig](config.ComponentConfig.FeatureComponentConfig) + predatorComponents := processLinkedHashMap[PredatorComponentConfig](config.ComponentConfig.PredatorComponentConfig) + numerixComponents := processLinkedHashMap[NumerixComponentConfig](config.ComponentConfig.NumerixComponentConfig) + + // 1. Feature Store components + addUniqueComponents(processFS(fsComponents)) + + // 2. Numerix Output + addUniqueComponents(processNumerixOutput(numerixComponents)) + + // 3. Predator Output + addUniqueComponents(processPredatorOutput(predatorComponents)) + + // 4. Numerix Input (only add if not already present) + addOrUpdateComponents(processNumerixInput(numerixComponents)) + + // 5. Predator Input (only add if not already present) + addOrUpdateComponents(processPredatorInput(predatorComponents)) + + responseSchemaComponents := processResponseConfig(config.ResponseConfig, response) + + result := linkedhashmap.New() + if config.ResponseConfig.LogFeatures { + for _, component := range responseSchemaComponents { + result.Put(component.FeatureName, component) + } + return result, nil + } + for _, component := range response { + result.Put(component.FeatureName, component) + } + return result, nil +} + +func processLinkedHashMap[T any](linkedHashMap linkedhashmap.Map) []T { + var response []T + for _, item := range linkedHashMap.Values() { + typedItem, ok := item.(T) + if ok { + response = append(response, typedItem) + } + } + return response +} + +func processNumerixOutput(numerixComponents []NumerixComponentConfig) []SchemaComponents { + if len(numerixComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, comp := range numerixComponents { + response = append(response, SchemaComponents{ + FeatureName: comp.ScoreColumn, + FeatureType: comp.DataType, + FeatureSize: 1, + }) + } + return response +} + +func processNumerixInput(numerixComponents []NumerixComponentConfig) []SchemaComponents { + if len(numerixComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, comp := range numerixComponents { + for input, featureName := range comp.ScoreMapping { + inputParts := strings.Split(input, "@") + if len(inputParts) < 2 { + continue + } + schemaComponent := SchemaComponents{ + FeatureName: featureName, + FeatureType: inputParts[1], + FeatureSize: 1, + } + response = append(response, schemaComponent) + } + } + return response +} + +func getFeatureName(prefix, entityLabel, fgLabel, feature string) string { + featureName := "" + if prefix != "" { + featureName = prefix + } + if entityLabel != "" { + featureName = featureName + entityLabel + ":" + } + if fgLabel != "" { + featureName = featureName + fgLabel + ":" + } + return featureName + feature +} + +func processFS(featureComponents []FeatureComponentConfig) []SchemaComponents { + if len(featureComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, featureComponent := range featureComponents { + for _, featureGroup := range featureComponent.FSRequest.FeatureGroups { + for _, feature := range featureGroup.Features { + response = append(response, SchemaComponents{ + FeatureName: getFeatureName(featureComponent.ColNamePrefix, featureComponent.FSRequest.Label, featureGroup.Label, feature), + FeatureType: featureGroup.DataType, + FeatureSize: 1, + }) + } + } + } + return response +} + +func processPredatorOutput(predatorComponents []PredatorComponentConfig) []SchemaComponents { + if len(predatorComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, predatorComponent := range predatorComponents { + for _, output := range predatorComponent.Outputs { + for index, modelScore := range output.ModelScores { + var featureSize any = 1 + dataType := output.DataType + if index < len(output.ModelScoresDims) { + featureSize, dataType = getPredatorFeatureTypeAndSize(output.DataType, output.ModelScoresDims[index]) + } + response = append(response, SchemaComponents{ + FeatureName: modelScore, + FeatureType: dataType, + FeatureSize: featureSize, + }) + } + } + } + return response +} + +func processPredatorInput(predatorComponents []PredatorComponentConfig) []SchemaComponents { + if len(predatorComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, predatorComponent := range predatorComponents { + for _, input := range predatorComponent.Inputs { + for _, feature := range input.Features { + size, dataType := getPredatorFeatureTypeAndSize(input.DataType, input.Shape) + response = append(response, SchemaComponents{ + FeatureName: feature, + FeatureType: dataType, + FeatureSize: size, + }) + } + } + } + return response +} + +func getPredatorFeatureTypeAndSize(dataType string, shape []int) (int, string) { + if len(shape) == 1 && shape[0] == 1 { + return 1, dataType + } + if len(shape) == 2 && shape[0] == -1 { + return shape[1], dataType + "Vector" + } + if len(shape) > 0 { + return shape[0], dataType + "Vector" + } + return 1, dataType +} + +func processResponseConfig(responseConfig ResponseConfig, schemaComponents []SchemaComponents) []SchemaComponents { + if len(responseConfig.Features) == 0 { + return nil + } + + var response []SchemaComponents + schemaMap := make(map[string]SchemaComponents) + for _, component := range schemaComponents { + schemaMap[component.FeatureName] = component + } + + for _, feature := range responseConfig.Features { + if existingComponent, exists := schemaMap[feature]; exists { + response = append(response, SchemaComponents{ + FeatureName: feature, + FeatureType: existingComponent.FeatureType, + FeatureSize: existingComponent.FeatureSize, + }) + } else { + response = append(response, SchemaComponents{ + FeatureName: feature, + FeatureType: "String", + FeatureSize: 1, + }) + } + } + return response +} diff --git a/inferflow/handlers/config/model_config.go b/inferflow/handlers/config/model_config.go index c67fabda..acf4528e 100644 --- a/inferflow/handlers/config/model_config.go +++ b/inferflow/handlers/config/model_config.go @@ -1,11 +1,6 @@ package config -import ( - "fmt" - - "github.com/Meesho/BharatMLStack/inferflow/internal/errors" - "github.com/Meesho/BharatMLStack/inferflow/pkg/logger" -) +import "fmt" var mConfig *ModelConfig @@ -13,24 +8,20 @@ func GetModelConfigMap() *ModelConfig { return mConfig } -func SetModelConfigMap(config *ModelConfig) { - mConfig = config -} - -func GetModelConfig(modelId string) (*Config, error) { - - configMap := GetModelConfigMap() - if len(configMap.ConfigMap) == 0 { - logger.Error("Error while fetching Inferflow config ", nil) - return nil, &errors.RequestError{ErrorMsg: "Error while fetching Inferflow config"} +// GetModelConfig returns the Config for a specific model config ID. +func GetModelConfig(modelConfigId string) (*Config, error) { + if mConfig == nil { + return nil, fmt.Errorf("model config map not initialised") } - config := configMap.ConfigMap[modelId] - isValid := validateModelConfig(&config) - if !isValid { - logger.Error(fmt.Sprintf("Invalid model config for modelId %s ", modelId), nil) - return nil, &errors.RequestError{ErrorMsg: fmt.Sprintf("Invalid model config for modelId %s ", modelId)} + conf, ok := mConfig.ConfigMap[modelConfigId] + if !ok || !validateModelConfig(&conf) { + return nil, fmt.Errorf("model config not found or invalid for id: %s", modelConfigId) } - return &config, nil + return &conf, nil +} + +func SetModelConfigMap(config *ModelConfig) { + mConfig = config } func validateModelConfig(c *Config) bool { diff --git a/inferflow/handlers/config/models.go b/inferflow/handlers/config/models.go index 2885e8d2..fee75abd 100644 --- a/inferflow/handlers/config/models.go +++ b/inferflow/handlers/config/models.go @@ -9,7 +9,19 @@ import ( ) type ModelConfig struct { - ConfigMap map[string]Config `json:"model_config_map"` + ConfigMap map[string]Config `json:"model_config_map"` + ServiceConfig ServiceConfig `json:"service-config"` +} + +type ServiceConfig struct { + V2LoggingType string `json:"v2-logging-type"` + CompressionEnabled bool `json:"compression-enabled"` +} + +type SchemaComponents struct { + FeatureName string `json:"feature_name"` + FeatureType string `json:"feature_type"` + FeatureSize any `json:"feature_size"` } type Config struct { @@ -79,6 +91,7 @@ type PredatorComponentConfig struct { Calibration string `json:"calibration"` Inputs []ModelInput `json:"inputs"` Outputs []ModelOutput `json:"outputs"` + SlateComponent bool `json:"slate_component"` // When true, outputs go to SlateData and inputs are gathered per-slate from the target matrix } type ModelEndpoint struct { @@ -101,12 +114,13 @@ type ModelOutput struct { } type NumerixComponentConfig struct { - Component string `json:"component"` - ComponentId string `json:"component_id"` - ScoreColumn string `json:"score_col"` - DataType string `json:"data_type"` - ScoreMapping map[string]string `json:"score_mapping"` - ComputeId string `json:"compute_id"` + Component string `json:"component"` + ComponentId string `json:"component_id"` + ScoreColumn string `json:"score_col"` + DataType string `json:"data_type"` + ScoreMapping map[string]string `json:"score_mapping"` + ComputeId string `json:"compute_id"` + SlateComponent bool `json:"slate_component"` // When true, outputs go to SlateData and inputs are gathered per-slate from the target matrix } func (c *ComponentConfig) UnmarshalJSON(data []byte) error { diff --git a/inferflow/handlers/external/kafka/kafka_logger.go b/inferflow/handlers/external/kafka/kafka_logger.go new file mode 100644 index 00000000..7a0ee4f2 --- /dev/null +++ b/inferflow/handlers/external/kafka/kafka_logger.go @@ -0,0 +1,82 @@ +package kafka + +import ( + "context" + "fmt" + "time" + + "github.com/Meesho/BharatMLStack/inferflow/pkg/configs" + "github.com/Meesho/BharatMLStack/inferflow/pkg/logger" + "github.com/Meesho/BharatMLStack/inferflow/pkg/metrics" + kafka "github.com/segmentio/kafka-go" + "google.golang.org/protobuf/proto" +) + +var ( + kafkaWriter *kafka.Writer +) + +func getMetricTags(metricTags []string, errType string) []string { + metricTags = append(metricTags, "error-type:"+errType) + return metricTags +} + +// InitKafkaLogger initializes Kafka writers for inference logging. +func InitKafkaLogger(appConfigs *configs.AppConfigs) { + bootstrapServers := appConfigs.Configs.KafkaBootstrapServers + if bootstrapServers == "" { + logger.Info("Kafka bootstrap servers not configured, inference logging disabled") + return + } + + if topic := appConfigs.Configs.KafkaLoggingTopic; topic != "" { + kafkaWriter = &kafka.Writer{ + Addr: kafka.TCP(bootstrapServers), + Topic: topic, + Balancer: &kafka.LeastBytes{}, + BatchTimeout: 10 * time.Millisecond, + Async: true, + } + logger.Info(fmt.Sprintf("Kafka V2 writer initialised for topic: %s", topic)) + } + + logger.Info("Kafka inference logger initialised") +} + +// publishInferenceInsightsLog sends inference insights (protobuf) to Kafka. +func publishInferenceInsightsLog(msg proto.Message, modelId string) { + if kafkaWriter == nil { + return + } + if msg == nil { + logger.Error("Empty proto message for V2 log", fmt.Errorf("model_id: %s", modelId)) + return + } + + data, err := proto.Marshal(msg) + metricTags := []string{"model-id", modelId} + if err != nil { + logger.Error("Error marshalling proto for V2 log:", err) + metrics.Count("inferflow.logging.error", 1, getMetricTags(metricTags, PROTO_MARSHAL_ERR)) + return + } + + if err := kafkaWriter.WriteMessages(context.Background(), kafka.Message{Value: data}); err != nil { + logger.Error("Error sending V2 log to Kafka:", err) + metrics.Count("inferflow.logging.error", 1, getMetricTags(metricTags, KAFKA_V2_ERR)) + return + } + + metrics.Count("inferflow.logging.kafka_sent", 1, metricTags) +} + +// PublishInferenceInsightsLog is kept as an exported alias for cross-package callers. +var PublishInferenceInsightsLog = publishInferenceInsightsLog + +func CloseKafkaLogger() { + if kafkaWriter != nil { + if err := kafkaWriter.Close(); err != nil { + logger.Error("Error closing Kafka V2 writer:", err) + } + } +} diff --git a/inferflow/handlers/external/kafka/models.go b/inferflow/handlers/external/kafka/models.go new file mode 100644 index 00000000..07d96f96 --- /dev/null +++ b/inferflow/handlers/external/kafka/models.go @@ -0,0 +1,7 @@ +package kafka + +const ( + ERROR_TYPE = "error-type" + KAFKA_V2_ERR = "kafka-v2-error" + PROTO_MARSHAL_ERR = "proto-marshal-error" +) diff --git a/inferflow/handlers/inferflow/feature_logging.go b/inferflow/handlers/inferflow/feature_logging.go new file mode 100644 index 00000000..e8b9169f --- /dev/null +++ b/inferflow/handlers/inferflow/feature_logging.go @@ -0,0 +1,546 @@ +package inferflow + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/rs/zerolog/log" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/Meesho/BharatMLStack/inferflow/handlers/components" + "github.com/Meesho/BharatMLStack/inferflow/handlers/config" + kafkaLogger "github.com/Meesho/BharatMLStack/inferflow/handlers/external/kafka" + "github.com/Meesho/BharatMLStack/inferflow/pkg/logger" + "github.com/Meesho/BharatMLStack/inferflow/pkg/metrics" + "github.com/Meesho/BharatMLStack/inferflow/pkg/utils" + + "github.com/Meesho/BharatMLStack/inferflow/pkg/datatypeconverter/types" + pb "github.com/Meesho/BharatMLStack/inferflow/server/grpc" + + parquet "github.com/parquet-go/parquet-go" + + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/array" + "github.com/apache/arrow/go/v16/arrow/ipc" + "github.com/apache/arrow/go/v16/arrow/memory" +) + +// Format type constants for the metadata byte. +const ( + FormatTypeProto = 0 // 00 + FormatTypeArrow = 1 // 01 + FormatTypeParquet = 2 // 10 +) + +// parquetRow represents a single logged row in Parquet format. +type parquetRow struct { + Features map[int][]byte +} + +// --- Metadata packing --- + +// packMetadataByte packs compression flag, version, and format type into a single byte. +// +// Bits 0-1: compression (00=off, 01=on) +// Bits 2-5: version (0-15) +// Bits 6-7: format (00=proto, 01=arrow, 10=parquet) +func packMetadataByte(compressionEnabled bool, version int, formatType int) byte { + var b byte + if compressionEnabled { + b |= 1 << 0 + } + b |= byte(version) << 2 + b |= byte(formatType) << 6 + return b +} + +// --- Common helpers --- + +// extractParentEntityFromHeaders extracts parent entity value from gRPC metadata. +func extractParentEntityFromHeaders(ctx context.Context) string { + if md, ok := metadata.FromIncomingContext(ctx); ok { + for _, key := range []string{"parent_catalog_id", "search_query", "clp_id", "collection_id"} { + if vals := md.Get(key); len(vals) > 0 && vals[0] != "" { + return vals[0] + } + } + } + return "" +} + +// buildMPLogBase creates the base InferflowLog structure with common fields. +func buildMPLogBase(ctx context.Context, userId, trackingId string, conf *config.Config, compRequest *components.ComponentRequest, formatType int) *pb.InferflowLog { + svcConfig := config.GetModelConfigMap().ServiceConfig + meta := packMetadataByte(svcConfig.CompressionEnabled, compRequest.ComponentConfig.CacheVersion, formatType) + parent := extractParentEntityFromHeaders(ctx) + + mpLog := &pb.InferflowLog{ + UserId: userId, + TrackingId: trackingId, + ModelConfigId: compRequest.ModelId, + Metadata: []byte{meta}, + } + if parent != "" { + mpLog.ParentEntity = []string{parent} + } + return mpLog +} + +// getBatchSize returns the batch size from config, defaulting to 500. +func getBatchSize(conf *config.Config) int { + if conf == nil || conf.ResponseConfig.LogBatchSize <= 0 { + return 500 + } + return conf.ResponseConfig.LogBatchSize +} + +// --- V2 Kafka transport --- + +// logInferenceBatchInsights wraps InferflowLog in KafkaRequest and sends to Kafka via V2 writer. +func logInferenceBatchInsights(mpLog *pb.InferflowLog, trackingId, modelId string) { + mpLogAny, err := anypb.New(mpLog) + if err != nil { + logger.Error("Error wrapping InferflowLog in Any", err) + metrics.Count("inferflow.logging.error", 1, []string{"model-id", modelId, "error", "wrapping_mp_log"}) + return + } + + now := time.Now() + kafkaReq := &pb.KafkaRequest{ + Value: &pb.KafkaEventValue{ + EventName: "inferflow_inference_logs", + EventId: trackingId, + CreatedAt: now.Unix(), + Properties: mpLogAny, + UserId: mpLog.UserId, + }, + } + + kafkaLogger.PublishInferenceInsightsLog(kafkaReq, modelId) +} + +// logInferenceInsights splits InferflowLog entities into batches and sends each batch. +func logInferenceInsights(mpLog *pb.InferflowLog, conf *config.Config, compRequest *components.ComponentRequest, trackingId string) { + batchSize := getBatchSize(conf) + total := len(mpLog.Entities) + + if total == 0 { + return + } + if total <= batchSize { + logInferenceBatchInsights(mpLog, trackingId, compRequest.ModelId) + return + } + + for i := 0; i < total; i += batchSize { + end := i + batchSize + if end > total { + end = total + } + batch := &pb.InferflowLog{ + UserId: mpLog.UserId, + TrackingId: mpLog.TrackingId, + ModelConfigId: mpLog.ModelConfigId, + Metadata: mpLog.Metadata, + ParentEntity: mpLog.ParentEntity, + Entities: mpLog.Entities[i:end], + Features: mpLog.Features[i:end], + } + logInferenceBatchInsights(batch, trackingId, compRequest.ModelId) + } +} + +// --- Proto V2 logging --- + +// logInferflowResponseBytes encodes features using proto format and sends to Kafka. +func logInferflowResponseBytes(ctx context.Context, userId, trackingId string, conf *config.Config, compRequest *components.ComponentRequest) { + defer func() { + if r := recover(); r != nil { + log.Error().Msgf("Recovered from panic in logInferflowResponseBytes: %v", r) + } + }() + + featureSchema, err := config.GetFeatureSchema(compRequest.ModelId) + if err != nil { + log.Error().Msgf("Error getting feature schema for model %s: %v", compRequest.ModelId, err) + metrics.Count("inferflow.logging.error", 1, []string{"model-id", compRequest.ModelId, "error", "schema_not_found"}) + return + } + + mpLog := buildMPLogBase(ctx, userId, trackingId, conf, compRequest, FormatTypeProto) + + featureKeys := featureSchema.Keys() + idColName := conf.ResponseConfig.Features[0] + col, ok := compRequest.ComponentData.StringColumnIndexMap[idColName] + if !ok { + log.Error().Msgf("Entity column not found for config: %s", compRequest.ModelId) + metrics.Count("inferflow.logging.error", 1, []string{"model-id", compRequest.ModelId, "error", "id_column_not_found"}) + return + } + + encodingErrors := 0 + for _, row := range compRequest.ComponentData.Rows { + valid := true + encoder := utils.NewFeatureEncoder(len(featureKeys) * 10) + + for _, keyIface := range featureKeys { + featureName := keyIface.(string) + + infoIface, ok := featureSchema.Get(featureName) + if !ok { + log.Error().Msgf("Feature info missing in schema: %s", featureName) + valid = false + break + } + + info := infoIface.(config.SchemaComponents) + featureType := info.FeatureType + isBytesType := utils.IsBytesDataType(featureType) + + var dataType types.DataType + if !isBytesType { + dataType, err = utils.StringToDataType(featureType) + if err != nil { + valid = false + break + } + } + + var featureValue []byte + var found bool + + if byteCol, exists := compRequest.ComponentData.ByteColumnIndexMap[featureName]; exists { + featureValue = row.ByteData[byteCol.Index] + found = len(featureValue) > 0 + } else if strCol, exists := compRequest.ComponentData.StringColumnIndexMap[featureName]; exists { + strVal := row.StringData[strCol.Index] + if strVal != "" { + featureValue, err = utils.ConvertStringToType(strVal, dataType) + if err == nil { + found = true + } + } + } + + if !found { + if dataType.IsVector() { + featureValue = []byte{} + } else { + featureValue = utils.GetDefaultValueByType(featureType) + } + encoder.MarkValueAsGenerated() + } + + if isBytesType { + if _, err = encoder.AppendBytesFeature(featureValue); err != nil { + valid = false + break + } + } else if _, err = encoder.AppendFeature(dataType, featureValue); err != nil { + valid = false + break + } + } + + if valid { + mpLog.Entities = append(mpLog.Entities, row.StringData[col.Index]) + mpLog.Features = append(mpLog.Features, &pb.PerEntityFeatures{EncodedFeatures: encoder.Bytes()}) + } else { + encodingErrors++ + } + } + + if encodingErrors > 0 { + metrics.Count("inferflow.logging.error", 1, []string{"model-id", compRequest.ModelId, "error", "encoding_error"}) + } else { + logInferenceInsights(mpLog, conf, compRequest, trackingId) + } +} + +// --- Arrow V2 logging --- + +// logInferflowResponseArrow encodes features into Arrow IPC format and sends to Kafka. +func logInferflowResponseArrow(ctx context.Context, userId, trackingId string, conf *config.Config, compRequest *components.ComponentRequest) { + defer func() { + if r := recover(); r != nil { + log.Error().Msgf("Recovered from panic in logInferflowResponseArrow: %v", r) + } + }() + + idColName := conf.ResponseConfig.Features[0] + col, ok := compRequest.ComponentData.StringColumnIndexMap[idColName] + if !ok { + log.Error().Msgf("Entity column not found for config: %s", compRequest.ModelId) + metrics.Count("inferflow.logging.error", 1, []string{"model-id", compRequest.ModelId, "error", "id_column_not_found"}) + return + } + + rec, _, entityIDs, err := buildColumnarRecord(compRequest, col.Index) + if err != nil { + logger.Error("Error building Arrow record", err) + return + } + if rec == nil { + return + } + defer rec.Release() + + batchSize := getBatchSize(conf) + totalRows := int(rec.NumRows()) + + for i := 0; i < totalRows; i += batchSize { + end := i + batchSize + if end > totalRows { + end = totalRows + } + + batchMPLog := buildMPLogBase(ctx, userId, trackingId, conf, compRequest, FormatTypeArrow) + batchRec := rec.NewSlice(int64(i), int64(end)) + defer batchRec.Release() + + var buf bytes.Buffer + w := ipc.NewWriter(&buf, ipc.WithSchema(batchRec.Schema())) + if err := w.Write(batchRec); err != nil { + logger.Error("Error writing Arrow record", err) + _ = w.Close() + continue + } + if err := w.Close(); err != nil { + logger.Error("Error closing Arrow IPC writer", err) + continue + } + + batchMPLog.Entities = entityIDs[i:end] + batchMPLog.Features = []*pb.PerEntityFeatures{{EncodedFeatures: buf.Bytes()}} + logInferenceBatchInsights(batchMPLog, trackingId, compRequest.ModelId) + } +} + +// buildColumnarRecord constructs an Arrow record with one binary column per feature. +func buildColumnarRecord(compRequest *components.ComponentRequest, entityColIndex int) (arrow.Record, time.Duration, []string, error) { + start := time.Now() + + featureSchema, err := config.GetFeatureSchema(compRequest.ModelId) + if err != nil { + logger.Error("Error getting feature schema from config", err) + metrics.Count("inferflow.logging.error", 1, []string{"model-id", compRequest.ModelId, "error", "schema_not_found"}) + return nil, 0, nil, err + } + + rawKeys := featureSchema.Keys() + if len(rawKeys) == 0 { + return nil, time.Since(start), nil, nil + } + + pool := memory.NewGoAllocator() + fields := make([]arrow.Field, 0, len(rawKeys)) + for i := range rawKeys { + fields = append(fields, arrow.Field{Name: fmt.Sprintf("%d", i), Type: arrow.BinaryTypes.Binary}) + } + schema := arrow.NewSchema(fields, nil) + + builders := make([]*array.BinaryBuilder, len(rawKeys)) + for i := range builders { + b := array.NewBinaryBuilder(pool, arrow.BinaryTypes.Binary) + builders[i] = b + defer b.Release() + } + + entityIDs := make([]string, 0, len(compRequest.ComponentData.Rows)) + + for _, row := range compRequest.ComponentData.Rows { + if entityColIndex >= 0 && entityColIndex < len(row.StringData) { + entityIDs = append(entityIDs, row.StringData[entityColIndex]) + } else { + continue + } + + for i, keyIface := range rawKeys { + fname := keyIface.(string) + var val []byte + featureType := "" + isBytesType := false + if infoIface, ok := featureSchema.Get(fname); ok { + if info, ok := infoIface.(config.SchemaComponents); ok { + featureType = info.FeatureType + isBytesType = utils.IsBytesDataType(featureType) + } + } + + if byteCol, exists := compRequest.ComponentData.ByteColumnIndexMap[fname]; exists { + if byteCol.Index >= 0 && byteCol.Index < len(row.ByteData) { + val = row.ByteData[byteCol.Index] + } + } else if strCol, exists := compRequest.ComponentData.StringColumnIndexMap[fname]; exists { + if strCol.Index >= 0 && strCol.Index < len(row.StringData) { + strVal := row.StringData[strCol.Index] + if strVal != "" { + if isBytesType { + // BYTES fields are already logical raw bytes; use string bytes directly. + val = []byte(strVal) + } else { + var dataType types.DataType + if featureType != "" { + if dt, err := utils.StringToDataType(featureType); err == nil { + dataType = dt + } + } + if dataType != 0 { + if converted, err := utils.ConvertStringToType(strVal, dataType); err == nil { + val = converted + } + } + } + } + } + } + + if len(val) == 0 { + builders[i].AppendNull() + } else { + builders[i].Append(val) + } + } + } + + columns := make([]arrow.Array, 0, len(rawKeys)) + for _, b := range builders { + columns = append(columns, b.NewArray()) + } + + var numRows int64 + if len(columns) > 0 { + numRows = int64(columns[0].Len()) + } + return array.NewRecord(schema, columns, numRows), time.Since(start), entityIDs, nil +} + +// --- Parquet V2 logging --- + +// logInferflowResponseParquet encodes features into Parquet format and sends to Kafka. +func logInferflowResponseParquet(ctx context.Context, userId, trackingId string, conf *config.Config, compRequest *components.ComponentRequest) { + defer func() { + if r := recover(); r != nil { + log.Error().Msgf("Recovered from panic in logInferflowResponseParquet: %v", r) + } + }() + + idColName := conf.ResponseConfig.Features[0] + col, ok := compRequest.ComponentData.StringColumnIndexMap[idColName] + if !ok { + log.Error().Msgf("Entity column not found for config: %s", compRequest.ModelId) + metrics.Count("inferflow.logging.error", 1, []string{"model-id", compRequest.ModelId, "error", "id_column_not_found"}) + return + } + + rows, _, _, entityIDs, err := buildParquetRows(compRequest, col.Index) + if err != nil { + logger.Error("Error building Parquet rows", err) + return + } + + batchSize := getBatchSize(conf) + totalRows := len(rows) + + for i := 0; i < totalRows; i += batchSize { + end := i + batchSize + if end > totalRows { + end = totalRows + } + + batchMPLog := buildMPLogBase(ctx, userId, trackingId, conf, compRequest, FormatTypeParquet) + + var buf bytes.Buffer + pw := parquet.NewGenericWriter[parquetRow](&buf) + if _, err := pw.Write(rows[i:end]); err != nil { + logger.Error("Error writing Parquet rows", err) + _ = pw.Close() + continue + } + if err := pw.Close(); err != nil { + logger.Error("Error closing Parquet writer", err) + continue + } + + batchMPLog.Entities = entityIDs[i:end] + batchMPLog.Features = []*pb.PerEntityFeatures{{EncodedFeatures: buf.Bytes()}} + logInferenceBatchInsights(batchMPLog, trackingId, compRequest.ModelId) + } +} + +// buildParquetRows constructs rows for parquet-go. +func buildParquetRows(compRequest *components.ComponentRequest, entityColIndex int) ([]parquetRow, int, time.Duration, []string, error) { + start := time.Now() + + featureSchema, err := config.GetFeatureSchema(compRequest.ModelId) + if err != nil { + logger.Error("Error getting feature schema from config", err) + metrics.Count("inferflow.logging.error", 1, []string{"model-id", compRequest.ModelId, "error", "schema_not_found"}) + return nil, 0, 0, nil, err + } + + rawKeys := featureSchema.Keys() + if len(rawKeys) == 0 { + return nil, 0, 0, nil, nil + } + + rows := make([]parquetRow, 0, len(compRequest.ComponentData.Rows)) + entityIDs := make([]string, 0, len(compRequest.ComponentData.Rows)) + + for _, perEntityRow := range compRequest.ComponentData.Rows { + if entityColIndex < 0 || entityColIndex >= len(perEntityRow.StringData) { + continue + } + entityIDs = append(entityIDs, perEntityRow.StringData[entityColIndex]) + + row := parquetRow{Features: make(map[int][]byte, len(rawKeys))} + + for i, keyIface := range rawKeys { + fname := keyIface.(string) + var val []byte + featureType := "" + isBytesType := false + if infoIface, ok := featureSchema.Get(fname); ok { + if info, ok := infoIface.(config.SchemaComponents); ok { + featureType = info.FeatureType + isBytesType = utils.IsBytesDataType(featureType) + } + } + + if byteCol, exists := compRequest.ComponentData.ByteColumnIndexMap[fname]; exists { + if byteCol.Index >= 0 && byteCol.Index < len(perEntityRow.ByteData) { + val = perEntityRow.ByteData[byteCol.Index] + } + } else if strCol, exists := compRequest.ComponentData.StringColumnIndexMap[fname]; exists { + if strCol.Index >= 0 && strCol.Index < len(perEntityRow.StringData) { + strVal := perEntityRow.StringData[strCol.Index] + if strVal != "" { + if isBytesType { + // BYTES fields are already logical raw bytes; use string bytes directly. + val = []byte(strVal) + } else if featureType != "" { + if dt, err := utils.StringToDataType(featureType); err == nil && dt != 0 { + if converted, err := utils.ConvertStringToType(strVal, dt); err == nil { + val = converted + } + } + } + } + } + } + + if len(val) == 0 { + row.Features[i] = nil + } else { + row.Features[i] = val + } + } + + rows = append(rows, row) + } + + return rows, len(rawKeys), time.Since(start), entityIDs, nil +} diff --git a/inferflow/handlers/inferflow/inferflow.go b/inferflow/handlers/inferflow/inferflow.go index 550441cc..bbd379e0 100644 --- a/inferflow/handlers/inferflow/inferflow.go +++ b/inferflow/handlers/inferflow/inferflow.go @@ -17,6 +17,7 @@ import ( extCache "github.com/Meesho/BharatMLStack/inferflow/dag-topology-executor/pkg/cache" "github.com/Meesho/BharatMLStack/inferflow/handlers/components" "github.com/Meesho/BharatMLStack/inferflow/handlers/config" + kafkaLogger "github.com/Meesho/BharatMLStack/inferflow/handlers/external/kafka" "github.com/Meesho/BharatMLStack/inferflow/internal/errors" "github.com/Meesho/BharatMLStack/inferflow/pkg/logger" "github.com/Meesho/BharatMLStack/inferflow/pkg/utils" @@ -53,6 +54,11 @@ func InitInferflowHandler(configs *configs.AppConfigs) { InferflowConfig := etcd.Instance().GetConfigInstance().(*config.ModelConfig) config.SetModelConfigMap(InferflowConfig) + // Build and cache feature schemas for all model configs (needed for V2 logging) + if err := config.InitAllFeatureSchema(InferflowConfig); err != nil { + logger.Error("Error initializing feature schemas", err) + } + // register components in config map componentProvider.RegisterComponent(config.GetModelConfigMap()) @@ -66,6 +72,10 @@ func InitInferflowHandler(configs *configs.AppConfigs) { }, }, } + + // Initialize Kafka writers for inference logging + kafkaLogger.InitKafkaLogger(configs) + logger.Info("Inferflow handler initialized") } @@ -74,6 +84,11 @@ func ReloadModelConfigMapAndRegisterComponents() error { if ok { config.SetModelConfigMap(updatedConfig) + // Rebuild feature schemas on config reload + if err := config.InitAllFeatureSchema(updatedConfig); err != nil { + logger.Error("Error reinitializing feature schemas on reload", err) + } + // register components in config map componentProvider.RegisterComponent(updatedConfig) @@ -123,6 +138,22 @@ func (s *Inferflow) RetrieveModelScore(ctx context.Context, req *pb.InferflowReq executor.Execute(conf.DAGExecutionConfig.ComponentDependency, componentReq) InferflowRes := prepareInferflowResponse(userId, conf, &componentReq) + // Inference logging: log response features based on config + if !utils.IsNilOrEmpty(conf.ResponseConfig) && + utils.IsEnableForUserForToday(userId, conf.ResponseConfig.LoggingPerc) && + InferflowReq.TrackingId != "" { + modelConfigMap := config.GetModelConfigMap() + // V2 logging: route based on configured format type + switch modelConfigMap.ServiceConfig.V2LoggingType { + case "proto": + go logInferflowResponseBytes(ctx, userId, InferflowReq.TrackingId, conf, &componentReq) + case "arrow": + go logInferflowResponseArrow(ctx, userId, InferflowReq.TrackingId, conf, &componentReq) + case "parquet": + go logInferflowResponseParquet(ctx, userId, InferflowReq.TrackingId, conf, &componentReq) + } + } + res := processInferflowResponse(InferflowRes, nil) responseTime := time.Since(startTime) diff --git a/inferflow/handlers/inferflow/predict_adapter.go b/inferflow/handlers/inferflow/predict_adapter.go new file mode 100644 index 00000000..1ada9f03 --- /dev/null +++ b/inferflow/handlers/inferflow/predict_adapter.go @@ -0,0 +1,337 @@ +package inferflow + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Meesho/BharatMLStack/inferflow/handlers/components" + "github.com/Meesho/BharatMLStack/inferflow/handlers/config" + "github.com/Meesho/BharatMLStack/inferflow/pkg/matrix" + pb "github.com/Meesho/BharatMLStack/inferflow/server/grpc/predict" +) + +// adaptPointWiseRequest translates a PointWiseRequest into the existing ComponentRequest. +// Each Target becomes a row in the ComponentMatrix. Context features are broadcast to all rows. +func adaptPointWiseRequest(req *pb.PointWiseRequest, conf *config.Config, headers map[string]string) (*components.ComponentRequest, error) { + numTargets := len(req.Targets) + if numTargets == 0 { + return nil, fmt.Errorf("pointwise request must have at least one target") + } + + firstEntity := firstEntityFromConfig(conf) + entities, entityIds, features := buildEntitiesFromTargets(req.Targets, req.ContextFeatures, req.TargetInputSchema, firstEntity) + + return &components.ComponentRequest{ + ComponentData: &matrix.ComponentMatrix{}, + Entities: &entities, + EntityIds: &entityIds, + Features: &features, + ComponentConfig: &conf.ComponentConfig, + ModelId: req.ModelConfigId, + Headers: headers, + }, nil +} + +// adaptPairWiseRequest translates a PairWiseRequest into a slate-style ComponentRequest. +// A pair is a slate of exactly 2 targets. The adapter converts TargetPair entries into +// TargetSlate entries (target_indices = [first, second]) and delegates to the standard +// slate data builder. The downstream slate pipeline handles scoring without any changes. +func adaptPairWiseRequest(req *pb.PairWiseRequest, conf *config.Config, headers map[string]string) (*components.ComponentRequest, error) { + if len(req.Pairs) == 0 { + return nil, fmt.Errorf("pairwise request must have at least one pair") + } + if len(req.Targets) == 0 { + return nil, fmt.Errorf("pairwise request must have at least one target") + } + if err := validatePairIndices(req.Pairs, len(req.Targets)); err != nil { + return nil, err + } + + firstEntity := firstEntityFromConfig(conf) + + // Main matrix: one row per target (for per-target feature fetch) + entities, entityIds, features := buildEntitiesFromTargets( + req.Targets, req.ContextFeatures, req.TargetInputSchema, firstEntity, + ) + + // Convert pairs → slates (each pair is a slate of size 2) + slates := pairsToSlates(req.Pairs) + + // Build SlateData with slate_target_indices and pair-level features + slateData := buildSlateData(req.Targets, slates, req.PairInputSchema, req.ContextFeatures) + + return &components.ComponentRequest{ + ComponentData: &matrix.ComponentMatrix{}, + SlateData: slateData, + Entities: &entities, + EntityIds: &entityIds, + Features: &features, + ComponentConfig: &conf.ComponentConfig, + ModelId: req.ModelConfigId, + Headers: headers, + }, nil +} + +// pairsToSlates converts TargetPair entries to TargetSlate entries. +// Each pair becomes a slate with exactly 2 target indices. +func pairsToSlates(pairs []*pb.TargetPair) []*pb.TargetSlate { + slates := make([]*pb.TargetSlate, len(pairs)) + for i, pair := range pairs { + slates[i] = &pb.TargetSlate{ + TargetIndices: []int32{pair.FirstTargetIndex, pair.SecondTargetIndex}, + FeatureValues: pair.FeatureValues, + } + } + return slates +} + +// adaptSlateWiseRequest translates a SlateWiseRequest into a ComponentRequest. +// The main matrix has one row per target (so feature store can fill per-target features). +// SlateData has one row per slate with slate_target_indices pointing into the main matrix, +// plus any slate-level features. Slate components read from main matrix and write to SlateData. +func adaptSlateWiseRequest(req *pb.SlateWiseRequest, conf *config.Config, headers map[string]string) (*components.ComponentRequest, error) { + numSlates := len(req.Slates) + if numSlates == 0 { + return nil, fmt.Errorf("slatewise request must have at least one slate") + } + if len(req.Targets) == 0 { + return nil, fmt.Errorf("slatewise request must have at least one target") + } + if err := validateSlateIndices(req.Slates, len(req.Targets)); err != nil { + return nil, err + } + + firstEntity := firstEntityFromConfig(conf) + + // Main matrix: one row per target (for per-target feature fetch) + entities, entityIds, features := buildEntitiesFromTargets( + req.Targets, req.ContextFeatures, req.TargetInputSchema, firstEntity, + ) + + // Build SlateData with slate_target_indices and slate-level features + slateData := buildSlateData(req.Targets, req.Slates, req.SlateInputSchema, req.ContextFeatures) + + return &components.ComponentRequest{ + ComponentData: &matrix.ComponentMatrix{}, + SlateData: slateData, + Entities: &entities, + EntityIds: &entityIds, + Features: &features, + ComponentConfig: &conf.ComponentConfig, + ModelId: req.ModelConfigId, + Headers: headers, + }, nil +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// firstEntityFromConfig returns the first entity/key column name from the ETCD DAG response config. +// It uses ResponseConfig.Features[0] when set; otherwise the first feature component's first FSKey.Column. +func firstEntityFromConfig(conf *config.Config) string { + if conf != nil && len(conf.ResponseConfig.Features) > 0 { + return conf.ResponseConfig.Features[0] + } + return "id" // fallback for legacy configs +} + +// buildEntitiesFromTargets creates the entity/feature structures for pointwise. +// The first entity dimension is the target ID (entity name from config); target and context features are added to the features map. +func buildEntitiesFromTargets(targets []*pb.Target, contextFeatures []*pb.ContextFeature, targetInputSchema []*pb.FeatureSchema, firstEntity string) ([]string, [][]string, map[string][]string) { + entities := []string{firstEntity} + targetIds := make([]string, len(targets)) + for i, t := range targets { + targetIds[i] = t.Id + } + entityIds := [][]string{targetIds} + + features := make(map[string][]string) + + // Target-level features (one column per schema field, values from each target's feature_values) + for schemaIdx, schema := range targetInputSchema { + vals := make([]string, len(targets)) + for i, t := range targets { + if schemaIdx < len(t.FeatureValues) { + vals[i] = string(t.FeatureValues[schemaIdx]) + } + } + features[schema.Name] = vals + } + + // Broadcast context features to all rows + for _, cf := range contextFeatures { + vals := make([]string, len(targets)) + strVal := string(cf.Value) + for i := range vals { + vals[i] = strVal + } + features[cf.Name] = vals + } + + return entities, entityIds, features +} + +// buildSlateData creates a pre-populated SlateData ComponentMatrix for slatewise requests. +// Each row represents one slate. The slate_target_indices string column holds comma-separated +// row indices into the main target matrix. Slate-level features are stored as string columns. +// feature_init will later call InitComponentMatrix to set up byte columns for slate outputs. +func buildSlateData( + targets []*pb.Target, + slates []*pb.TargetSlate, + slateSchema []*pb.FeatureSchema, + contextFeatures []*pb.ContextFeature, +) *matrix.ComponentMatrix { + numSlates := len(slates) + + // Build slate_target_indices: comma-separated target row indices + slateTargetIndices := make([]string, numSlates) + for i, slate := range slates { + parts := make([]string, len(slate.TargetIndices)) + for j, idx := range slate.TargetIndices { + parts[j] = strconv.Itoa(int(idx)) + } + slateTargetIndices[i] = strings.Join(parts, ",") + } + + // Build typed string/byte column schemas and value buffers. + stringCols := make(map[string]matrix.Column) + byteCols := make(map[string]matrix.Column) + stringValues := make(map[string][]string) + byteValues := make(map[string][][]byte) + stringColIdx := 0 + byteColIdx := 0 + + addStringCol := func(name string) { + if _, exists := stringCols[name]; exists { + return + } + stringCols[name] = matrix.Column{ + Name: name, + DataType: "string", + Index: stringColIdx, + } + stringColIdx++ + } + addByteCol := func(name string, dt pb.DataType) { + if _, exists := byteCols[name]; exists { + return + } + byteCols[name] = matrix.Column{ + Name: name, + DataType: dt.String(), + Index: byteColIdx, + } + byteColIdx++ + } + + // Always keep slate_target_indices as string column. + addStringCol(components.SlateTargetIndicesColumn) + stringValues[components.SlateTargetIndicesColumn] = slateTargetIndices + + // Slate-level features from TargetSlate.feature_values using slate_input_schema data types. + for schemaIdx, schema := range slateSchema { + if schema.GetDataType() == pb.DataType_DataTypeString { + addStringCol(schema.Name) + vals := make([]string, numSlates) + for i, slate := range slates { + if schemaIdx < len(slate.FeatureValues) { + vals[i] = string(slate.FeatureValues[schemaIdx]) + } + } + stringValues[schema.Name] = vals + continue + } + + addByteCol(schema.Name, schema.GetDataType()) + vals := make([][]byte, numSlates) + for i, slate := range slates { + if schemaIdx < len(slate.FeatureValues) { + vals[i] = slate.FeatureValues[schemaIdx] + } + } + byteValues[schema.Name] = vals + } + + // Context features broadcast to all slate rows, typed by context feature schema. + for _, cf := range contextFeatures { + if cf.GetDataType() == pb.DataType_DataTypeString { + addStringCol(cf.Name) + vals := make([]string, numSlates) + strVal := string(cf.Value) + for i := range vals { + vals[i] = strVal + } + stringValues[cf.Name] = vals + continue + } + + addByteCol(cf.Name, cf.GetDataType()) + vals := make([][]byte, numSlates) + for i := range vals { + vals[i] = cf.Value + } + byteValues[cf.Name] = vals + } + + // Allocate rows with both string and byte schemas. + slateMatrix := &matrix.ComponentMatrix{} + slateMatrix.InitComponentMatrix(numSlates, stringCols, byteCols) + + // Populate preserved string features. + for colName, vals := range stringValues { + slateMatrix.PopulateStringData(colName, vals) + } + + // Populate typed byte features. + for colName, vals := range byteValues { + slateMatrix.PopulateByteData(colName, vals) + } + + return slateMatrix +} + +func validatePairIndices(pairs []*pb.TargetPair, targetCount int) error { + for i, pair := range pairs { + if pair == nil { + return fmt.Errorf("pairwise request has nil pair at index %d", i) + } + + firstIndex := int(pair.FirstTargetIndex) + if firstIndex < 0 || firstIndex >= targetCount { + return fmt.Errorf( + "pairwise request has out-of-range first_target_index=%d at pair index %d (target count=%d)", + pair.FirstTargetIndex, i, targetCount, + ) + } + + secondIndex := int(pair.SecondTargetIndex) + if secondIndex < 0 || secondIndex >= targetCount { + return fmt.Errorf( + "pairwise request has out-of-range second_target_index=%d at pair index %d (target count=%d)", + pair.SecondTargetIndex, i, targetCount, + ) + } + } + return nil +} + +func validateSlateIndices(slates []*pb.TargetSlate, targetCount int) error { + for slateIdx, slate := range slates { + if slate == nil { + return fmt.Errorf("slatewise request has nil slate at index %d", slateIdx) + } + + for targetIdxPos, targetIdx := range slate.TargetIndices { + index := int(targetIdx) + if index < 0 || index >= targetCount { + return fmt.Errorf( + "slatewise request has out-of-range target index=%d at slates[%d].target_indices[%d] (target count=%d)", + targetIdx, slateIdx, targetIdxPos, targetCount, + ) + } + } + } + return nil +} diff --git a/inferflow/handlers/inferflow/predict_handler.go b/inferflow/handlers/inferflow/predict_handler.go new file mode 100644 index 00000000..0579d344 --- /dev/null +++ b/inferflow/handlers/inferflow/predict_handler.go @@ -0,0 +1,127 @@ +package inferflow + +import ( + "context" + "fmt" + "math/rand" + "time" + + "github.com/Meesho/BharatMLStack/inferflow/handlers/components" + "github.com/Meesho/BharatMLStack/inferflow/handlers/config" + "github.com/Meesho/BharatMLStack/inferflow/pkg/logger" + "github.com/Meesho/BharatMLStack/inferflow/pkg/metrics" + pb "github.com/Meesho/BharatMLStack/inferflow/server/grpc/predict" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// PredictService implements the Predict gRPC service with PointWise, PairWise, and SlateWise RPCs. +// It reuses the existing DAG executor and component pipeline via thin adapters. +type PredictService struct { + pb.UnimplementedPredictServer +} + +func (s *PredictService) InferPointWise(ctx context.Context, req *pb.PointWiseRequest) (*pb.PointWiseResponse, error) { + startTime := time.Now() + tags := []string{"model-id", req.ModelConfigId, "inference-mode", "pointwise"} + metrics.Count("predict.infer.request.total", 1, tags) + + conf, err := config.GetModelConfig(req.ModelConfigId) + if err != nil { + logger.Error(fmt.Sprintf("InferPointWise: config not found for %s", req.ModelConfigId), err) + return &pb.PointWiseResponse{RequestError: err.Error()}, status.Error(codes.InvalidArgument, err.Error()) + } + + headers := getAllHeaders(ctx) + componentReq, err := adaptPointWiseRequest(req, conf, headers) + if err != nil { + logger.Error(fmt.Sprintf("InferPointWise: adapter error for %s", req.ModelConfigId), err) + return &pb.PointWiseResponse{RequestError: err.Error()}, status.Error(codes.InvalidArgument, err.Error()) + } + + executor.Execute(conf.DAGExecutionConfig.ComponentDependency, *componentReq) + + resp := buildPointWiseResponse(componentReq.ComponentData, conf) + maybeLogInferenceResponse(ctx, req.TrackingId, conf, componentReq) + + metrics.Timing("predict.infer.request.latency", time.Since(startTime), tags) + metrics.Count("predict.infer.request.batch.size", int64(len(req.Targets)), tags) + + return resp, nil +} + +func (s *PredictService) InferPairWise(ctx context.Context, req *pb.PairWiseRequest) (*pb.PairWiseResponse, error) { + startTime := time.Now() + tags := []string{"model-id", req.ModelConfigId, "inference-mode", "pairwise"} + metrics.Count("predict.infer.request.total", 1, tags) + + conf, err := config.GetModelConfig(req.ModelConfigId) + if err != nil { + logger.Error(fmt.Sprintf("InferPairWise: config not found for %s", req.ModelConfigId), err) + return &pb.PairWiseResponse{RequestError: err.Error()}, status.Error(codes.InvalidArgument, err.Error()) + } + + headers := getAllHeaders(ctx) + componentReq, err := adaptPairWiseRequest(req, conf, headers) + if err != nil { + logger.Error(fmt.Sprintf("InferPairWise: adapter error for %s", req.ModelConfigId), err) + return &pb.PairWiseResponse{RequestError: err.Error()}, status.Error(codes.InvalidArgument, err.Error()) + } + + executor.Execute(conf.DAGExecutionConfig.ComponentDependency, *componentReq) + + resp := buildPairWiseResponse(componentReq.ComponentData, componentReq.SlateData, conf) + maybeLogInferenceResponse(ctx, req.TrackingId, conf, componentReq) + + metrics.Timing("predict.infer.request.latency", time.Since(startTime), tags) + metrics.Count("predict.infer.request.batch.size", int64(len(req.Pairs)), tags) + + return resp, nil +} + +func (s *PredictService) InferSlateWise(ctx context.Context, req *pb.SlateWiseRequest) (*pb.SlateWiseResponse, error) { + startTime := time.Now() + tags := []string{"model-id", req.ModelConfigId, "inference-mode", "slatewise"} + metrics.Count("predict.infer.request.total", 1, tags) + + conf, err := config.GetModelConfig(req.ModelConfigId) + if err != nil { + logger.Error(fmt.Sprintf("InferSlateWise: config not found for %s", req.ModelConfigId), err) + return &pb.SlateWiseResponse{RequestError: err.Error()}, status.Error(codes.InvalidArgument, err.Error()) + } + + headers := getAllHeaders(ctx) + componentReq, err := adaptSlateWiseRequest(req, conf, headers) + if err != nil { + logger.Error(fmt.Sprintf("InferSlateWise: adapter error for %s", req.ModelConfigId), err) + return &pb.SlateWiseResponse{RequestError: err.Error()}, status.Error(codes.InvalidArgument, err.Error()) + } + + executor.Execute(conf.DAGExecutionConfig.ComponentDependency, *componentReq) + + resp := buildSlateWiseResponse(componentReq.ComponentData, componentReq.SlateData, conf) + maybeLogInferenceResponse(ctx, req.TrackingId, conf, componentReq) + + metrics.Timing("predict.infer.request.latency", time.Since(startTime), tags) + metrics.Count("predict.infer.request.batch.size", int64(len(req.Slates)), tags) + + return resp, nil +} + +func maybeLogInferenceResponse(ctx context.Context, trackingID string, conf *config.Config, componentReq *components.ComponentRequest) { + // Inference logging: log response features based on config + if conf.ResponseConfig.LoggingPerc > 0 && + rand.Intn(100)+1 <= conf.ResponseConfig.LoggingPerc && + trackingID != "" { + modelConfigMap := config.GetModelConfigMap() + // V2 logging: route based on configured format type + switch modelConfigMap.ServiceConfig.V2LoggingType { + case "proto": + go logInferflowResponseBytes(ctx, "", trackingID, conf, componentReq) + case "arrow": + go logInferflowResponseArrow(ctx, "", trackingID, conf, componentReq) + case "parquet": + go logInferflowResponseParquet(ctx, "", trackingID, conf, componentReq) + } + } +} diff --git a/inferflow/handlers/inferflow/predict_response.go b/inferflow/handlers/inferflow/predict_response.go new file mode 100644 index 00000000..590cb968 --- /dev/null +++ b/inferflow/handlers/inferflow/predict_response.go @@ -0,0 +1,301 @@ +package inferflow + +import ( + "github.com/Meesho/BharatMLStack/inferflow/handlers/config" + "github.com/Meesho/BharatMLStack/inferflow/pkg/datatypeconverter/typeconverter" + "github.com/Meesho/BharatMLStack/inferflow/pkg/matrix" + pb "github.com/Meesho/BharatMLStack/inferflow/server/grpc/predict" +) + +// buildPointWiseResponse extracts scores from the ComponentMatrix into a PointWiseResponse. +// Each matrix row maps to one TargetScore. Output columns are determined by ResponseConfig. +func buildPointWiseResponse(componentMatrix *matrix.ComponentMatrix, conf *config.Config) *pb.PointWiseResponse { + outputColumns := conf.ResponseConfig.Features + outputSchema := buildOutputSchema(outputColumns, componentMatrix) + + scores := make([]*pb.TargetScore, len(componentMatrix.Rows)) + for i, row := range componentMatrix.Rows { + scores[i] = &pb.TargetScore{ + OutputValues: extractRowOutputBytes(row, outputColumns, outputSchema, componentMatrix), + } + } + + return &pb.PointWiseResponse{ + TargetOutputSchema: outputSchema, + TargetScores: scores, + } +} + +// buildPairWiseResponse extracts target scores from ComponentData and pair scores from SlateData. +// A pair is a slate of size 2, so pair-level outputs come from SlateData (slate component outputs) +// and target-level outputs come from the main ComponentMatrix. +func buildPairWiseResponse(componentMatrix *matrix.ComponentMatrix, slateData *matrix.ComponentMatrix, conf *config.Config) *pb.PairWiseResponse { + resp := &pb.PairWiseResponse{} + + // Target-level scores from the main matrix (one per target row) + targetOutputColumns := conf.ResponseConfig.Features + slateCols := slateOutputColumnNames(conf) + if len(slateCols) > 0 { + targetOutputColumns = excludeColumns(targetOutputColumns, slateCols) + } + if len(targetOutputColumns) > 0 && len(componentMatrix.Rows) > 0 { + targetOutputSchema := buildOutputSchema(targetOutputColumns, componentMatrix) + targetScores := make([]*pb.TargetScore, len(componentMatrix.Rows)) + for i, row := range componentMatrix.Rows { + targetScores[i] = &pb.TargetScore{ + OutputValues: extractRowOutputBytes(row, targetOutputColumns, targetOutputSchema, componentMatrix), + } + } + resp.TargetScores = targetScores + resp.TargetOutputSchema = targetOutputSchema + } + + // Pair-level scores from SlateData (one per pair/slate row) + if slateData != nil && len(slateData.Rows) > 0 { + pairOutputColumns := includeColumns(slateCols, conf.ResponseConfig.Features) + if len(pairOutputColumns) > 0 { + pairOutputSchema := buildOutputSchema(pairOutputColumns, slateData) + pairScores := make([]*pb.PairScore, len(slateData.Rows)) + for i, row := range slateData.Rows { + pairScores[i] = &pb.PairScore{ + OutputValues: extractRowOutputBytes(row, pairOutputColumns, pairOutputSchema, slateData), + } + } + resp.PairScores = pairScores + resp.PairOutputSchema = pairOutputSchema + } + } + + return resp +} + +// buildSlateWiseResponse extracts slate scores from SlateData and target scores from the main +// ComponentData. Slate output columns come from slate predator/iris configs; target output +// columns come from ResponseConfig.Features. +func buildSlateWiseResponse(componentMatrix *matrix.ComponentMatrix, slateData *matrix.ComponentMatrix, conf *config.Config) *pb.SlateWiseResponse { + resp := &pb.SlateWiseResponse{} + + // Target-level scores from the main matrix (one per target row) + targetOutputColumns := conf.ResponseConfig.Features + slateCols := slateOutputColumnNames(conf) + if len(slateCols) > 0 { + targetOutputColumns = excludeColumns(targetOutputColumns, slateCols) + } + if len(targetOutputColumns) > 0 && len(componentMatrix.Rows) > 0 { + targetOutputSchema := buildOutputSchema(targetOutputColumns, componentMatrix) + targetScores := make([]*pb.TargetScore, len(componentMatrix.Rows)) + for i, row := range componentMatrix.Rows { + targetScores[i] = &pb.TargetScore{ + OutputValues: extractRowOutputBytes(row, targetOutputColumns, targetOutputSchema, componentMatrix), + } + } + resp.TargetScores = targetScores + resp.TargetOutputSchema = targetOutputSchema + } + + // Slate-level scores from SlateData (one per slate row) + if slateData != nil && len(slateData.Rows) > 0 { + slateOutputColumns := includeColumns(slateCols, conf.ResponseConfig.Features) + if len(slateOutputColumns) > 0 { + slateOutputSchema := buildOutputSchema(slateOutputColumns, slateData) + slateScores := make([]*pb.SlateScore, len(slateData.Rows)) + for i, row := range slateData.Rows { + slateScores[i] = &pb.SlateScore{ + OutputValues: extractRowOutputBytes(row, slateOutputColumns, slateOutputSchema, slateData), + } + } + resp.SlateScores = slateScores + resp.SlateOutputSchema = slateOutputSchema + } + } + + return resp +} + +// slateOutputColumnNames collects output column names from slate predator and iris configs. +func slateOutputColumnNames(conf *config.Config) []string { + var cols []string + for _, comp := range conf.ComponentConfig.PredatorComponentConfig.Values() { + pComp, ok := comp.(config.PredatorComponentConfig) + if !ok || !pComp.SlateComponent { + continue + } + for _, out := range pComp.Outputs { + cols = append(cols, out.ModelScores...) + } + } + for _, comp := range conf.ComponentConfig.NumerixComponentConfig.Values() { + iComp, ok := comp.(config.NumerixComponentConfig) + if !ok || !iComp.SlateComponent { + continue + } + cols = append(cols, iComp.ScoreColumn) + } + return cols +} + +// excludeColumns removes excluded columns from a source list while preserving order. +func excludeColumns(source []string, excluded []string) []string { + if len(source) == 0 || len(excluded) == 0 { + return source + } + excludedSet := make(map[string]struct{}, len(excluded)) + for _, col := range excluded { + excludedSet[col] = struct{}{} + } + + filtered := make([]string, 0, len(source)) + for _, col := range source { + if _, ok := excludedSet[col]; ok { + continue + } + filtered = append(filtered, col) + } + return filtered +} + +// includeColumns keeps only columns that are present in the allowed list while preserving order. +func includeColumns(source []string, allowed []string) []string { + if len(source) == 0 || len(allowed) == 0 { + return []string{} + } + allowedSet := make(map[string]struct{}, len(allowed)) + for _, col := range allowed { + allowedSet[col] = struct{}{} + } + + filtered := make([]string, 0, len(source)) + for _, col := range source { + if _, ok := allowedSet[col]; !ok { + continue + } + filtered = append(filtered, col) + } + return filtered +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// buildOutputSchema creates FeatureSchema entries for the output columns. +// It looks up each column in the byte column map first (typed), then string column map. +func buildOutputSchema(columns []string, m *matrix.ComponentMatrix) []*pb.FeatureSchema { + schema := make([]*pb.FeatureSchema, len(columns)) + for i, col := range columns { + schema[i] = &pb.FeatureSchema{ + Name: col, + DataType: resolveDataType(col, m), + } + } + return schema +} + +// resolveDataType determines the proto DataType for a column by checking the matrix column maps. +func resolveDataType(colName string, m *matrix.ComponentMatrix) pb.DataType { + if byteCol, ok := m.ByteColumnIndexMap[colName]; ok { + return mapDataTypeString(byteCol.DataType) + } + if stringCol, ok := m.StringColumnIndexMap[colName]; ok { + return mapDataTypeString(stringCol.DataType) + } + return pb.DataType_DataTypeString // default fallback +} + +// mapDataTypeString converts the internal DataType string to the proto enum. +func mapDataTypeString(dt string) pb.DataType { + switch dt { + case "DataTypeFP32", "fp32": + return pb.DataType_DataTypeFP32 + case "DataTypeFP64", "fp64": + return pb.DataType_DataTypeFP64 + case "DataTypeFP16", "fp16": + return pb.DataType_DataTypeFP16 + case "DataTypeInt8", "int8": + return pb.DataType_DataTypeInt8 + case "DataTypeInt16", "int16": + return pb.DataType_DataTypeInt16 + case "DataTypeInt32", "int32": + return pb.DataType_DataTypeInt32 + case "DataTypeInt64", "int64": + return pb.DataType_DataTypeInt64 + case "DataTypeUint8", "uint8": + return pb.DataType_DataTypeUint8 + case "DataTypeUint16", "uint16": + return pb.DataType_DataTypeUint16 + case "DataTypeUint32", "uint32": + return pb.DataType_DataTypeUint32 + case "DataTypeUint64", "uint64": + return pb.DataType_DataTypeUint64 + case "DataTypeString", "string": + return pb.DataType_DataTypeString + case "DataTypeBool", "bool": + return pb.DataType_DataTypeBool + case "DataTypeFP8E5M2", "fp8e5m2": + return pb.DataType_DataTypeFP8E5M2 + case "DataTypeFP8E4M3", "fp8e4m3": + return pb.DataType_DataTypeFP8E4M3 + case "DataTypeFP8E5M2Vector", "fp8e5m2vector": + return pb.DataType_DataTypeFP8E5M2Vector + case "DataTypeFP8E4M3Vector", "fp8e4m3vector": + return pb.DataType_DataTypeFP8E4M3Vector + case "DataTypeFP32Vector", "fp32vector": + return pb.DataType_DataTypeFP32Vector + case "DataTypeFP16Vector", "fp16vector": + return pb.DataType_DataTypeFP16Vector + case "DataTypeFP64Vector", "fp64vector": + return pb.DataType_DataTypeFP64Vector + case "DataTypeInt8Vector", "int8vector": + return pb.DataType_DataTypeInt8Vector + case "DataTypeInt16Vector", "int16vector": + return pb.DataType_DataTypeInt16Vector + case "DataTypeInt32Vector", "int32vector": + return pb.DataType_DataTypeInt32Vector + case "DataTypeInt64Vector", "int64vector": + return pb.DataType_DataTypeInt64Vector + case "DataTypeUint8Vector", "uint8vector": + return pb.DataType_DataTypeUint8Vector + case "DataTypeUint16Vector", "uint16vector": + return pb.DataType_DataTypeUint16Vector + case "DataTypeUint32Vector", "uint32vector": + return pb.DataType_DataTypeUint32Vector + case "DataTypeUint64Vector", "uint64vector": + return pb.DataType_DataTypeUint64Vector + case "DataTypeStringVector", "stringvector": + return pb.DataType_DataTypeStringVector + case "DataTypeBoolVector", "boolvector": + return pb.DataType_DataTypeBoolVector + default: + return pb.DataType_DataTypeString + } +} + +// extractRowOutputBytes extracts the output column values from a single matrix row as [][]byte. +// It prefers byte columns (typed, zero-copy). For string columns, it converts the string value +// to typed bytes using the output schema's data type so the response carries actual typed values. +func extractRowOutputBytes(row matrix.Row, columns []string, outputSchema []*pb.FeatureSchema, m *matrix.ComponentMatrix) [][]byte { + values := make([][]byte, len(columns)) + for i, col := range columns { + if byteCol, ok := m.ByteColumnIndexMap[col]; ok { + // Byte column: already typed, use directly + if byteCol.Index < len(row.ByteData) { + values[i] = row.ByteData[byteCol.Index] + } + } else if stringCol, ok := m.StringColumnIndexMap[col]; ok { + if stringCol.Index < len(row.StringData) { + strVal := row.StringData[stringCol.Index] + // Convert string value to typed bytes using the output schema's data type + if i < len(outputSchema) { + targetDt := outputSchema[i].DataType.String() + if converted, err := typeconverter.StringToBytes(strVal, targetDt); err == nil { + values[i] = converted + continue + } + } + // Fallback: raw string bytes if conversion fails or no schema + values[i] = []byte(strVal) + } + } + } + return values +} diff --git a/inferflow/internal/server/server.go b/inferflow/internal/server/server.go index 3c2f0171..375ec0a1 100644 --- a/inferflow/internal/server/server.go +++ b/inferflow/internal/server/server.go @@ -11,6 +11,7 @@ import ( "github.com/Meesho/BharatMLStack/inferflow/pkg/logger" "github.com/Meesho/BharatMLStack/inferflow/pkg/middleware" pb "github.com/Meesho/BharatMLStack/inferflow/server/grpc" + predict "github.com/Meesho/BharatMLStack/inferflow/server/grpc/predict" "github.com/cockroachdb/cmux" "google.golang.org/grpc" "google.golang.org/grpc/reflection" @@ -38,6 +39,7 @@ func InitServer(configs *configs.AppConfigs) { ) reflection.Register(grpcServer) pb.RegisterInferflowServer(grpcServer, &inferflow.Inferflow{}) + predict.RegisterPredictServer(grpcServer, &inferflow.PredictService{}) // HTTP Server : h := http.NewServeMux() diff --git a/inferflow/pkg/configs/configs.go b/inferflow/pkg/configs/configs.go index 1b58d815..29795b7a 100644 --- a/inferflow/pkg/configs/configs.go +++ b/inferflow/pkg/configs/configs.go @@ -48,6 +48,10 @@ type Configs struct { NumerixClientV1_GrpcPlainText bool `mapstructure:"numerixClientV1_plainText"` NumerixClientV1_AuthToken string `mapstructure:"numerixClientV1_authToken"` NumerixClientV1_BatchSize int `mapstructure:"numerixClientV1_batchSize"` + + // Kafka config for inference logging + KafkaBootstrapServers string `mapstructure:"kafka_bootstrapServers"` + KafkaLoggingTopic string `mapstructure:"kafka_loggingTopic"` } type DynamicConfigs struct { diff --git a/inferflow/pkg/configs/configs_init.go b/inferflow/pkg/configs/configs_init.go index b3ce7902..4af38862 100644 --- a/inferflow/pkg/configs/configs_init.go +++ b/inferflow/pkg/configs/configs_init.go @@ -75,6 +75,12 @@ func bindEnvVars() { viper.BindEnv("externalServicePredator_CallerToken", "EXTERNAL_SERVICE_PREDATOR_CALLER_TOKEN") viper.BindEnv("externalServicePredator_Deadline", "EXTERNAL_SERVICE_PREDATOR_DEADLINE") - // Metrics config + // Metrics / Telegraf config viper.BindEnv("metrics_sampling_rate", "METRIC_SAMPLING_RATE") + viper.BindEnv("telegraf_host", "TELEGRAF_HOST") + viper.BindEnv("telegraf_port", "TELEGRAF_PORT") + + // Kafka inference logging config + viper.BindEnv("kafka_bootstrapServers", "KAFKA_BOOTSTRAP_SERVERS") + viper.BindEnv("kafka_v2LogTopic", "KAFKA_V2_LOG_TOPIC") } diff --git a/inferflow/pkg/matrix/matrix.go b/inferflow/pkg/matrix/matrix.go index 0abe3422..a39abfee 100644 --- a/inferflow/pkg/matrix/matrix.go +++ b/inferflow/pkg/matrix/matrix.go @@ -111,6 +111,39 @@ func (m *ComponentMatrix) PopulateByteData(columnToPopulate string, score [][]by } } +// PopulateByteAndStringData writes byte data to the byte column and, if a matching +// string column exists, converts and writes the string representation as well. +func (m *ComponentMatrix) PopulateByteAndStringData(columnToPopulate string, score [][]byte) { + col, ok := m.ByteColumnIndexMap[columnToPopulate] + if !ok { + return + } + stringCol, hasStringCol := m.StringColumnIndexMap[columnToPopulate] + var strVal string + if len(score) == 1 { + if hasStringCol { + strVal, _ = typeconverter.BytesToString(score[0], stringCol.DataType) + } + for i := 0; i < len(m.Rows); i++ { + m.Rows[i].ByteData[col.Index] = score[0] + if hasStringCol { + m.Rows[i].StringData[stringCol.Index] = strVal + } + } + return + } + + for i := 0; i < len(m.Rows) && i < len(score); i++ { + m.Rows[i].ByteData[col.Index] = score[i] + if hasStringCol { + strVal, err := typeconverter.BytesToString(score[i], stringCol.DataType) + if err == nil { + m.Rows[i].StringData[stringCol.Index] = strVal + } + } + } +} + func (m *ComponentMatrix) PopulateStringDataFromSingleValue(columnToPopulate string, value string) { col, ok := m.StringColumnIndexMap[columnToPopulate] if !ok { diff --git a/inferflow/pkg/metrics/metrics.go b/inferflow/pkg/metrics/metrics.go index d1eaa731..506bd9c7 100644 --- a/inferflow/pkg/metrics/metrics.go +++ b/inferflow/pkg/metrics/metrics.go @@ -33,7 +33,11 @@ func InitMetrics(configs *configs.AppConfigs) { statsd.WithTags(globalTags), ) if err != nil { - logger.Panic("StatsD client initialization failed!", err) + // In local/dev environments Telegraf may not be running; log and continue + // with the default no-op-safe client instead of crashing the service. + logger.Error("StatsD client initialization failed, metrics will be unavailable", err) + statsDClient = getDefaultClient() + return } //go initJMXServer() logger.Info(fmt.Sprintf("Metrics client initialized with telegraf address - %s, global tags - %v, and sampling rate - %f", @@ -41,7 +45,11 @@ func InitMetrics(configs *configs.AppConfigs) { } func getDefaultClient() *statsd.Client { - client, _ := statsd.New("localhost:8125") + client, err := statsd.New("localhost:8125") + if err != nil { + // Return a no-op client so callers never hit nil-pointer panics. + client, _ = statsd.New("localhost:8125", statsd.WithoutTelemetry()) + } return client } diff --git a/inferflow/pkg/utils/converter.go b/inferflow/pkg/utils/converter.go new file mode 100644 index 00000000..fccffcc5 --- /dev/null +++ b/inferflow/pkg/utils/converter.go @@ -0,0 +1,196 @@ +package utils + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + + "github.com/Meesho/BharatMLStack/inferflow/pkg/datatypeconverter/types" +) + +// StringToDataType converts a string representation of a data type to types.DataType. +// It adds "DataType" prefix if not already present. +func StringToDataType(dataTypeStr string) (types.DataType, error) { + s := strings.TrimSpace(dataTypeStr) + if !strings.HasPrefix(s, "DataType") { + s = "DataType" + s + } + return types.ParseDataType(s) +} + +// IsBytesDataType checks if the given data type string represents a BYTES type. +func IsBytesDataType(dataTypeStr string) bool { + s := strings.TrimSpace(dataTypeStr) + if !strings.HasPrefix(s, "DataType") { + s = "DataType" + s + } + return strings.EqualFold(s, "DataTypeBYTES") +} + +// ConvertStringToType converts a string value to its byte representation +// based on the specified data type. +func ConvertStringToType(value string, dataType types.DataType) ([]byte, error) { + if value == "" { + return nil, fmt.Errorf("empty string provided for conversion") + } + + switch dataType { + case types.DataTypeUnknown: + return nil, nil + + case types.DataTypeString: + return []byte(value), nil + + case types.DataTypeInt8: + val, err := strconv.ParseInt(value, 10, 8) + if err != nil { + return nil, fmt.Errorf("failed to parse INT8: %w", err) + } + return []byte{byte(int8(val))}, nil + + case types.DataTypeInt16: + val, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse INT16: %w", err) + } + buf := make([]byte, 2) + binary.LittleEndian.PutUint16(buf, uint16(val)) + return buf, nil + + case types.DataTypeInt32: + val, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse INT32: %w", err) + } + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, uint32(val)) + return buf, nil + + case types.DataTypeInt64: + val, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse INT64: %w", err) + } + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(val)) + return buf, nil + + case types.DataTypeUint8: + val, err := strconv.ParseUint(value, 10, 8) + if err != nil { + return nil, fmt.Errorf("failed to parse UINT8: %w", err) + } + return []byte{byte(val)}, nil + + case types.DataTypeUint16: + val, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse UINT16: %w", err) + } + buf := make([]byte, 2) + binary.LittleEndian.PutUint16(buf, uint16(val)) + return buf, nil + + case types.DataTypeUint32: + val, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse UINT32: %w", err) + } + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, uint32(val)) + return buf, nil + + case types.DataTypeUint64: + val, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse UINT64: %w", err) + } + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, val) + return buf, nil + + case types.DataTypeFP16: + val, err := strconv.ParseFloat(value, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse FP16: %w", err) + } + bits := math.Float32bits(float32(val)) + buf := make([]byte, 2) + binary.LittleEndian.PutUint16(buf, uint16(bits>>16)) + return buf, nil + + case types.DataTypeFP32: + val, err := strconv.ParseFloat(value, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse FP32: %w", err) + } + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(float32(val))) + return buf, nil + + case types.DataTypeFP64: + val, err := strconv.ParseFloat(value, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse FP64: %w", err) + } + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, math.Float64bits(val)) + return buf, nil + + case types.DataTypeBool: + val, err := strconv.ParseBool(value) + if err != nil { + return nil, fmt.Errorf("failed to parse BOOL: %w", err) + } + if val { + return []byte{1}, nil + } + return []byte{0}, nil + + case types.DataTypeStringVector: + var result []string + err := json.Unmarshal([]byte(value), &result) + if err != nil { + return nil, fmt.Errorf("failed to parse StringVector: %w", err) + } + return json.Marshal(result) + + case types.DataTypeFP32Vector: + var result []float32 + err := json.Unmarshal([]byte(value), &result) + if err != nil { + return nil, fmt.Errorf("failed to parse FP32Vector: %w", err) + } + return json.Marshal(result) + + case types.DataTypeFP64Vector: + var result []float64 + err := json.Unmarshal([]byte(value), &result) + if err != nil { + return nil, fmt.Errorf("failed to parse FP64Vector: %w", err) + } + return json.Marshal(result) + + case types.DataTypeInt32Vector: + var result []int32 + err := json.Unmarshal([]byte(value), &result) + if err != nil { + return nil, fmt.Errorf("failed to parse Int32Vector: %w", err) + } + return json.Marshal(result) + + case types.DataTypeInt64Vector: + var result []int64 + err := json.Unmarshal([]byte(value), &result) + if err != nil { + return nil, fmt.Errorf("failed to parse Int64Vector: %w", err) + } + return json.Marshal(result) + + default: + return nil, fmt.Errorf("unsupported data type: %v", dataType) + } +} diff --git a/inferflow/pkg/utils/default_values.go b/inferflow/pkg/utils/default_values.go index 22b9f71e..40a01f2c 100644 --- a/inferflow/pkg/utils/default_values.go +++ b/inferflow/pkg/utils/default_values.go @@ -57,6 +57,20 @@ func GetDefaultValuesInBytes(dataType string) []byte { return []byte("0.0") } +// GetDefaultValueByType returns default byte value for scalar types. +// Vectors are handled separately with empty bytes (2-byte size = 0). +func GetDefaultValueByType(dataType string) []byte { + dt := strings.ToLower(dataType) + dt = strings.TrimPrefix(dt, "datatype") + dt = strings.TrimPrefix(dt, "_") + dt = strings.TrimSpace(dt) + + if val, ok := defaultBytes[dt]; ok { + return val + } + return []byte("0.0") +} + func GetDefaultValuesInBytesForVector(dataType, featureStoreDataType string, shape []int) []byte { dt := strings.ToLower(dataType) @@ -82,4 +96,4 @@ func GetDefaultValuesInBytesForVector(dataType, featureStoreDataType string, sha return val } return []byte("0.0") -} \ No newline at end of file +} diff --git a/inferflow/pkg/utils/feature_encoder.go b/inferflow/pkg/utils/feature_encoder.go new file mode 100644 index 00000000..46246646 --- /dev/null +++ b/inferflow/pkg/utils/feature_encoder.go @@ -0,0 +1,130 @@ +package utils + +import ( + "encoding/binary" + "fmt" + + "github.com/Meesho/BharatMLStack/inferflow/pkg/datatypeconverter/types" +) + +// FeatureEncoder encodes features into a compact binary format. +// The first byte indicates if any values were generated (1 = no generated, 0 = some generated). +type FeatureEncoder struct { + buffer []byte +} + +// NewFeatureEncoder creates a new FeatureEncoder with pre-allocated capacity. +func NewFeatureEncoder(estimatedSize int) *FeatureEncoder { + encoder := &FeatureEncoder{ + buffer: make([]byte, 1, estimatedSize+1), + } + encoder.buffer[0] = 1 // Initialize flag to 1 (no generated values) + return encoder +} + +// MarkValueAsGenerated sets the flag indicating some values were auto-generated. +func (e *FeatureEncoder) MarkValueAsGenerated() { + e.buffer[0] = 0 +} + +// AppendFeature appends a typed feature value to the buffer. +// For strings and vectors, includes a 2-byte size prefix. +func (e *FeatureEncoder) AppendFeature(dataType types.DataType, value []byte) (tag string, err error) { + switch dataType { + case types.DataTypeInt8: + return e.appendScalar(value, 1) + case types.DataTypeInt16: + return e.appendScalar(value, 2) + case types.DataTypeInt32: + return e.appendScalar(value, 4) + case types.DataTypeInt64: + return e.appendScalar(value, 8) + case types.DataTypeFP8E5M2, types.DataTypeFP8E4M3: + return e.appendScalar(value, 1) + case types.DataTypeFP16: + return e.appendFP16Scalar(value) + case types.DataTypeFP32: + return e.appendScalar(value, 4) + case types.DataTypeFP64: + return e.appendScalar(value, 8) + case types.DataTypeUint8: + return e.appendScalar(value, 1) + case types.DataTypeUint16: + return e.appendScalar(value, 2) + case types.DataTypeUint32: + return e.appendScalar(value, 4) + case types.DataTypeUint64: + return e.appendScalar(value, 8) + case types.DataTypeBool: + return e.appendScalar(value, 1) + case types.DataTypeString: + return e.appendSizedValue(value) + case types.DataTypeFP8E5M2Vector, + types.DataTypeFP8E4M3Vector, + types.DataTypeFP16Vector, + types.DataTypeFP32Vector, + types.DataTypeFP64Vector, + types.DataTypeInt8Vector, + types.DataTypeInt16Vector, + types.DataTypeInt32Vector, + types.DataTypeInt64Vector, + types.DataTypeUint8Vector, + types.DataTypeUint16Vector, + types.DataTypeUint32Vector, + types.DataTypeUint64Vector, + types.DataTypeStringVector, + types.DataTypeBoolVector: + return e.appendSizedValue(value) + default: + return "invalid_data_type", fmt.Errorf("unsupported data type: %v", dataType) + } +} + +// AppendBytesFeature appends a raw bytes feature with size prefix. +func (e *FeatureEncoder) AppendBytesFeature(value []byte) (string, error) { + return e.appendSizedValue(value) +} + +func (e *FeatureEncoder) appendScalar(value []byte, expectedSize int) (string, error) { + if len(value) != expectedSize { + return "invalid_value_size", fmt.Errorf("invalid value size: expected %d, got %d", expectedSize, len(value)) + } + e.buffer = append(e.buffer, value...) + return "", nil +} + +func (e *FeatureEncoder) appendFP16Scalar(value []byte) (string, error) { + switch len(value) { + case 2: + e.buffer = append(e.buffer, value...) + case 4: + e.buffer = append(e.buffer, value[:2]...) + default: + return "invalid_value_size", fmt.Errorf("invalid FP16 value size: expected 2 or 4 bytes, got %d", len(value)) + } + return "", nil +} + +func (e *FeatureEncoder) appendSizedValue(value []byte) (string, error) { + valueLen := len(value) + if valueLen > 65535 { + return "invalid_value_size", fmt.Errorf("value size %d exceeds maximum of 65535 bytes", valueLen) + } + + sizeBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(sizeBytes, uint16(valueLen)) + e.buffer = append(e.buffer, sizeBytes...) + e.buffer = append(e.buffer, value...) + return "", nil +} + +// Bytes returns the encoded byte buffer. +func (e *FeatureEncoder) Bytes() []byte { + return e.buffer +} + +// Reset clears the buffer for reuse. +func (e *FeatureEncoder) Reset() { + e.buffer = e.buffer[:1] + e.buffer[0] = 0 +} diff --git a/inferflow/server/grpc/inferflow.pb.go b/inferflow/server/grpc/inferflow.pb.go index b79741b1..ecd303ba 100644 --- a/inferflow/server/grpc/inferflow.pb.go +++ b/inferflow/server/grpc/inferflow.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.8 -// protoc v5.29.3 +// protoc-gen-go v1.36.11 +// protoc v6.33.2 // source: inferflow.proto package grpc diff --git a/inferflow/server/grpc/inferflow_grpc.pb.go b/inferflow/server/grpc/inferflow_grpc.pb.go index be129164..2eea196c 100644 --- a/inferflow/server/grpc/inferflow_grpc.pb.go +++ b/inferflow/server/grpc/inferflow_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.3 +// - protoc-gen-go-grpc v1.6.1 +// - protoc v6.33.2 // source: inferflow.proto package grpc @@ -63,7 +63,7 @@ type InferflowServer interface { type UnimplementedInferflowServer struct{} func (UnimplementedInferflowServer) RetrieveModelScore(context.Context, *InferflowRequestProto) (*InferflowResponseProto, error) { - return nil, status.Errorf(codes.Unimplemented, "method RetrieveModelScore not implemented") + return nil, status.Error(codes.Unimplemented, "method RetrieveModelScore not implemented") } func (UnimplementedInferflowServer) mustEmbedUnimplementedInferflowServer() {} func (UnimplementedInferflowServer) testEmbeddedByValue() {} @@ -76,7 +76,7 @@ type UnsafeInferflowServer interface { } func RegisterInferflowServer(s grpc.ServiceRegistrar, srv InferflowServer) { - // If the following call pancis, it indicates UnimplementedInferflowServer was + // If the following call panics, it indicates UnimplementedInferflowServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/inferflow/server/grpc/inferflow_logging.pb.go b/inferflow/server/grpc/inferflow_logging.pb.go new file mode 100644 index 00000000..1091972e --- /dev/null +++ b/inferflow/server/grpc/inferflow_logging.pb.go @@ -0,0 +1,226 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v5.28.3 +// source: inferflow_logging.proto + +package grpc + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// InferflowLog is the message structure for inference logging +type InferflowLog struct { + state protoimpl.MessageState `protogen:"open.v1"` + UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + TrackingId string `protobuf:"bytes,2,opt,name=tracking_id,json=trackingId,proto3" json:"tracking_id,omitempty"` + ModelConfigId string `protobuf:"bytes,3,opt,name=model_config_id,json=modelConfigId,proto3" json:"model_config_id,omitempty"` + Entities []string `protobuf:"bytes,4,rep,name=entities,proto3" json:"entities,omitempty"` + Features []*PerEntityFeatures `protobuf:"bytes,5,rep,name=features,proto3" json:"features,omitempty"` + Metadata []byte `protobuf:"bytes,6,opt,name=metadata,proto3" json:"metadata,omitempty"` + ParentEntity []string `protobuf:"bytes,7,rep,name=parent_entity,json=parentEntity,proto3" json:"parent_entity,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InferflowLog) Reset() { + *x = InferflowLog{} + mi := &file_inferflow_logging_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InferflowLog) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InferflowLog) ProtoMessage() {} + +func (x *InferflowLog) ProtoReflect() protoreflect.Message { + mi := &file_inferflow_logging_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InferflowLog.ProtoReflect.Descriptor instead. +func (*InferflowLog) Descriptor() ([]byte, []int) { + return file_inferflow_logging_proto_rawDescGZIP(), []int{0} +} + +func (x *InferflowLog) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *InferflowLog) GetTrackingId() string { + if x != nil { + return x.TrackingId + } + return "" +} + +func (x *InferflowLog) GetModelConfigId() string { + if x != nil { + return x.ModelConfigId + } + return "" +} + +func (x *InferflowLog) GetEntities() []string { + if x != nil { + return x.Entities + } + return nil +} + +func (x *InferflowLog) GetFeatures() []*PerEntityFeatures { + if x != nil { + return x.Features + } + return nil +} + +func (x *InferflowLog) GetMetadata() []byte { + if x != nil { + return x.Metadata + } + return nil +} + +func (x *InferflowLog) GetParentEntity() []string { + if x != nil { + return x.ParentEntity + } + return nil +} + +type PerEntityFeatures struct { + state protoimpl.MessageState `protogen:"open.v1"` + EncodedFeatures []byte `protobuf:"bytes,1,opt,name=encoded_features,json=encodedFeatures,proto3" json:"encoded_features,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PerEntityFeatures) Reset() { + *x = PerEntityFeatures{} + mi := &file_inferflow_logging_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PerEntityFeatures) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PerEntityFeatures) ProtoMessage() {} + +func (x *PerEntityFeatures) ProtoReflect() protoreflect.Message { + mi := &file_inferflow_logging_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PerEntityFeatures.ProtoReflect.Descriptor instead. +func (*PerEntityFeatures) Descriptor() ([]byte, []int) { + return file_inferflow_logging_proto_rawDescGZIP(), []int{1} +} + +func (x *PerEntityFeatures) GetEncodedFeatures() []byte { + if x != nil { + return x.EncodedFeatures + } + return nil +} + +var File_inferflow_logging_proto protoreflect.FileDescriptor + +const file_inferflow_logging_proto_rawDesc = "" + + "\n" + + "\x17inferflow_logging.proto\"\xfd\x01\n" + + "\fInferflowLog\x12\x17\n" + + "\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1f\n" + + "\vtracking_id\x18\x02 \x01(\tR\n" + + "trackingId\x12&\n" + + "\x0fmodel_config_id\x18\x03 \x01(\tR\rmodelConfigId\x12\x1a\n" + + "\bentities\x18\x04 \x03(\tR\bentities\x12.\n" + + "\bfeatures\x18\x05 \x03(\v2\x12.PerEntityFeaturesR\bfeatures\x12\x1a\n" + + "\bmetadata\x18\x06 \x01(\fR\bmetadata\x12#\n" + + "\rparent_entity\x18\a \x03(\tR\fparentEntity\">\n" + + "\x11PerEntityFeatures\x12)\n" + + "\x10encoded_features\x18\x01 \x01(\fR\x0fencodedFeaturesB\tZ\a../grpcb\x06proto3" + +var ( + file_inferflow_logging_proto_rawDescOnce sync.Once + file_inferflow_logging_proto_rawDescData []byte +) + +func file_inferflow_logging_proto_rawDescGZIP() []byte { + file_inferflow_logging_proto_rawDescOnce.Do(func() { + file_inferflow_logging_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_inferflow_logging_proto_rawDesc), len(file_inferflow_logging_proto_rawDesc))) + }) + return file_inferflow_logging_proto_rawDescData +} + +var file_inferflow_logging_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_inferflow_logging_proto_goTypes = []any{ + (*InferflowLog)(nil), // 0: InferflowLog + (*PerEntityFeatures)(nil), // 1: PerEntityFeatures +} +var file_inferflow_logging_proto_depIdxs = []int32{ + 1, // 0: InferflowLog.features:type_name -> PerEntityFeatures + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_inferflow_logging_proto_init() } +func file_inferflow_logging_proto_init() { + if File_inferflow_logging_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_inferflow_logging_proto_rawDesc), len(file_inferflow_logging_proto_rawDesc)), + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_inferflow_logging_proto_goTypes, + DependencyIndexes: file_inferflow_logging_proto_depIdxs, + MessageInfos: file_inferflow_logging_proto_msgTypes, + }.Build() + File_inferflow_logging_proto = out.File + file_inferflow_logging_proto_goTypes = nil + file_inferflow_logging_proto_depIdxs = nil +} diff --git a/inferflow/server/grpc/kafka_request.pb.go b/inferflow/server/grpc/kafka_request.pb.go new file mode 100644 index 00000000..7b52b47e --- /dev/null +++ b/inferflow/server/grpc/kafka_request.pb.go @@ -0,0 +1,223 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v5.28.3 +// source: kafka_request.proto + +package grpc + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + anypb "google.golang.org/protobuf/types/known/anypb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type KafkaRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Value *KafkaEventValue `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` + Headers *anypb.Any `protobuf:"bytes,2,opt,name=headers,proto3" json:"headers,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *KafkaRequest) Reset() { + *x = KafkaRequest{} + mi := &file_kafka_request_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *KafkaRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*KafkaRequest) ProtoMessage() {} + +func (x *KafkaRequest) ProtoReflect() protoreflect.Message { + mi := &file_kafka_request_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use KafkaRequest.ProtoReflect.Descriptor instead. +func (*KafkaRequest) Descriptor() ([]byte, []int) { + return file_kafka_request_proto_rawDescGZIP(), []int{0} +} + +func (x *KafkaRequest) GetValue() *KafkaEventValue { + if x != nil { + return x.Value + } + return nil +} + +func (x *KafkaRequest) GetHeaders() *anypb.Any { + if x != nil { + return x.Headers + } + return nil +} + +type KafkaEventValue struct { + state protoimpl.MessageState `protogen:"open.v1"` + EventName string `protobuf:"bytes,1,opt,name=event_name,json=eventName,proto3" json:"event_name,omitempty"` + EventId string `protobuf:"bytes,2,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"` + CreatedAt int64 `protobuf:"varint,3,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + Properties *anypb.Any `protobuf:"bytes,4,opt,name=properties,proto3" json:"properties,omitempty"` + UserId string `protobuf:"bytes,5,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *KafkaEventValue) Reset() { + *x = KafkaEventValue{} + mi := &file_kafka_request_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *KafkaEventValue) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*KafkaEventValue) ProtoMessage() {} + +func (x *KafkaEventValue) ProtoReflect() protoreflect.Message { + mi := &file_kafka_request_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use KafkaEventValue.ProtoReflect.Descriptor instead. +func (*KafkaEventValue) Descriptor() ([]byte, []int) { + return file_kafka_request_proto_rawDescGZIP(), []int{1} +} + +func (x *KafkaEventValue) GetEventName() string { + if x != nil { + return x.EventName + } + return "" +} + +func (x *KafkaEventValue) GetEventId() string { + if x != nil { + return x.EventId + } + return "" +} + +func (x *KafkaEventValue) GetCreatedAt() int64 { + if x != nil { + return x.CreatedAt + } + return 0 +} + +func (x *KafkaEventValue) GetProperties() *anypb.Any { + if x != nil { + return x.Properties + } + return nil +} + +func (x *KafkaEventValue) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +var File_kafka_request_proto protoreflect.FileDescriptor + +const file_kafka_request_proto_rawDesc = "" + + "\n" + + "\x13kafka_request.proto\x1a\x19google/protobuf/any.proto\"f\n" + + "\fKafkaRequest\x12&\n" + + "\x05value\x18\x01 \x01(\v2\x10.KafkaEventValueR\x05value\x12.\n" + + "\aheaders\x18\x02 \x01(\v2\x14.google.protobuf.AnyR\aheaders\"\xb9\x01\n" + + "\x0fKafkaEventValue\x12\x1d\n" + + "\n" + + "event_name\x18\x01 \x01(\tR\teventName\x12\x19\n" + + "\bevent_id\x18\x02 \x01(\tR\aeventId\x12\x1d\n" + + "\n" + + "created_at\x18\x03 \x01(\x03R\tcreatedAt\x124\n" + + "\n" + + "properties\x18\x04 \x01(\v2\x14.google.protobuf.AnyR\n" + + "properties\x12\x17\n" + + "\auser_id\x18\x05 \x01(\tR\x06userIdB\tZ\a../grpcb\x06proto3" + +var ( + file_kafka_request_proto_rawDescOnce sync.Once + file_kafka_request_proto_rawDescData []byte +) + +func file_kafka_request_proto_rawDescGZIP() []byte { + file_kafka_request_proto_rawDescOnce.Do(func() { + file_kafka_request_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_kafka_request_proto_rawDesc), len(file_kafka_request_proto_rawDesc))) + }) + return file_kafka_request_proto_rawDescData +} + +var file_kafka_request_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_kafka_request_proto_goTypes = []any{ + (*KafkaRequest)(nil), // 0: KafkaRequest + (*KafkaEventValue)(nil), // 1: KafkaEventValue + (*anypb.Any)(nil), // 2: google.protobuf.Any +} +var file_kafka_request_proto_depIdxs = []int32{ + 1, // 0: KafkaRequest.value:type_name -> KafkaEventValue + 2, // 1: KafkaRequest.headers:type_name -> google.protobuf.Any + 2, // 2: KafkaEventValue.properties:type_name -> google.protobuf.Any + 3, // [3:3] is the sub-list for method output_type + 3, // [3:3] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_kafka_request_proto_init() } +func file_kafka_request_proto_init() { + if File_kafka_request_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_kafka_request_proto_rawDesc), len(file_kafka_request_proto_rawDesc)), + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_kafka_request_proto_goTypes, + DependencyIndexes: file_kafka_request_proto_depIdxs, + MessageInfos: file_kafka_request_proto_msgTypes, + }.Build() + File_kafka_request_proto = out.File + file_kafka_request_proto_goTypes = nil + file_kafka_request_proto_depIdxs = nil +} diff --git a/inferflow/server/grpc/predict/predict.pb.go b/inferflow/server/grpc/predict/predict.pb.go new file mode 100644 index 00000000..8b3accca --- /dev/null +++ b/inferflow/server/grpc/predict/predict.pb.go @@ -0,0 +1,1318 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v5.28.3 +// source: predict.proto + +package predict + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type DataType int32 + +const ( + DataType_DataTypeUnknown DataType = 0 + DataType_DataTypeFP8E5M2 DataType = 1 + DataType_DataTypeFP8E4M3 DataType = 2 + DataType_DataTypeFP16 DataType = 3 + DataType_DataTypeFP32 DataType = 4 + DataType_DataTypeFP64 DataType = 5 + DataType_DataTypeInt8 DataType = 6 + DataType_DataTypeInt16 DataType = 7 + DataType_DataTypeInt32 DataType = 8 + DataType_DataTypeInt64 DataType = 9 + DataType_DataTypeUint8 DataType = 10 + DataType_DataTypeUint16 DataType = 11 + DataType_DataTypeUint32 DataType = 12 + DataType_DataTypeUint64 DataType = 13 + DataType_DataTypeString DataType = 14 + DataType_DataTypeBool DataType = 15 + DataType_DataTypeFP8E5M2Vector DataType = 16 + DataType_DataTypeFP8E4M3Vector DataType = 17 + DataType_DataTypeFP16Vector DataType = 18 + DataType_DataTypeFP32Vector DataType = 19 + DataType_DataTypeFP64Vector DataType = 20 + DataType_DataTypeInt8Vector DataType = 21 + DataType_DataTypeInt16Vector DataType = 22 + DataType_DataTypeInt32Vector DataType = 23 + DataType_DataTypeInt64Vector DataType = 24 + DataType_DataTypeUint8Vector DataType = 25 + DataType_DataTypeUint16Vector DataType = 26 + DataType_DataTypeUint32Vector DataType = 27 + DataType_DataTypeUint64Vector DataType = 28 + DataType_DataTypeStringVector DataType = 29 + DataType_DataTypeBoolVector DataType = 30 +) + +// Enum value maps for DataType. +var ( + DataType_name = map[int32]string{ + 0: "DataTypeUnknown", + 1: "DataTypeFP8E5M2", + 2: "DataTypeFP8E4M3", + 3: "DataTypeFP16", + 4: "DataTypeFP32", + 5: "DataTypeFP64", + 6: "DataTypeInt8", + 7: "DataTypeInt16", + 8: "DataTypeInt32", + 9: "DataTypeInt64", + 10: "DataTypeUint8", + 11: "DataTypeUint16", + 12: "DataTypeUint32", + 13: "DataTypeUint64", + 14: "DataTypeString", + 15: "DataTypeBool", + 16: "DataTypeFP8E5M2Vector", + 17: "DataTypeFP8E4M3Vector", + 18: "DataTypeFP16Vector", + 19: "DataTypeFP32Vector", + 20: "DataTypeFP64Vector", + 21: "DataTypeInt8Vector", + 22: "DataTypeInt16Vector", + 23: "DataTypeInt32Vector", + 24: "DataTypeInt64Vector", + 25: "DataTypeUint8Vector", + 26: "DataTypeUint16Vector", + 27: "DataTypeUint32Vector", + 28: "DataTypeUint64Vector", + 29: "DataTypeStringVector", + 30: "DataTypeBoolVector", + } + DataType_value = map[string]int32{ + "DataTypeUnknown": 0, + "DataTypeFP8E5M2": 1, + "DataTypeFP8E4M3": 2, + "DataTypeFP16": 3, + "DataTypeFP32": 4, + "DataTypeFP64": 5, + "DataTypeInt8": 6, + "DataTypeInt16": 7, + "DataTypeInt32": 8, + "DataTypeInt64": 9, + "DataTypeUint8": 10, + "DataTypeUint16": 11, + "DataTypeUint32": 12, + "DataTypeUint64": 13, + "DataTypeString": 14, + "DataTypeBool": 15, + "DataTypeFP8E5M2Vector": 16, + "DataTypeFP8E4M3Vector": 17, + "DataTypeFP16Vector": 18, + "DataTypeFP32Vector": 19, + "DataTypeFP64Vector": 20, + "DataTypeInt8Vector": 21, + "DataTypeInt16Vector": 22, + "DataTypeInt32Vector": 23, + "DataTypeInt64Vector": 24, + "DataTypeUint8Vector": 25, + "DataTypeUint16Vector": 26, + "DataTypeUint32Vector": 27, + "DataTypeUint64Vector": 28, + "DataTypeStringVector": 29, + "DataTypeBoolVector": 30, + } +) + +func (x DataType) Enum() *DataType { + p := new(DataType) + *p = x + return p +} + +func (x DataType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DataType) Descriptor() protoreflect.EnumDescriptor { + return file_predict_proto_enumTypes[0].Descriptor() +} + +func (DataType) Type() protoreflect.EnumType { + return &file_predict_proto_enumTypes[0] +} + +func (x DataType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DataType.Descriptor instead. +func (DataType) EnumDescriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{0} +} + +// Schema definition for a feature column +type FeatureSchema struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + DataType DataType `protobuf:"varint,2,opt,name=data_type,json=dataType,proto3,enum=DataType" json:"data_type,omitempty"` + VectorDim int32 `protobuf:"varint,3,opt,name=vector_dim,json=vectorDim,proto3" json:"vector_dim,omitempty"` // 0 = scalar, >0 = fixed vector length + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FeatureSchema) Reset() { + *x = FeatureSchema{} + mi := &file_predict_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FeatureSchema) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FeatureSchema) ProtoMessage() {} + +func (x *FeatureSchema) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FeatureSchema.ProtoReflect.Descriptor instead. +func (*FeatureSchema) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{0} +} + +func (x *FeatureSchema) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *FeatureSchema) GetDataType() DataType { + if x != nil { + return x.DataType + } + return DataType_DataTypeUnknown +} + +func (x *FeatureSchema) GetVectorDim() int32 { + if x != nil { + return x.VectorDim + } + return 0 +} + +// A request-level context feature (user, session, device, etc.) +type ContextFeature struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + DataType DataType `protobuf:"varint,3,opt,name=data_type,json=dataType,proto3,enum=DataType" json:"data_type,omitempty"` + VectorDim int32 `protobuf:"varint,4,opt,name=vector_dim,json=vectorDim,proto3" json:"vector_dim,omitempty"` // 0 = scalar, >0 = fixed vector length + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ContextFeature) Reset() { + *x = ContextFeature{} + mi := &file_predict_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ContextFeature) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ContextFeature) ProtoMessage() {} + +func (x *ContextFeature) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ContextFeature.ProtoReflect.Descriptor instead. +func (*ContextFeature) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{1} +} + +func (x *ContextFeature) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *ContextFeature) GetValue() []byte { + if x != nil { + return x.Value + } + return nil +} + +func (x *ContextFeature) GetDataType() DataType { + if x != nil { + return x.DataType + } + return DataType_DataTypeUnknown +} + +func (x *ContextFeature) GetVectorDim() int32 { + if x != nil { + return x.VectorDim + } + return 0 +} + +// A single entity to be scored/ranked +type Target struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + FeatureValues [][]byte `protobuf:"bytes,2,rep,name=feature_values,json=featureValues,proto3" json:"feature_values,omitempty"` // aligned with target_input_schema + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Target) Reset() { + *x = Target{} + mi := &file_predict_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Target) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Target) ProtoMessage() {} + +func (x *Target) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Target.ProtoReflect.Descriptor instead. +func (*Target) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{2} +} + +func (x *Target) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *Target) GetFeatureValues() [][]byte { + if x != nil { + return x.FeatureValues + } + return nil +} + +type TargetScore struct { + state protoimpl.MessageState `protogen:"open.v1"` + Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` + OutputValues [][]byte `protobuf:"bytes,2,rep,name=output_values,json=outputValues,proto3" json:"output_values,omitempty"` // aligned with target_output_schema + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TargetScore) Reset() { + *x = TargetScore{} + mi := &file_predict_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TargetScore) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TargetScore) ProtoMessage() {} + +func (x *TargetScore) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TargetScore.ProtoReflect.Descriptor instead. +func (*TargetScore) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{3} +} + +func (x *TargetScore) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +func (x *TargetScore) GetOutputValues() [][]byte { + if x != nil { + return x.OutputValues + } + return nil +} + +type PointWiseRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ModelConfigId string `protobuf:"bytes,1,opt,name=model_config_id,json=modelConfigId,proto3" json:"model_config_id,omitempty"` + TrackingId string `protobuf:"bytes,2,opt,name=tracking_id,json=trackingId,proto3" json:"tracking_id,omitempty"` + ContextFeatures []*ContextFeature `protobuf:"bytes,3,rep,name=context_features,json=contextFeatures,proto3" json:"context_features,omitempty"` + TargetInputSchema []*FeatureSchema `protobuf:"bytes,4,rep,name=target_input_schema,json=targetInputSchema,proto3" json:"target_input_schema,omitempty"` + Targets []*Target `protobuf:"bytes,5,rep,name=targets,proto3" json:"targets,omitempty"` + TenantId string `protobuf:"bytes,6,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PointWiseRequest) Reset() { + *x = PointWiseRequest{} + mi := &file_predict_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PointWiseRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PointWiseRequest) ProtoMessage() {} + +func (x *PointWiseRequest) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PointWiseRequest.ProtoReflect.Descriptor instead. +func (*PointWiseRequest) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{4} +} + +func (x *PointWiseRequest) GetModelConfigId() string { + if x != nil { + return x.ModelConfigId + } + return "" +} + +func (x *PointWiseRequest) GetTrackingId() string { + if x != nil { + return x.TrackingId + } + return "" +} + +func (x *PointWiseRequest) GetContextFeatures() []*ContextFeature { + if x != nil { + return x.ContextFeatures + } + return nil +} + +func (x *PointWiseRequest) GetTargetInputSchema() []*FeatureSchema { + if x != nil { + return x.TargetInputSchema + } + return nil +} + +func (x *PointWiseRequest) GetTargets() []*Target { + if x != nil { + return x.Targets + } + return nil +} + +func (x *PointWiseRequest) GetTenantId() string { + if x != nil { + return x.TenantId + } + return "" +} + +type PointWiseResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + TargetOutputSchema []*FeatureSchema `protobuf:"bytes,1,rep,name=target_output_schema,json=targetOutputSchema,proto3" json:"target_output_schema,omitempty"` + TargetScores []*TargetScore `protobuf:"bytes,2,rep,name=target_scores,json=targetScores,proto3" json:"target_scores,omitempty"` + RequestError string `protobuf:"bytes,3,opt,name=request_error,json=requestError,proto3" json:"request_error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PointWiseResponse) Reset() { + *x = PointWiseResponse{} + mi := &file_predict_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PointWiseResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PointWiseResponse) ProtoMessage() {} + +func (x *PointWiseResponse) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PointWiseResponse.ProtoReflect.Descriptor instead. +func (*PointWiseResponse) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{5} +} + +func (x *PointWiseResponse) GetTargetOutputSchema() []*FeatureSchema { + if x != nil { + return x.TargetOutputSchema + } + return nil +} + +func (x *PointWiseResponse) GetTargetScores() []*TargetScore { + if x != nil { + return x.TargetScores + } + return nil +} + +func (x *PointWiseResponse) GetRequestError() string { + if x != nil { + return x.RequestError + } + return "" +} + +type TargetPair struct { + state protoimpl.MessageState `protogen:"open.v1"` + FirstTargetIndex int32 `protobuf:"varint,1,opt,name=first_target_index,json=firstTargetIndex,proto3" json:"first_target_index,omitempty"` + SecondTargetIndex int32 `protobuf:"varint,2,opt,name=second_target_index,json=secondTargetIndex,proto3" json:"second_target_index,omitempty"` + FeatureValues [][]byte `protobuf:"bytes,3,rep,name=feature_values,json=featureValues,proto3" json:"feature_values,omitempty"` // aligned with pair_input_schema + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TargetPair) Reset() { + *x = TargetPair{} + mi := &file_predict_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TargetPair) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TargetPair) ProtoMessage() {} + +func (x *TargetPair) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TargetPair.ProtoReflect.Descriptor instead. +func (*TargetPair) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{6} +} + +func (x *TargetPair) GetFirstTargetIndex() int32 { + if x != nil { + return x.FirstTargetIndex + } + return 0 +} + +func (x *TargetPair) GetSecondTargetIndex() int32 { + if x != nil { + return x.SecondTargetIndex + } + return 0 +} + +func (x *TargetPair) GetFeatureValues() [][]byte { + if x != nil { + return x.FeatureValues + } + return nil +} + +type PairWiseRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ModelConfigId string `protobuf:"bytes,1,opt,name=model_config_id,json=modelConfigId,proto3" json:"model_config_id,omitempty"` + TrackingId string `protobuf:"bytes,2,opt,name=tracking_id,json=trackingId,proto3" json:"tracking_id,omitempty"` + ContextFeatures []*ContextFeature `protobuf:"bytes,3,rep,name=context_features,json=contextFeatures,proto3" json:"context_features,omitempty"` + TargetInputSchema []*FeatureSchema `protobuf:"bytes,4,rep,name=target_input_schema,json=targetInputSchema,proto3" json:"target_input_schema,omitempty"` + PairInputSchema []*FeatureSchema `protobuf:"bytes,5,rep,name=pair_input_schema,json=pairInputSchema,proto3" json:"pair_input_schema,omitempty"` + Pairs []*TargetPair `protobuf:"bytes,6,rep,name=pairs,proto3" json:"pairs,omitempty"` + Targets []*Target `protobuf:"bytes,7,rep,name=targets,proto3" json:"targets,omitempty"` + TenantId string `protobuf:"bytes,8,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PairWiseRequest) Reset() { + *x = PairWiseRequest{} + mi := &file_predict_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PairWiseRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PairWiseRequest) ProtoMessage() {} + +func (x *PairWiseRequest) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PairWiseRequest.ProtoReflect.Descriptor instead. +func (*PairWiseRequest) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{7} +} + +func (x *PairWiseRequest) GetModelConfigId() string { + if x != nil { + return x.ModelConfigId + } + return "" +} + +func (x *PairWiseRequest) GetTrackingId() string { + if x != nil { + return x.TrackingId + } + return "" +} + +func (x *PairWiseRequest) GetContextFeatures() []*ContextFeature { + if x != nil { + return x.ContextFeatures + } + return nil +} + +func (x *PairWiseRequest) GetTargetInputSchema() []*FeatureSchema { + if x != nil { + return x.TargetInputSchema + } + return nil +} + +func (x *PairWiseRequest) GetPairInputSchema() []*FeatureSchema { + if x != nil { + return x.PairInputSchema + } + return nil +} + +func (x *PairWiseRequest) GetPairs() []*TargetPair { + if x != nil { + return x.Pairs + } + return nil +} + +func (x *PairWiseRequest) GetTargets() []*Target { + if x != nil { + return x.Targets + } + return nil +} + +func (x *PairWiseRequest) GetTenantId() string { + if x != nil { + return x.TenantId + } + return "" +} + +type PairScore struct { + state protoimpl.MessageState `protogen:"open.v1"` + Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` + OutputValues [][]byte `protobuf:"bytes,2,rep,name=output_values,json=outputValues,proto3" json:"output_values,omitempty"` // aligned with pair_output_schema + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PairScore) Reset() { + *x = PairScore{} + mi := &file_predict_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PairScore) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PairScore) ProtoMessage() {} + +func (x *PairScore) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PairScore.ProtoReflect.Descriptor instead. +func (*PairScore) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{8} +} + +func (x *PairScore) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +func (x *PairScore) GetOutputValues() [][]byte { + if x != nil { + return x.OutputValues + } + return nil +} + +type PairWiseResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + PairScores []*PairScore `protobuf:"bytes,1,rep,name=pair_scores,json=pairScores,proto3" json:"pair_scores,omitempty"` + TargetScores []*TargetScore `protobuf:"bytes,2,rep,name=target_scores,json=targetScores,proto3" json:"target_scores,omitempty"` + TargetOutputSchema []*FeatureSchema `protobuf:"bytes,3,rep,name=target_output_schema,json=targetOutputSchema,proto3" json:"target_output_schema,omitempty"` + PairOutputSchema []*FeatureSchema `protobuf:"bytes,4,rep,name=pair_output_schema,json=pairOutputSchema,proto3" json:"pair_output_schema,omitempty"` + RequestError string `protobuf:"bytes,5,opt,name=request_error,json=requestError,proto3" json:"request_error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PairWiseResponse) Reset() { + *x = PairWiseResponse{} + mi := &file_predict_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PairWiseResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PairWiseResponse) ProtoMessage() {} + +func (x *PairWiseResponse) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PairWiseResponse.ProtoReflect.Descriptor instead. +func (*PairWiseResponse) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{9} +} + +func (x *PairWiseResponse) GetPairScores() []*PairScore { + if x != nil { + return x.PairScores + } + return nil +} + +func (x *PairWiseResponse) GetTargetScores() []*TargetScore { + if x != nil { + return x.TargetScores + } + return nil +} + +func (x *PairWiseResponse) GetTargetOutputSchema() []*FeatureSchema { + if x != nil { + return x.TargetOutputSchema + } + return nil +} + +func (x *PairWiseResponse) GetPairOutputSchema() []*FeatureSchema { + if x != nil { + return x.PairOutputSchema + } + return nil +} + +func (x *PairWiseResponse) GetRequestError() string { + if x != nil { + return x.RequestError + } + return "" +} + +type TargetSlate struct { + state protoimpl.MessageState `protogen:"open.v1"` + TargetIndices []int32 `protobuf:"varint,1,rep,packed,name=target_indices,json=targetIndices,proto3" json:"target_indices,omitempty"` + FeatureValues [][]byte `protobuf:"bytes,2,rep,name=feature_values,json=featureValues,proto3" json:"feature_values,omitempty"` // aligned with slate_input_schema + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TargetSlate) Reset() { + *x = TargetSlate{} + mi := &file_predict_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TargetSlate) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TargetSlate) ProtoMessage() {} + +func (x *TargetSlate) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TargetSlate.ProtoReflect.Descriptor instead. +func (*TargetSlate) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{10} +} + +func (x *TargetSlate) GetTargetIndices() []int32 { + if x != nil { + return x.TargetIndices + } + return nil +} + +func (x *TargetSlate) GetFeatureValues() [][]byte { + if x != nil { + return x.FeatureValues + } + return nil +} + +type SlateScore struct { + state protoimpl.MessageState `protogen:"open.v1"` + Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` + OutputValues [][]byte `protobuf:"bytes,2,rep,name=output_values,json=outputValues,proto3" json:"output_values,omitempty"` // aligned with slate_output_schema + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SlateScore) Reset() { + *x = SlateScore{} + mi := &file_predict_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SlateScore) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SlateScore) ProtoMessage() {} + +func (x *SlateScore) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SlateScore.ProtoReflect.Descriptor instead. +func (*SlateScore) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{11} +} + +func (x *SlateScore) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +func (x *SlateScore) GetOutputValues() [][]byte { + if x != nil { + return x.OutputValues + } + return nil +} + +type SlateWiseRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ModelConfigId string `protobuf:"bytes,1,opt,name=model_config_id,json=modelConfigId,proto3" json:"model_config_id,omitempty"` + TrackingId string `protobuf:"bytes,2,opt,name=tracking_id,json=trackingId,proto3" json:"tracking_id,omitempty"` + ContextFeatures []*ContextFeature `protobuf:"bytes,3,rep,name=context_features,json=contextFeatures,proto3" json:"context_features,omitempty"` + TargetInputSchema []*FeatureSchema `protobuf:"bytes,4,rep,name=target_input_schema,json=targetInputSchema,proto3" json:"target_input_schema,omitempty"` + SlateInputSchema []*FeatureSchema `protobuf:"bytes,5,rep,name=slate_input_schema,json=slateInputSchema,proto3" json:"slate_input_schema,omitempty"` + Slates []*TargetSlate `protobuf:"bytes,6,rep,name=slates,proto3" json:"slates,omitempty"` + Targets []*Target `protobuf:"bytes,7,rep,name=targets,proto3" json:"targets,omitempty"` + TenantId string `protobuf:"bytes,8,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SlateWiseRequest) Reset() { + *x = SlateWiseRequest{} + mi := &file_predict_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SlateWiseRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SlateWiseRequest) ProtoMessage() {} + +func (x *SlateWiseRequest) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SlateWiseRequest.ProtoReflect.Descriptor instead. +func (*SlateWiseRequest) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{12} +} + +func (x *SlateWiseRequest) GetModelConfigId() string { + if x != nil { + return x.ModelConfigId + } + return "" +} + +func (x *SlateWiseRequest) GetTrackingId() string { + if x != nil { + return x.TrackingId + } + return "" +} + +func (x *SlateWiseRequest) GetContextFeatures() []*ContextFeature { + if x != nil { + return x.ContextFeatures + } + return nil +} + +func (x *SlateWiseRequest) GetTargetInputSchema() []*FeatureSchema { + if x != nil { + return x.TargetInputSchema + } + return nil +} + +func (x *SlateWiseRequest) GetSlateInputSchema() []*FeatureSchema { + if x != nil { + return x.SlateInputSchema + } + return nil +} + +func (x *SlateWiseRequest) GetSlates() []*TargetSlate { + if x != nil { + return x.Slates + } + return nil +} + +func (x *SlateWiseRequest) GetTargets() []*Target { + if x != nil { + return x.Targets + } + return nil +} + +func (x *SlateWiseRequest) GetTenantId() string { + if x != nil { + return x.TenantId + } + return "" +} + +type SlateWiseResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + SlateScores []*SlateScore `protobuf:"bytes,1,rep,name=slate_scores,json=slateScores,proto3" json:"slate_scores,omitempty"` + TargetScores []*TargetScore `protobuf:"bytes,2,rep,name=target_scores,json=targetScores,proto3" json:"target_scores,omitempty"` + TargetOutputSchema []*FeatureSchema `protobuf:"bytes,3,rep,name=target_output_schema,json=targetOutputSchema,proto3" json:"target_output_schema,omitempty"` + SlateOutputSchema []*FeatureSchema `protobuf:"bytes,4,rep,name=slate_output_schema,json=slateOutputSchema,proto3" json:"slate_output_schema,omitempty"` + RequestError string `protobuf:"bytes,5,opt,name=request_error,json=requestError,proto3" json:"request_error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SlateWiseResponse) Reset() { + *x = SlateWiseResponse{} + mi := &file_predict_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SlateWiseResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SlateWiseResponse) ProtoMessage() {} + +func (x *SlateWiseResponse) ProtoReflect() protoreflect.Message { + mi := &file_predict_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SlateWiseResponse.ProtoReflect.Descriptor instead. +func (*SlateWiseResponse) Descriptor() ([]byte, []int) { + return file_predict_proto_rawDescGZIP(), []int{13} +} + +func (x *SlateWiseResponse) GetSlateScores() []*SlateScore { + if x != nil { + return x.SlateScores + } + return nil +} + +func (x *SlateWiseResponse) GetTargetScores() []*TargetScore { + if x != nil { + return x.TargetScores + } + return nil +} + +func (x *SlateWiseResponse) GetTargetOutputSchema() []*FeatureSchema { + if x != nil { + return x.TargetOutputSchema + } + return nil +} + +func (x *SlateWiseResponse) GetSlateOutputSchema() []*FeatureSchema { + if x != nil { + return x.SlateOutputSchema + } + return nil +} + +func (x *SlateWiseResponse) GetRequestError() string { + if x != nil { + return x.RequestError + } + return "" +} + +var File_predict_proto protoreflect.FileDescriptor + +const file_predict_proto_rawDesc = "" + + "\n" + + "\rpredict.proto\"j\n" + + "\rFeatureSchema\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12&\n" + + "\tdata_type\x18\x02 \x01(\x0e2\t.DataTypeR\bdataType\x12\x1d\n" + + "\n" + + "vector_dim\x18\x03 \x01(\x05R\tvectorDim\"\x81\x01\n" + + "\x0eContextFeature\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n" + + "\x05value\x18\x02 \x01(\fR\x05value\x12&\n" + + "\tdata_type\x18\x03 \x01(\x0e2\t.DataTypeR\bdataType\x12\x1d\n" + + "\n" + + "vector_dim\x18\x04 \x01(\x05R\tvectorDim\"?\n" + + "\x06Target\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12%\n" + + "\x0efeature_values\x18\x02 \x03(\fR\rfeatureValues\"H\n" + + "\vTargetScore\x12\x14\n" + + "\x05error\x18\x01 \x01(\tR\x05error\x12#\n" + + "\routput_values\x18\x02 \x03(\fR\foutputValues\"\x97\x02\n" + + "\x10PointWiseRequest\x12&\n" + + "\x0fmodel_config_id\x18\x01 \x01(\tR\rmodelConfigId\x12\x1f\n" + + "\vtracking_id\x18\x02 \x01(\tR\n" + + "trackingId\x12:\n" + + "\x10context_features\x18\x03 \x03(\v2\x0f.ContextFeatureR\x0fcontextFeatures\x12>\n" + + "\x13target_input_schema\x18\x04 \x03(\v2\x0e.FeatureSchemaR\x11targetInputSchema\x12!\n" + + "\atargets\x18\x05 \x03(\v2\a.TargetR\atargets\x12\x1b\n" + + "\ttenant_id\x18\x06 \x01(\tR\btenantId\"\xad\x01\n" + + "\x11PointWiseResponse\x12@\n" + + "\x14target_output_schema\x18\x01 \x03(\v2\x0e.FeatureSchemaR\x12targetOutputSchema\x121\n" + + "\rtarget_scores\x18\x02 \x03(\v2\f.TargetScoreR\ftargetScores\x12#\n" + + "\rrequest_error\x18\x03 \x01(\tR\frequestError\"\x91\x01\n" + + "\n" + + "TargetPair\x12,\n" + + "\x12first_target_index\x18\x01 \x01(\x05R\x10firstTargetIndex\x12.\n" + + "\x13second_target_index\x18\x02 \x01(\x05R\x11secondTargetIndex\x12%\n" + + "\x0efeature_values\x18\x03 \x03(\fR\rfeatureValues\"\xf5\x02\n" + + "\x0fPairWiseRequest\x12&\n" + + "\x0fmodel_config_id\x18\x01 \x01(\tR\rmodelConfigId\x12\x1f\n" + + "\vtracking_id\x18\x02 \x01(\tR\n" + + "trackingId\x12:\n" + + "\x10context_features\x18\x03 \x03(\v2\x0f.ContextFeatureR\x0fcontextFeatures\x12>\n" + + "\x13target_input_schema\x18\x04 \x03(\v2\x0e.FeatureSchemaR\x11targetInputSchema\x12:\n" + + "\x11pair_input_schema\x18\x05 \x03(\v2\x0e.FeatureSchemaR\x0fpairInputSchema\x12!\n" + + "\x05pairs\x18\x06 \x03(\v2\v.TargetPairR\x05pairs\x12!\n" + + "\atargets\x18\a \x03(\v2\a.TargetR\atargets\x12\x1b\n" + + "\ttenant_id\x18\b \x01(\tR\btenantId\"F\n" + + "\tPairScore\x12\x14\n" + + "\x05error\x18\x01 \x01(\tR\x05error\x12#\n" + + "\routput_values\x18\x02 \x03(\fR\foutputValues\"\x97\x02\n" + + "\x10PairWiseResponse\x12+\n" + + "\vpair_scores\x18\x01 \x03(\v2\n" + + ".PairScoreR\n" + + "pairScores\x121\n" + + "\rtarget_scores\x18\x02 \x03(\v2\f.TargetScoreR\ftargetScores\x12@\n" + + "\x14target_output_schema\x18\x03 \x03(\v2\x0e.FeatureSchemaR\x12targetOutputSchema\x12<\n" + + "\x12pair_output_schema\x18\x04 \x03(\v2\x0e.FeatureSchemaR\x10pairOutputSchema\x12#\n" + + "\rrequest_error\x18\x05 \x01(\tR\frequestError\"[\n" + + "\vTargetSlate\x12%\n" + + "\x0etarget_indices\x18\x01 \x03(\x05R\rtargetIndices\x12%\n" + + "\x0efeature_values\x18\x02 \x03(\fR\rfeatureValues\"G\n" + + "\n" + + "SlateScore\x12\x14\n" + + "\x05error\x18\x01 \x01(\tR\x05error\x12#\n" + + "\routput_values\x18\x02 \x03(\fR\foutputValues\"\xfb\x02\n" + + "\x10SlateWiseRequest\x12&\n" + + "\x0fmodel_config_id\x18\x01 \x01(\tR\rmodelConfigId\x12\x1f\n" + + "\vtracking_id\x18\x02 \x01(\tR\n" + + "trackingId\x12:\n" + + "\x10context_features\x18\x03 \x03(\v2\x0f.ContextFeatureR\x0fcontextFeatures\x12>\n" + + "\x13target_input_schema\x18\x04 \x03(\v2\x0e.FeatureSchemaR\x11targetInputSchema\x12<\n" + + "\x12slate_input_schema\x18\x05 \x03(\v2\x0e.FeatureSchemaR\x10slateInputSchema\x12$\n" + + "\x06slates\x18\x06 \x03(\v2\f.TargetSlateR\x06slates\x12!\n" + + "\atargets\x18\a \x03(\v2\a.TargetR\atargets\x12\x1b\n" + + "\ttenant_id\x18\b \x01(\tR\btenantId\"\x9d\x02\n" + + "\x11SlateWiseResponse\x12.\n" + + "\fslate_scores\x18\x01 \x03(\v2\v.SlateScoreR\vslateScores\x121\n" + + "\rtarget_scores\x18\x02 \x03(\v2\f.TargetScoreR\ftargetScores\x12@\n" + + "\x14target_output_schema\x18\x03 \x03(\v2\x0e.FeatureSchemaR\x12targetOutputSchema\x12>\n" + + "\x13slate_output_schema\x18\x04 \x03(\v2\x0e.FeatureSchemaR\x11slateOutputSchema\x12#\n" + + "\rrequest_error\x18\x05 \x01(\tR\frequestError*\xb9\x05\n" + + "\bDataType\x12\x13\n" + + "\x0fDataTypeUnknown\x10\x00\x12\x13\n" + + "\x0fDataTypeFP8E5M2\x10\x01\x12\x13\n" + + "\x0fDataTypeFP8E4M3\x10\x02\x12\x10\n" + + "\fDataTypeFP16\x10\x03\x12\x10\n" + + "\fDataTypeFP32\x10\x04\x12\x10\n" + + "\fDataTypeFP64\x10\x05\x12\x10\n" + + "\fDataTypeInt8\x10\x06\x12\x11\n" + + "\rDataTypeInt16\x10\a\x12\x11\n" + + "\rDataTypeInt32\x10\b\x12\x11\n" + + "\rDataTypeInt64\x10\t\x12\x11\n" + + "\rDataTypeUint8\x10\n" + + "\x12\x12\n" + + "\x0eDataTypeUint16\x10\v\x12\x12\n" + + "\x0eDataTypeUint32\x10\f\x12\x12\n" + + "\x0eDataTypeUint64\x10\r\x12\x12\n" + + "\x0eDataTypeString\x10\x0e\x12\x10\n" + + "\fDataTypeBool\x10\x0f\x12\x19\n" + + "\x15DataTypeFP8E5M2Vector\x10\x10\x12\x19\n" + + "\x15DataTypeFP8E4M3Vector\x10\x11\x12\x16\n" + + "\x12DataTypeFP16Vector\x10\x12\x12\x16\n" + + "\x12DataTypeFP32Vector\x10\x13\x12\x16\n" + + "\x12DataTypeFP64Vector\x10\x14\x12\x16\n" + + "\x12DataTypeInt8Vector\x10\x15\x12\x17\n" + + "\x13DataTypeInt16Vector\x10\x16\x12\x17\n" + + "\x13DataTypeInt32Vector\x10\x17\x12\x17\n" + + "\x13DataTypeInt64Vector\x10\x18\x12\x17\n" + + "\x13DataTypeUint8Vector\x10\x19\x12\x18\n" + + "\x14DataTypeUint16Vector\x10\x1a\x12\x18\n" + + "\x14DataTypeUint32Vector\x10\x1b\x12\x18\n" + + "\x14DataTypeUint64Vector\x10\x1c\x12\x18\n" + + "\x14DataTypeStringVector\x10\x1d\x12\x16\n" + + "\x12DataTypeBoolVector\x10\x1e2\xb7\x01\n" + + "\aPredict\x129\n" + + "\x0eInferPointWise\x12\x11.PointWiseRequest\x1a\x12.PointWiseResponse\"\x00\x126\n" + + "\rInferPairWise\x12\x10.PairWiseRequest\x1a\x11.PairWiseResponse\"\x00\x129\n" + + "\x0eInferSlateWise\x12\x11.SlateWiseRequest\x1a\x12.SlateWiseResponse\"\x00B\x11Z\x0f../grpc/predictb\x06proto3" + +var ( + file_predict_proto_rawDescOnce sync.Once + file_predict_proto_rawDescData []byte +) + +func file_predict_proto_rawDescGZIP() []byte { + file_predict_proto_rawDescOnce.Do(func() { + file_predict_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_predict_proto_rawDesc), len(file_predict_proto_rawDesc))) + }) + return file_predict_proto_rawDescData +} + +var file_predict_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_predict_proto_msgTypes = make([]protoimpl.MessageInfo, 14) +var file_predict_proto_goTypes = []any{ + (DataType)(0), // 0: DataType + (*FeatureSchema)(nil), // 1: FeatureSchema + (*ContextFeature)(nil), // 2: ContextFeature + (*Target)(nil), // 3: Target + (*TargetScore)(nil), // 4: TargetScore + (*PointWiseRequest)(nil), // 5: PointWiseRequest + (*PointWiseResponse)(nil), // 6: PointWiseResponse + (*TargetPair)(nil), // 7: TargetPair + (*PairWiseRequest)(nil), // 8: PairWiseRequest + (*PairScore)(nil), // 9: PairScore + (*PairWiseResponse)(nil), // 10: PairWiseResponse + (*TargetSlate)(nil), // 11: TargetSlate + (*SlateScore)(nil), // 12: SlateScore + (*SlateWiseRequest)(nil), // 13: SlateWiseRequest + (*SlateWiseResponse)(nil), // 14: SlateWiseResponse +} +var file_predict_proto_depIdxs = []int32{ + 0, // 0: FeatureSchema.data_type:type_name -> DataType + 0, // 1: ContextFeature.data_type:type_name -> DataType + 2, // 2: PointWiseRequest.context_features:type_name -> ContextFeature + 1, // 3: PointWiseRequest.target_input_schema:type_name -> FeatureSchema + 3, // 4: PointWiseRequest.targets:type_name -> Target + 1, // 5: PointWiseResponse.target_output_schema:type_name -> FeatureSchema + 4, // 6: PointWiseResponse.target_scores:type_name -> TargetScore + 2, // 7: PairWiseRequest.context_features:type_name -> ContextFeature + 1, // 8: PairWiseRequest.target_input_schema:type_name -> FeatureSchema + 1, // 9: PairWiseRequest.pair_input_schema:type_name -> FeatureSchema + 7, // 10: PairWiseRequest.pairs:type_name -> TargetPair + 3, // 11: PairWiseRequest.targets:type_name -> Target + 9, // 12: PairWiseResponse.pair_scores:type_name -> PairScore + 4, // 13: PairWiseResponse.target_scores:type_name -> TargetScore + 1, // 14: PairWiseResponse.target_output_schema:type_name -> FeatureSchema + 1, // 15: PairWiseResponse.pair_output_schema:type_name -> FeatureSchema + 2, // 16: SlateWiseRequest.context_features:type_name -> ContextFeature + 1, // 17: SlateWiseRequest.target_input_schema:type_name -> FeatureSchema + 1, // 18: SlateWiseRequest.slate_input_schema:type_name -> FeatureSchema + 11, // 19: SlateWiseRequest.slates:type_name -> TargetSlate + 3, // 20: SlateWiseRequest.targets:type_name -> Target + 12, // 21: SlateWiseResponse.slate_scores:type_name -> SlateScore + 4, // 22: SlateWiseResponse.target_scores:type_name -> TargetScore + 1, // 23: SlateWiseResponse.target_output_schema:type_name -> FeatureSchema + 1, // 24: SlateWiseResponse.slate_output_schema:type_name -> FeatureSchema + 5, // 25: Predict.InferPointWise:input_type -> PointWiseRequest + 8, // 26: Predict.InferPairWise:input_type -> PairWiseRequest + 13, // 27: Predict.InferSlateWise:input_type -> SlateWiseRequest + 6, // 28: Predict.InferPointWise:output_type -> PointWiseResponse + 10, // 29: Predict.InferPairWise:output_type -> PairWiseResponse + 14, // 30: Predict.InferSlateWise:output_type -> SlateWiseResponse + 28, // [28:31] is the sub-list for method output_type + 25, // [25:28] is the sub-list for method input_type + 25, // [25:25] is the sub-list for extension type_name + 25, // [25:25] is the sub-list for extension extendee + 0, // [0:25] is the sub-list for field type_name +} + +func init() { file_predict_proto_init() } +func file_predict_proto_init() { + if File_predict_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_predict_proto_rawDesc), len(file_predict_proto_rawDesc)), + NumEnums: 1, + NumMessages: 14, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_predict_proto_goTypes, + DependencyIndexes: file_predict_proto_depIdxs, + EnumInfos: file_predict_proto_enumTypes, + MessageInfos: file_predict_proto_msgTypes, + }.Build() + File_predict_proto = out.File + file_predict_proto_goTypes = nil + file_predict_proto_depIdxs = nil +} diff --git a/inferflow/server/grpc/predict/predict_grpc.pb.go b/inferflow/server/grpc/predict/predict_grpc.pb.go new file mode 100644 index 00000000..cc393b5b --- /dev/null +++ b/inferflow/server/grpc/predict/predict_grpc.pb.go @@ -0,0 +1,197 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc v5.28.3 +// source: predict.proto + +package predict + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + Predict_InferPointWise_FullMethodName = "/Predict/InferPointWise" + Predict_InferPairWise_FullMethodName = "/Predict/InferPairWise" + Predict_InferSlateWise_FullMethodName = "/Predict/InferSlateWise" +) + +// PredictClient is the client API for Predict service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type PredictClient interface { + InferPointWise(ctx context.Context, in *PointWiseRequest, opts ...grpc.CallOption) (*PointWiseResponse, error) + InferPairWise(ctx context.Context, in *PairWiseRequest, opts ...grpc.CallOption) (*PairWiseResponse, error) + InferSlateWise(ctx context.Context, in *SlateWiseRequest, opts ...grpc.CallOption) (*SlateWiseResponse, error) +} + +type predictClient struct { + cc grpc.ClientConnInterface +} + +func NewPredictClient(cc grpc.ClientConnInterface) PredictClient { + return &predictClient{cc} +} + +func (c *predictClient) InferPointWise(ctx context.Context, in *PointWiseRequest, opts ...grpc.CallOption) (*PointWiseResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(PointWiseResponse) + err := c.cc.Invoke(ctx, Predict_InferPointWise_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *predictClient) InferPairWise(ctx context.Context, in *PairWiseRequest, opts ...grpc.CallOption) (*PairWiseResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(PairWiseResponse) + err := c.cc.Invoke(ctx, Predict_InferPairWise_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *predictClient) InferSlateWise(ctx context.Context, in *SlateWiseRequest, opts ...grpc.CallOption) (*SlateWiseResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(SlateWiseResponse) + err := c.cc.Invoke(ctx, Predict_InferSlateWise_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// PredictServer is the server API for Predict service. +// All implementations must embed UnimplementedPredictServer +// for forward compatibility. +type PredictServer interface { + InferPointWise(context.Context, *PointWiseRequest) (*PointWiseResponse, error) + InferPairWise(context.Context, *PairWiseRequest) (*PairWiseResponse, error) + InferSlateWise(context.Context, *SlateWiseRequest) (*SlateWiseResponse, error) + mustEmbedUnimplementedPredictServer() +} + +// UnimplementedPredictServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedPredictServer struct{} + +func (UnimplementedPredictServer) InferPointWise(context.Context, *PointWiseRequest) (*PointWiseResponse, error) { + return nil, status.Error(codes.Unimplemented, "method InferPointWise not implemented") +} +func (UnimplementedPredictServer) InferPairWise(context.Context, *PairWiseRequest) (*PairWiseResponse, error) { + return nil, status.Error(codes.Unimplemented, "method InferPairWise not implemented") +} +func (UnimplementedPredictServer) InferSlateWise(context.Context, *SlateWiseRequest) (*SlateWiseResponse, error) { + return nil, status.Error(codes.Unimplemented, "method InferSlateWise not implemented") +} +func (UnimplementedPredictServer) mustEmbedUnimplementedPredictServer() {} +func (UnimplementedPredictServer) testEmbeddedByValue() {} + +// UnsafePredictServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to PredictServer will +// result in compilation errors. +type UnsafePredictServer interface { + mustEmbedUnimplementedPredictServer() +} + +func RegisterPredictServer(s grpc.ServiceRegistrar, srv PredictServer) { + // If the following call panics, it indicates UnimplementedPredictServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&Predict_ServiceDesc, srv) +} + +func _Predict_InferPointWise_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PointWiseRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(PredictServer).InferPointWise(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Predict_InferPointWise_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(PredictServer).InferPointWise(ctx, req.(*PointWiseRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Predict_InferPairWise_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PairWiseRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(PredictServer).InferPairWise(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Predict_InferPairWise_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(PredictServer).InferPairWise(ctx, req.(*PairWiseRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Predict_InferSlateWise_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SlateWiseRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(PredictServer).InferSlateWise(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Predict_InferSlateWise_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(PredictServer).InferSlateWise(ctx, req.(*SlateWiseRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Predict_ServiceDesc is the grpc.ServiceDesc for Predict service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Predict_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "Predict", + HandlerType: (*PredictServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "InferPointWise", + Handler: _Predict_InferPointWise_Handler, + }, + { + MethodName: "InferPairWise", + Handler: _Predict_InferPairWise_Handler, + }, + { + MethodName: "InferSlateWise", + Handler: _Predict_InferSlateWise_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "predict.proto", +} diff --git a/inferflow/server/proto/inferflow_logging.proto b/inferflow/server/proto/inferflow_logging.proto new file mode 100644 index 00000000..65f536a1 --- /dev/null +++ b/inferflow/server/proto/inferflow_logging.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +option go_package = "../grpc"; + +// InferflowLog is the message structure for inference logging +message InferflowLog { + string user_id = 1; + string tracking_id = 2; + string model_config_id = 3; + repeated string entities = 4; + repeated PerEntityFeatures features = 5; + bytes metadata = 6; + repeated string parent_entity = 7; +} + +message PerEntityFeatures { + bytes encoded_features = 1; +} diff --git a/inferflow/server/proto/kafka_request.proto b/inferflow/server/proto/kafka_request.proto new file mode 100644 index 00000000..e51e1fe0 --- /dev/null +++ b/inferflow/server/proto/kafka_request.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +import "google/protobuf/any.proto"; + +option go_package = "../grpc"; + +message KafkaRequest { + KafkaEventValue value = 1; + google.protobuf.Any headers = 2; +} + +message KafkaEventValue { + string event_name = 1; + string event_id = 2; + int64 created_at = 3; + google.protobuf.Any properties = 4; + string user_id = 5; +} diff --git a/inferflow/server/proto/predict.proto b/inferflow/server/proto/predict.proto new file mode 100644 index 00000000..c6597183 --- /dev/null +++ b/inferflow/server/proto/predict.proto @@ -0,0 +1,151 @@ +syntax = "proto3"; + +option go_package = "../grpc/predict"; + +enum DataType { + DataTypeUnknown = 0; + DataTypeFP8E5M2 = 1; + DataTypeFP8E4M3 = 2; + DataTypeFP16 = 3; + DataTypeFP32 = 4; + DataTypeFP64 = 5; + DataTypeInt8 = 6; + DataTypeInt16 = 7; + DataTypeInt32 = 8; + DataTypeInt64 = 9; + DataTypeUint8 = 10; + DataTypeUint16 = 11; + DataTypeUint32 = 12; + DataTypeUint64 = 13; + DataTypeString = 14; + DataTypeBool = 15; + DataTypeFP8E5M2Vector = 16; + DataTypeFP8E4M3Vector = 17; + DataTypeFP16Vector = 18; + DataTypeFP32Vector = 19; + DataTypeFP64Vector = 20; + DataTypeInt8Vector = 21; + DataTypeInt16Vector = 22; + DataTypeInt32Vector = 23; + DataTypeInt64Vector = 24; + DataTypeUint8Vector = 25; + DataTypeUint16Vector = 26; + DataTypeUint32Vector = 27; + DataTypeUint64Vector = 28; + DataTypeStringVector = 29; + DataTypeBoolVector = 30; +} + +// Schema definition for a feature column +message FeatureSchema { + string name = 1; + DataType data_type = 2; + int32 vector_dim = 3; // 0 = scalar, >0 = fixed vector length +} + +// A request-level context feature (user, session, device, etc.) +message ContextFeature { + string name = 1; + bytes value = 2; + DataType data_type = 3; + int32 vector_dim = 4; // 0 = scalar, >0 = fixed vector length +} + +// A single entity to be scored/ranked +message Target { + string id = 1; + repeated bytes feature_values = 2; // aligned with target_input_schema +} + +// --- PointWise --- + +message TargetScore { + string error = 1; + repeated bytes output_values = 2; // aligned with target_output_schema +} + +message PointWiseRequest { + string model_config_id = 1; + string tracking_id = 2; + repeated ContextFeature context_features = 3; + repeated FeatureSchema target_input_schema = 4; + repeated Target targets = 5; + string tenant_id = 6; +} + +message PointWiseResponse { + repeated FeatureSchema target_output_schema = 1; + repeated TargetScore target_scores = 2; + string request_error = 3; +} + +// --- PairWise --- + +message TargetPair { + int32 first_target_index = 1; + int32 second_target_index = 2; + repeated bytes feature_values = 3; // aligned with pair_input_schema +} + +message PairWiseRequest { + string model_config_id = 1; + string tracking_id = 2; + repeated ContextFeature context_features = 3; + repeated FeatureSchema target_input_schema = 4; + repeated FeatureSchema pair_input_schema = 5; + repeated TargetPair pairs = 6; + repeated Target targets = 7; + string tenant_id = 8; +} + +message PairScore { + string error = 1; + repeated bytes output_values = 2; // aligned with pair_output_schema +} + +message PairWiseResponse { + repeated PairScore pair_scores = 1; + repeated TargetScore target_scores = 2; + repeated FeatureSchema target_output_schema = 3; + repeated FeatureSchema pair_output_schema = 4; + string request_error = 5; +} + +// --- SlateWise --- + +message TargetSlate { + repeated int32 target_indices = 1; + repeated bytes feature_values = 2; // aligned with slate_input_schema +} + +message SlateScore { + string error = 1; + repeated bytes output_values = 2; // aligned with slate_output_schema +} + +message SlateWiseRequest { + string model_config_id = 1; + string tracking_id = 2; + repeated ContextFeature context_features = 3; + repeated FeatureSchema target_input_schema = 4; + repeated FeatureSchema slate_input_schema = 5; + repeated TargetSlate slates = 6; + repeated Target targets = 7; + string tenant_id = 8; +} + +message SlateWiseResponse { + repeated SlateScore slate_scores = 1; + repeated TargetScore target_scores = 2; + repeated FeatureSchema target_output_schema = 3; + repeated FeatureSchema slate_output_schema = 4; + string request_error = 5; +} + +// --- Service --- + +service Predict { + rpc InferPointWise(PointWiseRequest) returns (PointWiseResponse) {}; + rpc InferPairWise(PairWiseRequest) returns (PairWiseResponse) {}; + rpc InferSlateWise(SlateWiseRequest) returns (SlateWiseResponse) {}; +} diff --git a/quick-start/docker-compose.yml b/quick-start/docker-compose.yml index 6cf1f4cc..d40b5856 100644 --- a/quick-start/docker-compose.yml +++ b/quick-start/docker-compose.yml @@ -136,6 +136,7 @@ services: # Create ONFS topic /opt/kafka/bin/kafka-topics.sh --bootstrap-server broker:29092 --create --if-not-exists --topic online-feature-store.feature_ingestion --partitions 1 --replication-factor 1 + /opt/kafka/bin/kafka-topics.sh --bootstrap-server broker:29092 --create --if-not-exists --topic inferflow_inference_logs --partitions 1 --replication-factor 1 # Create Skye topics /opt/kafka/bin/kafka-topics.sh --bootstrap-server broker:29092 --create --if-not-exists --topic skye.model-state --partitions 1 --replication-factor 1 @@ -584,6 +585,8 @@ services: - EXTERNAL_SERVICE_PREDATOR_CALLER_TOKEN=inferflow - EXTERNAL_SERVICE_PREDATOR_DEADLINE=200 - METRIC_SAMPLING_RATE=1 + - KAFKA_BOOTSTRAP_SERVERS=broker:29092 + - KAFKA_LOGGING_TOPIC=inferflow_inference_logs networks: onfs-network: ipv4_address: 172.18.0.18