From 363041b73e4067bf62de0a116aeabf5982e7f1fe Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Fri, 30 Jan 2026 18:02:58 +0530 Subject: [PATCH 01/24] internal complete code migration --- .pre-commit-config.yaml | 2 +- .../inferflow/controller/controller.go | 13 + horizon/internal/inferflow/etcd/config.go | 3 + horizon/internal/inferflow/etcd/etcd.go | 37 +- horizon/internal/inferflow/etcd/models.go | 30 +- horizon/internal/inferflow/handler/adaptor.go | 184 +++- horizon/internal/inferflow/handler/config.go | 1 + .../inferflow/handler/config_builder.go | 80 +- .../internal/inferflow/handler/inferflow.go | 847 +++++++++++++++--- horizon/internal/inferflow/handler/models.go | 40 +- .../inferflow/handler/schema_adapter.go | 267 ++++++ horizon/internal/inferflow/route/router.go | 1 + .../repositories/sql/discoveryconfig/sql.go | 32 + .../sql/inferflow/config/repository.go | 57 +- .../sql/inferflow/config/table.go | 1 + .../repositories/sql/inferflow/models.go | 35 +- .../sql/inferflow/request/repository.go | 30 +- .../sql/servicedeployableconfig/sql.go | 10 + .../sql/servicedeployableconfig/table.go | 2 + horizon/pkg/configschemaclient/client.go | 241 +++++ horizon/pkg/configschemaclient/types.go | 119 +++ inferflow/handlers/inferflow/inferflow.go | 2 +- trufflehog/trufflehog-hook.sh | 45 - 23 files changed, 1831 insertions(+), 248 deletions(-) create mode 100644 horizon/internal/inferflow/handler/schema_adapter.go create mode 100644 horizon/pkg/configschemaclient/client.go create mode 100644 horizon/pkg/configschemaclient/types.go delete mode 100755 trufflehog/trufflehog-hook.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c721100c..e1fccdbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,6 @@ repos: - id: trufflehog name: TruffleHog description: Detect secrets in your data. - entry: "trufflehog/trufflehog-hook.sh" + entry: "pre-commit-scripts/runner.sh" language: script stages: ["pre-commit", "pre-push"] diff --git a/horizon/internal/inferflow/controller/controller.go b/horizon/internal/inferflow/controller/controller.go index 0d62160f..3696e95c 100644 --- a/horizon/internal/inferflow/controller/controller.go +++ b/horizon/internal/inferflow/controller/controller.go @@ -26,6 +26,7 @@ type Config interface { ExecuteFuncitonalTestRequest(ctx *gin.Context) GetLatestRequest(ctx *gin.Context) GetLoggingTTL(ctx *gin.Context) + GetFeatureSchema(ctx *gin.Context) } var ( @@ -365,3 +366,15 @@ func (c *V1) GetLoggingTTL(ctx *gin.Context) { } ctx.JSON(200, response) } + +func (c *V1) GetFeatureSchema(ctx *gin.Context) { + response, err := c.Config.GetFeatureSchema(handler.FeatureSchemaRequest{ + ModelConfigId: ctx.Query("model_config_id"), + Version: strings.TrimSpace(ctx.Query("version")), + }) + if err != nil { + ctx.JSON(api.NewBadRequestError(err.Error()).StatusCode, "Error getting feature schema") + return + } + ctx.JSON(200, response) +} diff --git a/horizon/internal/inferflow/etcd/config.go b/horizon/internal/inferflow/etcd/config.go index 40807c87..840533a9 100644 --- a/horizon/internal/inferflow/etcd/config.go +++ b/horizon/internal/inferflow/etcd/config.go @@ -1,8 +1,11 @@ package etcd +import mapset "github.com/deckarep/golang-set/v2" + type Manager interface { GetComponentData(componentName string) *ComponentData CreateConfig(serviceName string, ConfigId string, InferflowConfig InferflowConfig) error UpdateConfig(serviceName string, ConfigId string, InferflowConfig InferflowConfig) error DeleteConfig(serviceName string, ConfigId string) error + GetConfiguredEndpoints(serviceDeployableName string) mapset.Set[string] } diff --git a/horizon/internal/inferflow/etcd/etcd.go b/horizon/internal/inferflow/etcd/etcd.go index 65fe166b..026f7c6c 100644 --- a/horizon/internal/inferflow/etcd/etcd.go +++ b/horizon/internal/inferflow/etcd/etcd.go @@ -3,8 +3,10 @@ package etcd import ( "encoding/json" "fmt" + "strings" "github.com/Meesho/BharatMLStack/horizon/internal/inferflow" + mapset "github.com/deckarep/golang-set/v2" "github.com/Meesho/BharatMLStack/horizon/pkg/etcd" "github.com/rs/zerolog/log" @@ -17,6 +19,10 @@ type Etcd struct { env string } +const ( + commaDelimiter = "," +) + func NewEtcdInstance() *Etcd { return &Etcd{ inferflowInstance: etcd.Instance()[inferflow.InferflowAppName], @@ -37,7 +43,7 @@ func (e *Etcd) GetInferflowEtcdInstance() *ModelConfigRegistery { func (e *Etcd) GetHorizonEtcdInstance() *HorizonRegistry { instance, ok := e.horizonInstance.GetConfigInstance().(*HorizonRegistry) if !ok { - log.Panic().Msg("invalid etcd instanc e") + log.Panic().Msg("invalid etcd instance") } return instance } @@ -80,3 +86,32 @@ func (e *Etcd) UpdateConfig(serviceName string, ConfigId string, InferflowConfig func (e *Etcd) DeleteConfig(serviceName string, ConfigId string) error { return e.inferflowInstance.DeleteNode(fmt.Sprintf("/config/%s/services/%s/model-config/config-map/%s", e.appName, serviceName, ConfigId)) } + +func (e *Etcd) GetConfiguredEndpoints(serviceDeployableName string) mapset.Set[string] { + validEndpoints := mapset.NewSet[string]() + instance := e.GetInferflowEtcdInstance() + if instance == nil { + return validEndpoints + } + + inferflowConfig, exists := instance.InferflowConfig[serviceDeployableName] + if !exists { + log.Warn().Msgf("service '%s' not found in etcd registry", serviceDeployableName) + return validEndpoints + } + + predatorHosts := inferflowConfig.ModelConfig.ServiceConfig.PredatorHosts + if predatorHosts == "" { + return validEndpoints + } + + endpoints := strings.Split(predatorHosts, commaDelimiter) + for i := range len(endpoints) { + cleanedEndpoint := strings.TrimSpace(endpoints[i]) + if cleanedEndpoint == "" { + continue + } + validEndpoints.Add(cleanedEndpoint) + } + return validEndpoints +} diff --git a/horizon/internal/inferflow/etcd/models.go b/horizon/internal/inferflow/etcd/models.go index b72d96fb..0c2c5e50 100644 --- a/horizon/internal/inferflow/etcd/models.go +++ b/horizon/internal/inferflow/etcd/models.go @@ -17,7 +17,12 @@ type InferflowConfigs struct { } type ModelConfigData struct { - ConfigMap map[string]InferflowConfig `json:"config-map"` + ConfigMap map[string]InferflowConfig `json:"config-map"` + ServiceConfig ServiceConfigData `json:"service-config"` +} + +type ServiceConfigData struct { + PredatorHosts string `json:"predator-hosts"` } type NumerixComponent struct { @@ -73,6 +78,14 @@ type RTPComponent struct { CompCacheEnabled bool `json:"comp_cache_enabled"` } +type SeenScoreComponent struct { + Component string `json:"component"` + ComponentID string `json:"component_id,omitempty"` + ColNamePrefix string `json:"col_name_prefix,omitempty"` + FSKeys []FSKey `json:"fs_keys"` + FSRequest *FSRequest `json:"fs_request"` +} + type FinalResponseConfig struct { LoggingPerc int `json:"logging_perc"` ModelSchemaPerc int `json:"model_schema_features_perc"` @@ -110,13 +123,14 @@ type FeatureComponent struct { } type ComponentConfig struct { - CacheEnabled bool `json:"cache_enabled"` - CacheTTL int `json:"cache_ttl"` - CacheVersion int `json:"cache_version"` - FeatureComponents []FeatureComponent `json:"feature_components"` - RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` - PredatorComponents []PredatorComponent `json:"predator_components"` - NumerixComponents []NumerixComponent `json:"numerix_components"` + CacheEnabled bool `json:"cache_enabled"` + CacheTTL int `json:"cache_ttl"` + CacheVersion int `json:"cache_version"` + FeatureComponents []FeatureComponent `json:"feature_components"` + RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` + PredatorComponents []PredatorComponent `json:"predator_components"` + NumerixComponents []NumerixComponent `json:"numerix_components"` + SeenScoreComponents []SeenScoreComponent `json:"seen_score_components"` } type DagExecutionConfig struct { diff --git a/horizon/internal/inferflow/handler/adaptor.go b/horizon/internal/inferflow/handler/adaptor.go index 52665062..2246ad8d 100644 --- a/horizon/internal/inferflow/handler/adaptor.go +++ b/horizon/internal/inferflow/handler/adaptor.go @@ -23,9 +23,11 @@ func AdaptOnboardRequestToDBPayload(req interface{}, inferflowConfig InferflowCo dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + featureComponents := AdaptToDBFeatureComponent(inferflowConfig) - dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents) + dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents, dbSeenScoreComponents) dbDagExecutionConfig := AdaptToDBDagExecutionConfig(inferflowConfig) @@ -48,15 +50,16 @@ func AdaptEditRequestToDBPayload(req interface{}, inferflowConfig InferflowConfi dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + featureComponents := AdaptToDBFeatureComponent(inferflowConfig) - dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents) + dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents, dbSeenScoreComponents) dbDagExecutionConfig := AdaptToDBDagExecutionConfig(inferflowConfig) payload.ConfigValue = AdaptToDBConfigValue(dbDagExecutionConfig, dbComponentConfig, dbResponseConfig) - payload.ConfigValue.ComponentConfig.CacheVersion = payload.ConfigValue.ComponentConfig.CacheVersion + 1 payload.RequestPayload = AdaptToDBOnboardPayload(onboardPayload) return payload, nil @@ -75,9 +78,11 @@ func AdaptCloneConfigRequestToDBPayload(req interface{}, inferflowConfig Inferfl dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + featureComponents := AdaptToDBFeatureComponent(inferflowConfig) - dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents) + dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents, dbSeenScoreComponents) dbDagExecutionConfig := AdaptToDBDagExecutionConfig(inferflowConfig) @@ -102,9 +107,11 @@ func AdaptPromoteRequestToDBPayload(req interface{}, requestPayload RequestConfi dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + featureComponents := AdaptToDBFeatureComponent(inferflowConfig) - dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents) + dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents, dbSeenScoreComponents) dbDagExecutionConfig := AdaptToDBDagExecutionConfig(inferflowConfig) @@ -115,7 +122,7 @@ func AdaptPromoteRequestToDBPayload(req interface{}, requestPayload RequestConfi return payload, nil } -func AdaptScaleUpRequestToDBPayload(req interface{}) (dbModel.Payload, error) { +func AdaptScaleUpRequestToDBPayload(req interface{}, requestPayload RequestConfig) (dbModel.Payload, error) { var payload dbModel.Payload payload.ConfigMapping = AdaptToDBConfig(req) @@ -130,13 +137,16 @@ func AdaptScaleUpRequestToDBPayload(req interface{}) (dbModel.Payload, error) { dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + featureComponents := AdaptToDBFeatureComponent(inferflowConfig) - dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents) + dbComponentConfig := AdaptToDBComponentConfig(inferflowConfig, featureComponents, dbNumerixComponents, dbPredatorComponents, dbRTPComponents, dbSeenScoreComponents) dbDagExecutionConfig := AdaptToDBDagExecutionConfig(inferflowConfig) payload.ConfigValue = AdaptToDBConfigValue(dbDagExecutionConfig, dbComponentConfig, dbResponseConfig) + payload.RequestPayload = AdaptToDBOnboardPayload(requestPayload.Payload.RequestPayload) return payload, nil } @@ -158,6 +168,7 @@ func AdaptToDBResponseConfig(inferflowConfig InferflowConfig) dbModel.ResponseCo Features: inferflowConfig.ResponseConfig.Features, LogSelectiveFeatures: inferflowConfig.ResponseConfig.LogSelectiveFeatures, LogBatchSize: inferflowConfig.ResponseConfig.LogBatchSize, + LoggingTTL: inferflowConfig.ResponseConfig.LoggingTTL, } } @@ -273,6 +284,43 @@ func AdaptToDBRTPComponent(inferflowConfig InferflowConfig) []dbModel.RTPCompone return rtpComponents } +func AdaptToDBSeenScoreComponent(inferflowConfig InferflowConfig) []dbModel.SeenScoreComponent { + var seenScoreComponents []dbModel.SeenScoreComponent + for _, seenScoreComponent := range inferflowConfig.ComponentConfig.SeenScoreComponents { + fsKeys := make([]dbModel.FSKey, len(seenScoreComponent.FSKeys)) + for i, key := range seenScoreComponent.FSKeys { + fsKeys[i] = dbModel.FSKey{ + Schema: key.Schema, + Col: key.Col, + } + } + var fsRequest *dbModel.FSRequest + if seenScoreComponent.FSRequest != nil { + fsFeatureGroups := make([]dbModel.FSFeatureGroup, len(seenScoreComponent.FSRequest.FeatureGroups)) + for i, grp := range seenScoreComponent.FSRequest.FeatureGroups { + fsFeatureGroups[i] = dbModel.FSFeatureGroup{ + Label: grp.Label, + Features: grp.Features, + DataType: grp.DataType, + } + } + fsRequest = &dbModel.FSRequest{ + Label: seenScoreComponent.FSRequest.Label, + FeatureGroups: fsFeatureGroups, + } + } + dbSeenScoreComponent := dbModel.SeenScoreComponent{ + Component: seenScoreComponent.Component, + ComponentID: seenScoreComponent.ComponentID, + ColNamePrefix: seenScoreComponent.ColNamePrefix, + FSKeys: fsKeys, + FSRequest: fsRequest, + } + seenScoreComponents = append(seenScoreComponents, dbSeenScoreComponent) + } + return seenScoreComponents +} + func AdaptToDBFeatureComponent(inferflowConfig InferflowConfig) []dbModel.FeatureComponent { var featureComponents []dbModel.FeatureComponent @@ -315,15 +363,16 @@ func AdaptToDBFeatureComponent(inferflowConfig InferflowConfig) []dbModel.Featur return featureComponents } -func AdaptToDBComponentConfig(inferflowConfig InferflowConfig, featureComponents []dbModel.FeatureComponent, NumerixComponents []dbModel.NumerixComponent, predatorComponents []dbModel.PredatorComponent, rtpComponents []dbModel.RTPComponent) dbModel.ComponentConfig { +func AdaptToDBComponentConfig(inferflowConfig InferflowConfig, featureComponents []dbModel.FeatureComponent, NumerixComponents []dbModel.NumerixComponent, predatorComponents []dbModel.PredatorComponent, rtpComponents []dbModel.RTPComponent, seenScoreComponents []dbModel.SeenScoreComponent) dbModel.ComponentConfig { return dbModel.ComponentConfig{ - CacheEnabled: inferflowConfig.ComponentConfig.CacheEnabled, - CacheTTL: inferflowConfig.ComponentConfig.CacheTTL, - CacheVersion: inferflowConfig.ComponentConfig.CacheVersion, - FeatureComponents: featureComponents, - PredatorComponents: predatorComponents, - NumerixComponents: NumerixComponents, - RTPComponents: rtpComponents, + CacheEnabled: inferflowConfig.ComponentConfig.CacheEnabled, + CacheTTL: inferflowConfig.ComponentConfig.CacheTTL, + CacheVersion: inferflowConfig.ComponentConfig.CacheVersion, + FeatureComponents: featureComponents, + PredatorComponents: predatorComponents, + NumerixComponents: NumerixComponents, + RTPComponents: rtpComponents, + SeenScoreComponents: seenScoreComponents, } } @@ -354,6 +403,7 @@ func AdaptToDBOnboardPayload(onboardPayload OnboardPayload) dbModel.OnboardPaylo Features: onboardPayload.Response.ResponseFeatures, LogSelectiveFeatures: onboardPayload.Response.LogSelectiveFeatures, LogBatchSize: onboardPayload.Response.LogBatchSize, + LoggingTTL: onboardPayload.Response.LoggingTTL, }, ConfigMapping: dbModel.ConfigMapping{ AppToken: onboardPayload.ConfigMapping.AppToken, @@ -433,6 +483,7 @@ func AdaptFromDbToOnboardPayload(dbOnboardPayload dbModel.OnboardPayload) Onboar ResponseFeatures: dbOnboardPayload.Response.Features, LogSelectiveFeatures: dbOnboardPayload.Response.LogSelectiveFeatures, LogBatchSize: dbOnboardPayload.Response.LogBatchSize, + LoggingTTL: dbOnboardPayload.Response.LoggingTTL, }, ConfigMapping: ConfigMapping{ AppToken: dbOnboardPayload.ConfigMapping.AppToken, @@ -495,13 +546,14 @@ func AdaptFromDbToOnboardPayload(dbOnboardPayload dbModel.OnboardPayload) Onboar func AdaptFromDbToComponentConfig(dbComponentConfig dbModel.ComponentConfig) *ComponentConfig { return &ComponentConfig{ - CacheEnabled: dbComponentConfig.CacheEnabled, - CacheTTL: dbComponentConfig.CacheTTL, - CacheVersion: dbComponentConfig.CacheVersion, - FeatureComponents: AdaptFromDbToFeatureComponent(dbComponentConfig.FeatureComponents), - PredatorComponents: AdaptFromDbToPredatorComponent(dbComponentConfig.PredatorComponents), - NumerixComponents: AdaptFromDbToNumerixComponent(dbComponentConfig.NumerixComponents), - RTPComponents: AdaptFromDbToRTPComponent(dbComponentConfig.RTPComponents), + CacheEnabled: dbComponentConfig.CacheEnabled, + CacheTTL: dbComponentConfig.CacheTTL, + CacheVersion: dbComponentConfig.CacheVersion, + FeatureComponents: AdaptFromDbToFeatureComponent(dbComponentConfig.FeatureComponents), + PredatorComponents: AdaptFromDbToPredatorComponent(dbComponentConfig.PredatorComponents), + NumerixComponents: AdaptFromDbToNumerixComponent(dbComponentConfig.NumerixComponents), + RTPComponents: AdaptFromDbToRTPComponent(dbComponentConfig.RTPComponents), + SeenScoreComponents: AdaptFromDbToSeenScoreComponent(dbComponentConfig.SeenScoreComponents), } } @@ -512,6 +564,7 @@ func AdaptFromDbToResponseConfig(dbResponseConfig dbModel.ResponseConfig) *Final Features: dbResponseConfig.Features, LogSelectiveFeatures: dbResponseConfig.LogSelectiveFeatures, LogBatchSize: dbResponseConfig.LogBatchSize, + LoggingTTL: dbResponseConfig.LoggingTTL, } } @@ -626,6 +679,42 @@ func AdaptFromDbToRTPComponent(dbRTPComponents []dbModel.RTPComponent) []RTPComp return rtpComponents } +func AdaptFromDbToSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []SeenScoreComponent { + var seenScoreComponents []SeenScoreComponent + for _, seenScoreComponent := range dbSeenScoreComponents { + fsKeys := make([]FSKey, len(seenScoreComponent.FSKeys)) + for i, key := range seenScoreComponent.FSKeys { + fsKeys[i] = FSKey{ + Schema: key.Schema, + Col: key.Col, + } + } + var fsRequest *FSRequest + if seenScoreComponent.FSRequest != nil { + fsFeatureGroups := make([]FSFeatureGroup, len(seenScoreComponent.FSRequest.FeatureGroups)) + for i, grp := range seenScoreComponent.FSRequest.FeatureGroups { + fsFeatureGroups[i] = FSFeatureGroup{ + Label: grp.Label, + Features: grp.Features, + DataType: grp.DataType, + } + } + fsRequest = &FSRequest{ + Label: seenScoreComponent.FSRequest.Label, + FeatureGroups: fsFeatureGroups, + } + } + seenScoreComponents = append(seenScoreComponents, SeenScoreComponent{ + Component: seenScoreComponent.Component, + ComponentID: seenScoreComponent.ComponentID, + ColNamePrefix: seenScoreComponent.ColNamePrefix, + FSKeys: fsKeys, + FSRequest: fsRequest, + }) + } + return seenScoreComponents +} + func AdaptFromDbToFeatureComponent(dbFeatureComponents []dbModel.FeatureComponent) []FeatureComponent { var featureComponents []FeatureComponent for _, fc := range dbFeatureComponents { @@ -685,13 +774,14 @@ func AdaptToEtcdInferFlowConfig(dpConfig dbModel.InferflowConfig) etcdModel.Infe func AdaptToEtcdComponentConfig(dbComponentConfig dbModel.ComponentConfig) etcdModel.ComponentConfig { return etcdModel.ComponentConfig{ - CacheEnabled: dbComponentConfig.CacheEnabled, - CacheTTL: dbComponentConfig.CacheTTL, - CacheVersion: dbComponentConfig.CacheVersion, - FeatureComponents: AdaptToEtcdFeatureComponent(dbComponentConfig.FeatureComponents), - PredatorComponents: AdaptToEtcdPredatorComponent(dbComponentConfig.PredatorComponents), - NumerixComponents: AdaptToEtcdNumerixComponent(dbComponentConfig.NumerixComponents), - RTPComponents: AdaptToEtcdRTPComponent(dbComponentConfig.RTPComponents), + CacheEnabled: dbComponentConfig.CacheEnabled, + CacheTTL: dbComponentConfig.CacheTTL, + CacheVersion: dbComponentConfig.CacheVersion, + FeatureComponents: AdaptToEtcdFeatureComponent(dbComponentConfig.FeatureComponents), + PredatorComponents: AdaptToEtcdPredatorComponent(dbComponentConfig.PredatorComponents), + NumerixComponents: AdaptToEtcdNumerixComponent(dbComponentConfig.NumerixComponents), + RTPComponents: AdaptToEtcdRTPComponent(dbComponentConfig.RTPComponents), + SeenScoreComponents: AdaptToEtcdSeenScoreComponent(dbComponentConfig.SeenScoreComponents), } } @@ -815,6 +905,42 @@ func AdaptToEtcdRTPComponent(dbRTPComponents []dbModel.RTPComponent) []etcdModel return rtpComponents } +func AdaptToEtcdSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []etcdModel.SeenScoreComponent { + var seenScoreComponents []etcdModel.SeenScoreComponent + for _, seenScoreComponent := range dbSeenScoreComponents { + fsKeys := make([]etcdModel.FSKey, len(seenScoreComponent.FSKeys)) + for i, key := range seenScoreComponent.FSKeys { + fsKeys[i] = etcdModel.FSKey{ + Schema: key.Schema, + Col: key.Col, + } + } + var fsRequest *etcdModel.FSRequest + if seenScoreComponent.FSRequest != nil { + fsFeatureGroups := make([]etcdModel.FSFeatureGroup, len(seenScoreComponent.FSRequest.FeatureGroups)) + for i, grp := range seenScoreComponent.FSRequest.FeatureGroups { + fsFeatureGroups[i] = etcdModel.FSFeatureGroup{ + Label: grp.Label, + Features: grp.Features, + DataType: grp.DataType, + } + } + fsRequest = &etcdModel.FSRequest{ + Label: seenScoreComponent.FSRequest.Label, + FeatureGroups: fsFeatureGroups, + } + } + seenScoreComponents = append(seenScoreComponents, etcdModel.SeenScoreComponent{ + Component: seenScoreComponent.Component, + ComponentID: seenScoreComponent.ComponentID, + ColNamePrefix: seenScoreComponent.ColNamePrefix, + FSKeys: fsKeys, + FSRequest: fsRequest, + }) + } + return seenScoreComponents +} + func AdaptToEtcdFeatureComponent(dbFeatureComponents []dbModel.FeatureComponent) []etcdModel.FeatureComponent { var featureComponents []etcdModel.FeatureComponent for _, fc := range dbFeatureComponents { diff --git a/horizon/internal/inferflow/handler/config.go b/horizon/internal/inferflow/handler/config.go index 829ae74e..2c543893 100644 --- a/horizon/internal/inferflow/handler/config.go +++ b/horizon/internal/inferflow/handler/config.go @@ -16,4 +16,5 @@ type Config interface { ExecuteFuncitonalTestRequest(request ExecuteRequestFunctionalTestingRequest) (ExecuteRequestFunctionalTestingResponse, error) GetLatestRequest(requestID string) (GetLatestRequestResponse, error) GetLoggingTTL() (GetLoggingTTLResponse, error) + GetFeatureSchema(FeatureSchemaRequest) (FeatureSchemaResponse, error) } diff --git a/horizon/internal/inferflow/handler/config_builder.go b/horizon/internal/inferflow/handler/config_builder.go index 8ae9a95d..2dd64496 100644 --- a/horizon/internal/inferflow/handler/config_builder.go +++ b/horizon/internal/inferflow/handler/config_builder.go @@ -24,6 +24,7 @@ const ( RTP_FEATURE = "RTP" PCTR_CALIBRATION = "PCTR_CALIBRATION" PCVR_CALIBRATION = "PCVR_CALIBRATION" + SEEN_SCORE = "SEEN_SCORE_FEATURE" PIPE_DELIMITER = "|" UNDERSCORE_DELIMITER = "_" COLON_DELIMITER = ":" @@ -31,14 +32,56 @@ const ( featureClassOffline = "offline" featureClassOnline = "online" featureClassDefault = "default" + featureClassModel = "model" featureClassRtp = "rtp" featureClassPCVRCalibration = "pcvr_calibration" featureClassPCTRCalibration = "pctr_calibration" + featureClassSeenScore = "seen_score" featureClassInvalid = "invalid" COMPONENT_NAME_PREFIX = "composite_key_gen_" FEATURE_INITIALIZER = "feature_initializer" + SeenScoreComponentName = "product_seen_score" + SeenScoreDefaultDataType = "DataTypeString" ) +type FeatureLists struct { + allFeatureList mapset.Set[string] + + rtpFeatures, pcvrCalibrationFeatures, pctrCalibrationFeatures, seenScoreFeatures mapset.Set[string] + + featureToDataType, predatorAndIrisOutputsToDataType, offlineToOnlineMapping map[string]string +} + +type ClassifiedFeatures struct { + OfflineFeatures mapset.Set[string] + + OnlineFeatures mapset.Set[string] + + DefaultFeatures mapset.Set[string] + + RTPFeatures mapset.Set[string] + + PCTRCalibrationFeatures mapset.Set[string] + + PCVRCalibrationFeatures mapset.Set[string] + + SeenScoreFeatures mapset.Set[string] + + FeatureToDataType map[string]string +} + +type AllComponents struct { + FeatureComponents []FeatureComponent + + RTPComponents []RTPComponent + + IrisComponents []NumerixComponent + + PredatorComponents []PredatorComponent + + SeenScoreComponents []SeenScoreComponent +} + func (m *InferFlow) GetInferflowConfig(request InferflowOnboardRequest, token string) (InferflowConfig, error) { // RTP client is initialized in externalcall.Init() entityIDs := extractEntityIDs(request) @@ -63,12 +106,12 @@ func (m *InferFlow) GetInferflowConfig(request InferflowOnboardRequest, token st return InferflowConfig{}, err } - rtpComponents, err := GetRTPComponents(request, rtpFeatures, featureToDataType, m.EtcdConfig, token) + rtpComponents, err := GetRTPComponents(request, rtpFeatures, m.EtcdConfig, token) if err != nil { return InferflowConfig{}, err } - featureComponents, err := GetFeatureComponents(request, featureList, featureToDataType, pcvrCalibrationFeatures, pctrCalibrationFeatures, m.EtcdConfig, token, entityIDs) + featureComponents, err := GetFeatureComponents(request, featureList, pcvrCalibrationFeatures, pctrCalibrationFeatures, m.EtcdConfig, token, entityIDs) if err != nil { return InferflowConfig{}, err } @@ -98,7 +141,7 @@ func (m *InferFlow) GetInferflowConfig(request InferflowOnboardRequest, token st func GetFeatureList(request InferflowOnboardRequest, etcdConfig etcd.Manager, token string, entityIDs map[string]bool) (mapset.Set[string], map[string]string, mapset.Set[string], mapset.Set[string], mapset.Set[string], map[string]string, map[string]string, error) { initialFeatures, featureToDataType, predatorAndIrisOutputsToDataType := extractFeatures(request, entityIDs) - offlineFeatures, onlineFeatures, _, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, newFeatureToDataType, err := classifyFeatures(initialFeatures, featureToDataType) + offlineFeatures, onlineFeatures, defaultFeatures, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, newFeatureToDataType, err := classifyFeatures(initialFeatures, featureToDataType) if err != nil { return nil, nil, nil, nil, nil, nil, nil, err } @@ -172,6 +215,15 @@ func GetFeatureList(request InferflowOnboardRequest, etcdConfig etcd.Manager, to // features.Add(f) // } + for _, f := range defaultFeatures.ToSlice() { + if _, exists := predatorAndIrisOutputsToDataType[f]; exists { + continue + } + if featureToDataType[f] == "" { + featureToDataType[f] = "String" + } + } + if err := fetchMissingDatatypes(featureToDataType, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, onlineFeatures, token); err != nil { return nil, nil, nil, nil, nil, nil, nil, err } @@ -224,7 +276,7 @@ func extractFeatures(request InferflowOnboardRequest, entityIDs map[string]bool) for _, ranker := range request.Payload.Rankers { for _, input := range ranker.Inputs { for _, feature := range input.Features { - addFeature(feature, input.DataType) + addFeature(feature, "") } } @@ -386,6 +438,7 @@ func classifyFeatures( featureDataTypes map[string]string, ) (mapset.Set[string], mapset.Set[string], mapset.Set[string], mapset.Set[string], mapset.Set[string], mapset.Set[string], map[string]string, error) { defaultFeatures := mapset.NewSet[string]() + modelFeatures := mapset.NewSet[string]() onlineFeatures := mapset.NewSet[string]() offlineFeatures := mapset.NewSet[string]() rtpFeatures := mapset.NewSet[string]() @@ -394,7 +447,7 @@ func classifyFeatures( newFeatureToDataType := make(map[string]string) add := func(name, originalFeature string, featureType string) error { - if err := AddFeatureToSet(&defaultFeatures, &onlineFeatures, &offlineFeatures, &rtpFeatures, &pctrCalibrationFeatures, &pcvrCalibrationFeatures, name, featureType); err != nil { + if err := AddFeatureToSet(&defaultFeatures, &modelFeatures, &onlineFeatures, &offlineFeatures, &rtpFeatures, &pctrCalibrationFeatures, &pcvrCalibrationFeatures, name, featureType); err != nil { return fmt.Errorf("error classifying feature: %w", err) } newFeatureToDataType[name] = featureDataTypes[originalFeature] @@ -415,9 +468,10 @@ func classifyFeatures( return offlineFeatures, onlineFeatures, defaultFeatures, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, newFeatureToDataType, nil } -func AddFeatureToSet(defaultFeatures, onlineFeatures, offlineFeatures, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures *mapset.Set[string], feature string, featureType string) error { +func AddFeatureToSet(defaultFeatures, modelFeatures, onlineFeatures, offlineFeatures, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures *mapset.Set[string], feature string, featureType string) error { allSets := map[string]*mapset.Set[string]{ featureClassDefault: defaultFeatures, + featureClassModel: modelFeatures, featureClassOnline: onlineFeatures, featureClassOffline: offlineFeatures, featureClassRtp: rtpFeatures, @@ -474,6 +528,8 @@ func transformFeature(feature string) (string, string, error) { switch featureTypes[1] { case DEFAULT_FEATURE: return newFeature, featureClassDefault, nil + case MODEL_FEATURE: + return newFeature, featureClassModel, nil case ONLINE_FEATURE, CALIBRATION: return newFeature, featureClassOnline, nil case OFFLINE_FEATURE: @@ -491,6 +547,8 @@ func transformFeature(feature string) (string, string, error) { switch featureTypes[0] { case DEFAULT_FEATURE: return featureName, featureClassDefault, nil + case MODEL_FEATURE: + return featureName, featureClassModel, nil case ONLINE_FEATURE, CALIBRATION: return featureName, featureClassOnline, nil case OFFLINE_FEATURE: @@ -689,7 +747,7 @@ func GetOnlineFeatureMapping(offlineFeatureList mapset.Set[string], token string return response.Data, nil } -func GetFeatureComponents(request InferflowOnboardRequest, featureList mapset.Set[string], featureToDataType map[string]string, pcvrCalibrationFeatures mapset.Set[string], pctrCalibrationFeatures mapset.Set[string], etcdConfig etcd.Manager, token string, entityIDs map[string]bool) ([]FeatureComponent, error) { +func GetFeatureComponents(request InferflowOnboardRequest, featureList mapset.Set[string], pcvrCalibrationFeatures mapset.Set[string], pctrCalibrationFeatures mapset.Set[string], etcdConfig etcd.Manager, token string, entityIDs map[string]bool) ([]FeatureComponent, error) { featureComponents := make([]FeatureComponent, 0, featureList.Cardinality()+pcvrCalibrationFeatures.Cardinality()+pctrCalibrationFeatures.Cardinality()) featureComponentsMap := GetFeatureLabelToPrefixToFeatureGroupToFeatureMap(featureList.ToSlice()) @@ -937,7 +995,7 @@ func GetFeatureGroupDataTypeMap(label string, token string) (map[string]string, return featureGroupDataTypeMap, nil } -func GetRTPComponents(request InferflowOnboardRequest, rtpFeatures mapset.Set[string], featureToDataTypeMap map[string]string, etcdConfig etcd.Manager, token string) ([]RTPComponent, error) { +func GetRTPComponents(request InferflowOnboardRequest, rtpFeatures mapset.Set[string], etcdConfig etcd.Manager, token string) ([]RTPComponent, error) { rtpComponents := make([]RTPComponent, 0) if rtpFeatures.Cardinality() == 0 { @@ -946,7 +1004,7 @@ func GetRTPComponents(request InferflowOnboardRequest, rtpFeatures mapset.Set[st featureDataTypeMap, err := GetRTPFeatureGroupDataTypeMap() if err != nil && inferflow.IsMeeshoEnabled { - return rtpComponents, nil + return rtpComponents, fmt.Errorf("RTP Components: failed to fetch RTP feature data-type map: %w", err) } rtpFeatureComponentsMap := GetRTPFeatureLabelToPrefixToFeatureGroupToFeatureMap(rtpFeatures.ToSlice()) for label, prefixToFeatureGroupToFeatureMap := range rtpFeatureComponentsMap { @@ -1223,6 +1281,9 @@ func getNumerixScoreMapping(eqVariables map[string]string, offlineToOnlineMappin if keyDataType == "" { keyDataType = predatorAndNumerixOutputsToDataType[transformedFeature] } + if keyDataType == "" { + return nil, fmt.Errorf("numerix Score Mapping: key data type for '%s' not found", transformedFeature) + } if !strings.Contains(keyDataType, "DataType") { key = key + "@DataType" + keyDataType } else { @@ -1253,6 +1314,7 @@ func GetResponseConfigs(request *InferflowOnboardRequest) (*FinalResponseConfig, Features: request.Payload.Response.ResponseFeatures, LogSelectiveFeatures: request.Payload.Response.LogSelectiveFeatures, LogBatchSize: request.Payload.Response.LogBatchSize, + LoggingTTL: request.Payload.Response.LoggingTTL, } return responseConfigs, nil diff --git a/horizon/internal/inferflow/handler/inferflow.go b/horizon/internal/inferflow/handler/inferflow.go index 24cc5dc3..8df7a05b 100644 --- a/horizon/internal/inferflow/handler/inferflow.go +++ b/horizon/internal/inferflow/handler/inferflow.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" "strings" + "sync" "time" mainHandler "github.com/Meesho/BharatMLStack/horizon/internal/externalcall" @@ -21,6 +22,7 @@ import ( "github.com/Meesho/BharatMLStack/horizon/pkg/grpc" "github.com/Meesho/BharatMLStack/horizon/pkg/infra" "github.com/Meesho/BharatMLStack/horizon/pkg/random" + mapset "github.com/deckarep/golang-set/v2" "github.com/rs/zerolog/log" "google.golang.org/grpc/metadata" "gorm.io/gorm" @@ -53,6 +55,12 @@ const ( activeFalse = false inferFlowRetrieveModelScoreMethod = "/Inferflow/RetrieveModelScore" setFunctionalTest = "FunctionalTest" + defaultLoggingTTL = 30 + maxConfigVersion = 15 + defaultModelSchemaPerc = 0 + deployableTagDelimiter = "_" + scaleupTag = "scaleup" + defaultVersion = 1 ) func InitV1ConfigHandler() Config { @@ -109,13 +117,12 @@ func (m *InferFlow) Onboard(request InferflowOnboardRequest, token string) (Resp return Response{}, errors.New("failed to validate onboard request: " + err.Error()) } if response.Error != emptyResponse { - return Response{}, errors.New("model proxy config is invalid: " + response.Error) + return Response{}, errors.New("inferflow config is invalid: " + response.Error) } - exists := false - configs, err := m.InferFlowConfigRepo.GetByID(configId) - if err == nil { - exists = configs.Active + exists, err := m.InferFlowConfigRepo.DoesConfigIDExist(configId) + if err != nil { + return Response{}, errors.New("failed to check if config already exists: " + err.Error()) } if exists { return Response{}, errors.New("Config ID already exists") @@ -123,12 +130,12 @@ func (m *InferFlow) Onboard(request InferflowOnboardRequest, token string) (Resp inferFlowConfig, err := m.GetInferflowConfig(request, token) if err != nil { - return Response{}, errors.New("failed to generate model proxy config: " + err.Error()) + return Response{}, errors.New("failed to generate inferflow config: " + err.Error()) } response, err = ValidateInferFlowConfig(inferFlowConfig, token) if err != nil { - return Response{}, errors.New("failed to validate model proxy config: " + err.Error()) + return Response{}, errors.New("failed to validate inferflow config: " + err.Error()) } if response.Error != emptyResponse { return Response{}, errors.New("infer flow config is invalid: " + response.Error) @@ -178,6 +185,39 @@ func (m *InferFlow) Promote(request PromoteConfigRequest) (Response, error) { modelName := request.Payload.ConfigValue.ComponentConfig.PredatorComponents[i].ModelName request.Payload.ConfigValue.ComponentConfig.PredatorComponents[i].ModelEndPoint = modelNameToEndPointMap[modelName] } + for i := range request.Payload.LatestRequest.Payload.RequestPayload.Rankers { + modelName := request.Payload.LatestRequest.Payload.RequestPayload.Rankers[i].ModelName + request.Payload.LatestRequest.Payload.RequestPayload.Rankers[i].EndPoint = modelNameToEndPointMap[modelName] + } + request.Payload.ConfigValue.ResponseConfig.LoggingTTL = defaultLoggingTTL + request.Payload.ConfigValue.ResponseConfig.ModelSchemaPerc = defaultModelSchemaPerc + + destinationDeployableID := request.Payload.ConfigMapping.DeployableID + request.Payload.LatestRequest.Payload.ConfigMapping.DeployableID = destinationDeployableID + request.Payload.LatestRequest.Payload.RequestPayload.ConfigMapping.DeployableID = destinationDeployableID + request.Payload.LatestRequest.Payload.RequestPayload.Response.RankerSchemaFeaturesInResponsePerc = defaultModelSchemaPerc + + newVersion := defaultVersion + configIDExists, err := m.InferFlowConfigRepo.DoesConfigIDExist(request.Payload.ConfigID) + if err != nil { + return Response{}, errors.New("failed to check if config id exists in config table " + err.Error()) + } + if configIDExists { + log.Info().Msgf("config already exists, bumping version") + latestRequests, retrieveErr := m.InferFlowRequestRepo.GetApprovedRequestsByConfigID(request.Payload.ConfigID) + if retrieveErr != nil { + return Response{}, errors.New("failed to fetch config from DB") + } + if len(latestRequests) > 0 { + newVersion = latestRequests[0].Version + 1 + } + if newVersion > maxConfigVersion { + return Response{}, errors.New("This inferflow config has reached its version limit. Please create a clone to make further updates.") + } + request.Payload.ConfigValue.ComponentConfig.CacheVersion = newVersion + } else { + request.Payload.ConfigValue.ComponentConfig.CacheVersion = newVersion + } payload, err := AdaptPromoteRequestToDBPayload(request, request.Payload.LatestRequest) if err != nil { @@ -191,6 +231,7 @@ func (m *InferFlow) Promote(request PromoteConfigRequest) (Response, error) { RequestType: promoteRequestType, Status: pendingApproval, Active: activeTrue, + Version: newVersion, } err = m.InferFlowRequestRepo.Create(table) @@ -210,7 +251,7 @@ func (m *InferFlow) Edit(request EditConfigOrCloneConfigRequest, token string) ( response, err := m.ValidateOnboardRequest(request.Payload) if err != nil { - return Response{}, errors.New("failed to validate onboard request: " + err.Error()) + return Response{}, errors.New("failed to validate edit request: " + err.Error()) } if response.Error != emptyResponse { return Response{}, errors.New("onboard request is invalid: " + response.Error) @@ -238,9 +279,18 @@ func (m *InferFlow) Edit(request EditConfigOrCloneConfigRequest, token string) ( return Response{}, errors.New("failed to get existing configs: " + err.Error()) } - newVersion := 1 + newVersion := defaultVersion + prevLoggingTTL := defaultLoggingTTL if len(existingConfigs) > 0 { newVersion = existingConfigs[0].Version + 1 + prevLoggingTTL = existingConfigs[0].Payload.ConfigValue.ResponseConfig.LoggingTTL + } + if request.Payload.Response.LoggingTTL == 0 { + request.Payload.Response.LoggingTTL = prevLoggingTTL + } + + if newVersion > maxConfigVersion { + return Response{}, errors.New("This inferflow config has reached its version limit. Please create a clone to make further updates.") } onboardRequest := InferflowOnboardRequest(request) @@ -252,12 +302,14 @@ func (m *InferFlow) Edit(request EditConfigOrCloneConfigRequest, token string) ( response, err = ValidateInferFlowConfig(inferFlowConfig, token) if err != nil { - return Response{}, errors.New("failed to validate model proxy config: " + err.Error()) + return Response{}, errors.New("failed to validate inferflow config: " + err.Error()) } if response.Error != emptyResponse { return Response{}, errors.New("infer flow config is invalid: " + response.Error) } + inferFlowConfig.ComponentConfig.CacheVersion = newVersion + payload, err := AdaptEditRequestToDBPayload(request, inferFlowConfig, request.Payload) if err != nil { return Response{}, errors.New("failed to adapt edit request to db payload: " + err.Error()) @@ -310,6 +362,11 @@ func (m *InferFlow) Clone(request EditConfigOrCloneConfigRequest, token string) } } + if request.Payload.Response.LoggingTTL == 0 { + request.Payload.Response.LoggingTTL = defaultLoggingTTL + } + request.Payload.Response.RankerSchemaFeaturesInResponsePerc = defaultModelSchemaPerc + onboardRequest := InferflowOnboardRequest(request) inferFlowConfig, err := m.GetInferflowConfig(onboardRequest, token) @@ -351,6 +408,12 @@ func (m *InferFlow) Clone(request EditConfigOrCloneConfigRequest, token string) } func (m *InferFlow) ScaleUp(request ScaleUpConfigRequest) (Response, error) { + sourceConfigID := request.Payload.ConfigID + derivedConfigID, err := m.GetDerivedConfigID(request.Payload.ConfigID, request.GetConfigMapping().DeployableID) + if err != nil { + return Response{}, errors.New("failed to create derived config ID: " + err.Error()) + } + request.Payload.ConfigID = derivedConfigID exists, err := m.InferFlowRequestRepo.DoesConfigIdExistWithRequestType(request.Payload.ConfigID, scaleUpRequestType) if err != nil { return Response{}, errors.New("failed to check if config id exists in db: " + err.Error()) @@ -359,6 +422,13 @@ func (m *InferFlow) ScaleUp(request ScaleUpConfigRequest) (Response, error) { return Response{}, errors.New("Config ID already exists with scale up request") } + var latestSourceRequest GetLatestRequestResponse + latestSourceRequest, err = m.GetLatestRequest(sourceConfigID) + if err != nil { + return Response{}, errors.New("failed to get latest request for the source configID: " + sourceConfigID + ": " + err.Error()) + } + request.Payload.ConfigMapping.SourceConfigID = sourceConfigID + modelNameToEndPointMap := make(map[string]ModelNameToEndPointMap) for _, proposedModelEndpoint := range request.Payload.ModelNameToEndPointMap { modelNameToEndPointMap[proposedModelEndpoint.CurrentModelName] = proposedModelEndpoint @@ -370,12 +440,35 @@ func (m *InferFlow) ScaleUp(request ScaleUpConfigRequest) (Response, error) { request.Payload.ConfigValue.ComponentConfig.PredatorComponents[i].ModelName = modelNameToEndPointMap[modelName].NewModelName } + for i := range latestSourceRequest.Data.Payload.RequestPayload.Rankers { + modelName := latestSourceRequest.Data.Payload.RequestPayload.Rankers[i].ModelName + latestSourceRequest.Data.Payload.RequestPayload.Rankers[i].EndPoint = modelNameToEndPointMap[modelName].EndPointID + latestSourceRequest.Data.Payload.RequestPayload.Rankers[i].ModelName = modelNameToEndPointMap[modelName].NewModelName + } + + // Set Request Payload name + parts := strings.Split(request.Payload.ConfigID, "-") + if len(parts) >= 3 { + latestSourceRequest.Data.Payload.RequestPayload.RealEstate = parts[0] + latestSourceRequest.Data.Payload.RequestPayload.Tenant = parts[1] + latestSourceRequest.Data.Payload.RequestPayload.ConfigIdentifier = strings.Join(parts[2:], "-") + + } + + latestSourceRequest.Data.ConfigID = request.Payload.ConfigID + latestSourceRequest.Data.Payload.RequestPayload.ConfigMapping.DeployableID = request.Payload.ConfigMapping.DeployableID + latestSourceRequest.Data.Payload.RequestPayload.Response.RankerSchemaFeaturesInResponsePerc = defaultModelSchemaPerc + + request.Payload.ConfigValue.ResponseConfig.ModelSchemaPerc = defaultModelSchemaPerc request.Payload.ConfigValue.ResponseConfig.LoggingPerc = request.Payload.LoggingPerc + request.Payload.ConfigValue.ResponseConfig.LoggingTTL = request.Payload.LoggingTTL + request.Payload.ConfigValue.ComponentConfig.CacheVersion = defaultVersion - payload, err := AdaptScaleUpRequestToDBPayload(request) + payload, err := AdaptScaleUpRequestToDBPayload(request, latestSourceRequest.Data) if err != nil { return Response{}, errors.New("failed to adapt scale up request to db payload: " + err.Error()) } + payload.ConfigMapping.SourceConfigID = request.Payload.ConfigMapping.SourceConfigID table := &inferflow_request.Table{ ConfigID: request.Payload.ConfigID, @@ -403,6 +496,9 @@ func (m *InferFlow) Delete(request DeleteConfigRequest) (Response, error) { if err != nil { return Response{}, errors.New("failed to get infer flow config by id in db: " + err.Error()) } + if InferFlowConfigTable == nil { + return Response{}, errors.New("inferflow config: " + request.ConfigID + " does not exist in db") + } Discoverytable, err := m.DiscoveryConfigRepo.GetById(int(InferFlowConfigTable.DiscoveryID)) if err != nil { @@ -467,107 +563,318 @@ func (m *InferFlow) Cancel(request CancelConfigRequest) (Response, error) { } func (m *InferFlow) Review(request ReviewRequest) (Response, error) { + request.Status = strings.ToUpper(request.Status) - err := m.InferFlowRequestRepo.Transaction(func(tx *gorm.DB) error { + if request.Status != approved && request.Status != rejected { + log.Error().Msgf("invalid status for request id: %d", request.RequestID) + return Response{}, errors.New("invalid status for request") + } - request.Status = strings.ToUpper(request.Status) + if request.Status == rejected && request.RejectReason == emptyResponse { + log.Error().Msgf("request reason not specified for request id: %d", request.RequestID) + return Response{}, errors.New("rejection reason is required") + } - if request.Status != approved && request.Status != rejected { - return errors.New("invalid status") - } + exists, err := m.InferFlowRequestRepo.DoesRequestIDExistWithStatus(request.RequestID, pendingApproval) + if err != nil { + log.Error().Msgf("failed to check if request id: %d exists with status in db: %s", request.RequestID, err) + return Response{}, errors.New("failed to check if request id exists with status in db: " + err.Error()) + } + if !exists { + log.Error().Msgf("request id: %d does not exist or request is not pending approval", request.RequestID) + return Response{}, errors.New("request id does not exist or request is not pending approval") + } - if request.Status == rejected { - if request.RejectReason == emptyResponse { - return errors.New("rejecttion reason is needed") - } - } + if request.Status == rejected { + return m.handleRejectedRequest(request) + } - exists, err := m.InferFlowRequestRepo.DoesRequestIDExistWithStatus(request.RequestID, pendingApproval) - if err != nil { - return errors.New("failed to check if request id exists with status: " + err.Error()) - } - if !exists { - return errors.New("request id does not exist or is not pending approval") - } + return m.handleApprovedRequest(request) +} - fullTable := &inferflow_request.Table{} - if err := tx.First(fullTable, request.RequestID).Error; err != nil { - return errors.New("failed to get infer flow config request by id: " + err.Error()) - } +func (m *InferFlow) handleRejectedRequest(request ReviewRequest) (Response, error) { + requestEntry := &inferflow_request.Table{ + RequestID: request.RequestID, + Status: request.Status, + RejectReason: request.RejectReason, + Reviewer: request.Reviewer, + Active: activeFalse, + } - table := &inferflow_request.Table{ + if err := m.InferFlowRequestRepo.Update(requestEntry); err != nil { + return Response{}, errors.New("failed to update inferflow config request in db: " + err.Error()) + } + + return Response{ + Error: emptyResponse, + Data: Message{Message: fmt.Sprintf("inferflow config request rejected successfully for Request Id %d", request.RequestID)}, + }, nil +} + +func (m *InferFlow) handleApprovedRequest(request ReviewRequest) (Response, error) { + var requestEntry *inferflow_request.Table + var discoveryID int + var discoveryConfig *discovery_config.DiscoveryConfig + + tempRequest := inferflow_request.Table{} + tempRequest, err := m.InferFlowRequestRepo.GetRequestByID(request.RequestID) + if err != nil { + return Response{}, fmt.Errorf("failed to fetch latest unapproved request for request id: %d: %w", request.RequestID, err) + } + + var configExistedBeforeTx bool + if tempRequest.RequestType == promoteRequestType { + existingConfig, _ := m.InferFlowConfigRepo.GetByID(tempRequest.ConfigID) + configExistedBeforeTx = existingConfig != nil + } + + err = m.InferFlowRequestRepo.Transaction(func(tx *gorm.DB) error { + requestEntry = &inferflow_request.Table{ RequestID: request.RequestID, Status: request.Status, RejectReason: request.RejectReason, Reviewer: request.Reviewer, } - if request.Status == rejected { - table.Active = activeFalse + if err := tx.First(requestEntry, request.RequestID).Error; err != nil { + return fmt.Errorf("failed to get request: %w", err) } + requestEntry.Reviewer = request.Reviewer + requestEntry.RejectReason = request.RejectReason - err = m.InferFlowRequestRepo.UpdateTx(tx, table) + var err error + discoveryID, discoveryConfig, err = m.createOrUpdateDiscoveryConfig(tx, requestEntry, configExistedBeforeTx) if err != nil { - return errors.New("failed to update infer flow config request in db: " + err.Error()) + return fmt.Errorf("failed to handle discovery config: %w", err) } - if request.Status == approved { - discoveryID, discovery, err := m.createOrUpdateDiscoveryConfig(tx, fullTable) - if err != nil { + if err := m.createOrUpdateInferFlowConfig(tx, requestEntry, discoveryID, configExistedBeforeTx); err != nil { + return fmt.Errorf("failed to handle inferflow config: %w", err) + } + + requestEntry.Status = approved + err = m.InferFlowRequestRepo.UpdateTx(tx, requestEntry) + if err != nil { + return errors.New("failed to update inferflow config request in db: " + err.Error()) + } + + return nil + }) + + if err != nil { + return Response{}, fmt.Errorf("failed to review config (DB rolled back): %w", err) + } + + if err := m.createOrUpdateEtcdConfig(requestEntry, discoveryConfig, configExistedBeforeTx); err != nil { + if rollBackErr := m.rollbackApprovedRequest(request, requestEntry, discoveryID, configExistedBeforeTx); rollBackErr != nil { + log.Error().Err(rollBackErr).Msg("Failed to rollback DB changes after ETCD failure") + return Response{}, fmt.Errorf("ETCD sync failed and DB rollback also failed: etcd=%w, rollback=%v", err, rollBackErr) + } + log.Warn().Msgf("Successfully rolled back the request: %d", request.RequestID) + return Response{}, fmt.Errorf("ETCD sync failed: %w", err) + } + + return Response{ + Error: emptyResponse, + Data: Message{Message: "Mp Config reviewed successfully."}, + }, nil +} + +func (m *InferFlow) rollbackApprovedRequest(request ReviewRequest, fullTable *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { + return m.InferFlowRequestRepo.Transaction(func(tx *gorm.DB) error { + table := &inferflow_request.Table{ + RequestID: request.RequestID, + Status: pendingApproval, + Reviewer: emptyResponse, + } + if err := m.InferFlowRequestRepo.UpdateTx(tx, table); err != nil { + return fmt.Errorf("failed to revert request status: %w", err) + } + + switch fullTable.RequestType { + case onboardRequestType, cloneRequestType, scaleUpRequestType: + if err := m.rollbackCreatedConfigs(tx, fullTable.ConfigID, discoveryID); err != nil { return err } - err = m.createOrUpdateInferFlowConfig(tx, fullTable, discoveryID) - if err != nil { + case editRequestType: + if err := m.rollbackEditRequest(tx, fullTable, discoveryID); err != nil { return err } - err = m.createOrUpdateEtcdConfig(fullTable, discovery) - if err != nil { - log.Error().Err(err).Msg("Failed to sync config to etcd") - return errors.New("failed to sync config to etcd: " + err.Error()) + case deleteRequestType: + updatedBy := fullTable.UpdatedBy + if updatedBy == "" { + updatedBy = fullTable.CreatedBy + } + if err := m.rollbackDeletedConfigs(tx, fullTable.ConfigID, discoveryID, updatedBy); err != nil { + return err + } + + case promoteRequestType: + if err := m.rollbackPromoteRequest(tx, fullTable, discoveryID, configExistedBeforeTx); err != nil { + return err } } + return nil }) +} + +func (m *InferFlow) rollbackPromoteRequest(tx *gorm.DB, currentRequest *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { + if !configExistedBeforeTx { + if err := m.rollbackCreatedConfigs(tx, currentRequest.ConfigID, discoveryID); err != nil { + return err + } + } + return nil +} +func (m *InferFlow) rollbackEditRequest(tx *gorm.DB, currentRequest *inferflow_request.Table, discoveryID int) error { + approvedRequests, err := m.InferFlowRequestRepo.GetApprovedRequestsByConfigID(currentRequest.ConfigID) if err != nil { - return Response{}, errors.New("failed to review config: " + err.Error()) + return fmt.Errorf("Failed to retrieve approved requests: %w", err) } - return Response{ - Error: emptyResponse, - Data: Message{Message: "Mp Config reviewed successfully."}, - }, nil + var previousRequest *inferflow_request.Table + if len(approvedRequests) > 0 { + if approvedRequests[0].RequestID == currentRequest.RequestID { + if len(approvedRequests) > 1 { + previousRequest = &approvedRequests[1] + } else { + return fmt.Errorf("no other request to revert back to: Requires manual intervention") + } + } else { + previousRequest = &approvedRequests[0] + } + } else { + return fmt.Errorf("no other request to revert back to: Requires manual intervention") + } + + existingConfig, err := m.InferFlowConfigRepo.GetByID(currentRequest.ConfigID) + if err != nil { + return fmt.Errorf("failed to get inferflow config: %w", err) + } + if existingConfig == nil { + return errors.New("inferflow config not found") + } + + restoredConfig := &inferflow_config.Table{ + ConfigID: currentRequest.ConfigID, + DiscoveryID: discoveryID, + ConfigValue: previousRequest.Payload.ConfigValue, + Active: activeTrue, + UpdatedBy: currentRequest.UpdatedBy, + } + + if err := m.InferFlowConfigRepo.UpdateTx(tx, restoredConfig); err != nil { + return fmt.Errorf("failed to restore inferflow config: %w", err) + } + + restoredDiscovery := &discovery_config.DiscoveryConfig{ + ID: discoveryID, + ServiceDeployableID: previousRequest.Payload.ConfigMapping.DeployableID, + AppToken: previousRequest.Payload.ConfigMapping.AppToken, + ServiceConnectionID: previousRequest.Payload.ConfigMapping.ConnectionConfigID, + Active: activeTrue, + UpdatedBy: currentRequest.UpdatedBy, + } + if err := m.DiscoveryConfigRepo.UpdateTx(tx, restoredDiscovery); err != nil { + return fmt.Errorf("failed to restore discovery config: %w", err) + } + + return nil } -func (m *InferFlow) createOrUpdateDiscoveryConfig(tx *gorm.DB, table *inferflow_request.Table) (int, *discovery_config.DiscoveryConfig, error) { +func (m *InferFlow) rollbackCreatedConfigs(tx *gorm.DB, configID string, discoveryID int) error { + if err := m.InferFlowConfigRepo.DeleteByConfigIDTx(tx, configID); err != nil { + return fmt.Errorf("failed to rollback inferflow config: %w", err) + } + + if err := m.DiscoveryConfigRepo.DeleteByIDTx(tx, discoveryID); err != nil { + return fmt.Errorf("failed to rollback discovery config: %w", err) + } + + return nil +} + +func (m *InferFlow) rollbackDeletedConfigs(tx *gorm.DB, configID string, discoveryID int, updatedby string) error { + latestConfig, err := m.InferFlowConfigRepo.GetLatestInactiveByConfigID(tx, configID) + if err != nil { + return fmt.Errorf("failed to find soft-deleted inferflow config: %w", err) + } + if latestConfig == nil { + return errors.New("no soft-deleted inferflow config found") + } + + if err := m.InferFlowConfigRepo.ReactivateByIDTx(tx, int(latestConfig.ID), updatedby); err != nil { + return fmt.Errorf("failed to reactivate inferflow config: %w", err) + } + + if err := m.DiscoveryConfigRepo.ReactivateByIDTx(tx, discoveryID); err != nil { + return fmt.Errorf("failed to reactivate discovery config: %w", err) + } + + return nil +} + +func (m *InferFlow) createOrUpdateDiscoveryConfig(tx *gorm.DB, requestEntry *inferflow_request.Table, configExistedBeforeTx bool) (int, *discovery_config.DiscoveryConfig, error) { discovery := &discovery_config.DiscoveryConfig{ - ServiceDeployableID: table.Payload.ConfigMapping.DeployableID, - AppToken: table.Payload.ConfigMapping.AppToken, - ServiceConnectionID: table.Payload.ConfigMapping.ConnectionConfigID, + ServiceDeployableID: requestEntry.Payload.ConfigMapping.DeployableID, + AppToken: requestEntry.Payload.ConfigMapping.AppToken, + ServiceConnectionID: requestEntry.Payload.ConfigMapping.ConnectionConfigID, Active: activeTrue, } - switch table.RequestType { - case onboardRequestType, cloneRequestType, promoteRequestType, scaleUpRequestType: - if table.UpdatedBy != "" { - discovery.CreatedBy = table.UpdatedBy + switch requestEntry.RequestType { + case onboardRequestType, cloneRequestType, scaleUpRequestType: + if requestEntry.UpdatedBy != "" { + discovery.CreatedBy = requestEntry.UpdatedBy } else { - discovery.CreatedBy = table.CreatedBy + discovery.CreatedBy = requestEntry.CreatedBy } err := m.DiscoveryConfigRepo.CreateTx(tx, discovery) if err != nil { return 0, nil, errors.New("failed to create discovery config: " + err.Error()) } + case promoteRequestType: + if !configExistedBeforeTx { + if requestEntry.UpdatedBy != "" { + discovery.CreatedBy = requestEntry.UpdatedBy + } else { + discovery.CreatedBy = requestEntry.CreatedBy + } + err := m.DiscoveryConfigRepo.CreateTx(tx, discovery) + if err != nil { + return 0, nil, errors.New("failed to create discovery config: " + err.Error()) + } + } else { + existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) + if err != nil { + return 0, nil, errors.New("failed to query inferflow config repo: " + err.Error()) + } + if requestEntry.UpdatedBy != "" { + discovery.UpdatedBy = requestEntry.UpdatedBy + } else { + discovery.UpdatedBy = requestEntry.CreatedBy + } + discovery.ID = int(existingConfig.DiscoveryID) + err = m.DiscoveryConfigRepo.UpdateTx(tx, discovery) + if err != nil { + return 0, nil, errors.New("failed to update discovery config: " + err.Error()) + } + } case editRequestType: - if table.UpdatedBy != "" { - discovery.UpdatedBy = table.UpdatedBy + if requestEntry.UpdatedBy != "" { + discovery.UpdatedBy = requestEntry.UpdatedBy } else { - discovery.UpdatedBy = table.CreatedBy + discovery.UpdatedBy = requestEntry.CreatedBy } - config, err := m.InferFlowConfigRepo.GetByID(table.ConfigID) + config, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) if err != nil { - return 0, nil, errors.New("failed to get model proxy config by id: " + err.Error()) + return 0, nil, errors.New("failed to get inferflow config by id: " + err.Error()) + } + if config == nil { + return 0, nil, errors.New("failed to get inferflow config by id") } discovery.ID = int(config.DiscoveryID) err = m.DiscoveryConfigRepo.UpdateTx(tx, discovery) @@ -575,14 +882,17 @@ func (m *InferFlow) createOrUpdateDiscoveryConfig(tx *gorm.DB, table *inferflow_ return 0, nil, errors.New("failed to update discovery config: " + err.Error()) } case deleteRequestType: - config, err := m.InferFlowConfigRepo.GetByID(table.ConfigID) + config, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) if err != nil { - return 0, nil, errors.New("failed to get model proxy config by id: " + err.Error()) + return 0, nil, errors.New("failed to get inferflow config by id: " + err.Error()) + } + if config == nil { + return 0, nil, errors.New("failed to get inferflow config by id") } - if table.UpdatedBy != "" { - discovery.UpdatedBy = table.UpdatedBy + if requestEntry.UpdatedBy != "" { + discovery.UpdatedBy = requestEntry.UpdatedBy } else { - discovery.UpdatedBy = table.CreatedBy + discovery.UpdatedBy = requestEntry.CreatedBy } discovery.ID = int(config.DiscoveryID) discovery.Active = activeFalse @@ -597,53 +907,88 @@ func (m *InferFlow) createOrUpdateDiscoveryConfig(tx *gorm.DB, table *inferflow_ return discovery.ID, discovery, nil } -func (m *InferFlow) createOrUpdateInferFlowConfig(tx *gorm.DB, table *inferflow_request.Table, discoveryID int) error { - newTable := &inferflow_config.Table{ +func (m *InferFlow) createOrUpdateInferFlowConfig(tx *gorm.DB, requestEntry *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { + newConfig := &inferflow_config.Table{ DiscoveryID: discoveryID, - ConfigID: table.ConfigID, + ConfigID: requestEntry.ConfigID, Active: activeTrue, - ConfigValue: table.Payload.ConfigValue, + ConfigValue: requestEntry.Payload.ConfigValue, } - switch table.RequestType { - case onboardRequestType, cloneRequestType, promoteRequestType, scaleUpRequestType: - if table.UpdatedBy != "" { - newTable.CreatedBy = table.UpdatedBy + switch requestEntry.RequestType { + case onboardRequestType, cloneRequestType: + if requestEntry.UpdatedBy != "" { + newConfig.CreatedBy = requestEntry.UpdatedBy } else { - newTable.CreatedBy = table.CreatedBy + newConfig.CreatedBy = requestEntry.CreatedBy + } + return m.InferFlowConfigRepo.CreateTx(tx, newConfig) + case scaleUpRequestType: + if requestEntry.UpdatedBy != "" { + newConfig.CreatedBy = requestEntry.UpdatedBy + } else { + newConfig.CreatedBy = requestEntry.CreatedBy + } + newConfig.SourceConfigID = requestEntry.Payload.ConfigMapping.SourceConfigID + return m.InferFlowConfigRepo.CreateTx(tx, newConfig) + case promoteRequestType: + if !configExistedBeforeTx { + if requestEntry.UpdatedBy != "" { + newConfig.CreatedBy = requestEntry.UpdatedBy + } else { + newConfig.CreatedBy = requestEntry.CreatedBy + } + return m.InferFlowConfigRepo.CreateTx(tx, newConfig) + } else { + existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) + if err != nil { + return errors.New("failed to query inferflow config repo: " + err.Error()) + } + newConfig.ID = existingConfig.ID + if requestEntry.UpdatedBy != "" { + newConfig.UpdatedBy = requestEntry.UpdatedBy + } else { + newConfig.UpdatedBy = requestEntry.CreatedBy + } + return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) } - return m.InferFlowConfigRepo.CreateTx(tx, newTable) case editRequestType: - existingConfig, err := m.InferFlowConfigRepo.GetByID(table.ConfigID) + existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) if err != nil { - return errors.New("failed to get model proxy config by id: " + err.Error()) + return errors.New("failed to get inferflow config by id: " + err.Error()) } - newTable.ID = existingConfig.ID - if table.UpdatedBy != "" { - newTable.UpdatedBy = table.UpdatedBy + if existingConfig == nil { + return errors.New("failed to get inferflow config by id") + } + newConfig.ID = existingConfig.ID + if requestEntry.UpdatedBy != "" { + newConfig.UpdatedBy = requestEntry.UpdatedBy } else { - newTable.UpdatedBy = table.CreatedBy + newConfig.UpdatedBy = requestEntry.CreatedBy } - return m.InferFlowConfigRepo.UpdateTx(tx, newTable) + return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) case deleteRequestType: - existingConfig, err := m.InferFlowConfigRepo.GetByID(table.ConfigID) + existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) if err != nil { - return errors.New("failed to get model proxy config by id: " + err.Error()) + return errors.New("failed to get inferflow config by id: " + err.Error()) + } + if existingConfig == nil { + return errors.New("failed to get inferflow config by id") } - newTable.ID = existingConfig.ID - if table.UpdatedBy != "" { - newTable.UpdatedBy = table.UpdatedBy + newConfig.ID = existingConfig.ID + if requestEntry.UpdatedBy != "" { + newConfig.UpdatedBy = requestEntry.UpdatedBy } else { - newTable.UpdatedBy = table.CreatedBy + newConfig.UpdatedBy = requestEntry.CreatedBy } - newTable.Active = activeFalse - return m.InferFlowConfigRepo.UpdateTx(tx, newTable) + newConfig.Active = activeFalse + return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) default: return errors.New("invalid request type") } } -func (m *InferFlow) createOrUpdateEtcdConfig(table *inferflow_request.Table, discovery *discovery_config.DiscoveryConfig) error { +func (m *InferFlow) createOrUpdateEtcdConfig(table *inferflow_request.Table, discovery *discovery_config.DiscoveryConfig, configExistedBeforeTx bool) error { serviceDeployableTable, err := m.ServiceDeployableConfigRepo.GetById(int(discovery.ServiceDeployableID)) if err != nil { return errors.New("failed to get service deployable config by id: " + err.Error()) @@ -653,8 +998,13 @@ func (m *InferFlow) createOrUpdateEtcdConfig(table *inferflow_request.Table, dis inferFlowConfig := AdaptToEtcdInferFlowConfig(table.Payload.ConfigValue) switch table.RequestType { - case onboardRequestType, cloneRequestType, promoteRequestType, scaleUpRequestType: + case onboardRequestType, cloneRequestType, scaleUpRequestType: return m.EtcdConfig.CreateConfig(serviceName, configId, inferFlowConfig) + case promoteRequestType: + if !configExistedBeforeTx { + return m.EtcdConfig.CreateConfig(serviceName, configId, inferFlowConfig) + } + return m.EtcdConfig.UpdateConfig(serviceName, configId, inferFlowConfig) case editRequestType: return m.EtcdConfig.UpdateConfig(serviceName, configId, inferFlowConfig) case deleteRequestType: @@ -683,6 +1033,29 @@ func (m *InferFlow) GetAllRequests(request GetAllRequestConfigsRequest) (GetAllR } } + deployableIDsMap := make(map[int]bool) + for _, table := range tables { + payload := table.Payload + if payload.ConfigMapping.DeployableID > 0 { + deployableIDsMap[payload.ConfigMapping.DeployableID] = true + } + } + + deployableIDs := make([]int, 0, len(deployableIDsMap)) + for id := range deployableIDsMap { + deployableIDs = append(deployableIDs, id) + } + + serviceDeployables, err := m.ServiceDeployableConfigRepo.GetByIds(deployableIDs) + if err != nil { + return GetAllRequestConfigsResponse{}, errors.New("failed to get service deployable configs: " + err.Error()) + } + + deployableMap := make(map[int]string) + for _, sd := range serviceDeployables { + deployableMap[sd.ID] = sd.Name + } + requestConfigs := make([]RequestConfig, len(tables)) for i, table := range tables { @@ -691,14 +1064,12 @@ func (m *InferFlow) GetAllRequests(request GetAllRequestConfigsRequest) (GetAllR ConfigValue := AdaptFromDbToInferFlowConfig(payload.ConfigValue) - serviceDeployableID := ConfigMapping.DeployableID - serviceDeployableTable, err := m.ServiceDeployableConfigRepo.GetById(int(serviceDeployableID)) - if err != nil { - return GetAllRequestConfigsResponse{}, errors.New("failed to get service deployable config by id: " + err.Error()) + if name, exists := deployableMap[ConfigMapping.DeployableID]; exists { + ConfigMapping.DeployableName = name + } else { + ConfigMapping.DeployableName = "Unknown" } - ConfigMapping.DeployableName = serviceDeployableTable.Name - requestConfigs[i] = RequestConfig{ RequestID: table.RequestID, Payload: Payload{ @@ -728,37 +1099,45 @@ func (m *InferFlow) GetAll() (GetAllResponse, error) { tables, err := m.InferFlowConfigRepo.GetAll() if err != nil { - return GetAllResponse{}, errors.New("failed to get all infer flow configs: " + err.Error()) + return GetAllResponse{}, errors.New("failed to get all inferflow configs: " + err.Error()) + } + + discoveryIDs := make([]int, 0, len(tables)) + for _, table := range tables { + discoveryIDs = append(discoveryIDs, int(table.DiscoveryID)) + } + + discoveryMap, serviceDeployableMap, err := m.batchFetchDiscoveryConfigs(discoveryIDs) + if err != nil { + return GetAllResponse{}, errors.New(err.Error()) + } + + ringMasterConfigs, err := m.batchFetchRingMasterConfigs(serviceDeployableMap) + if err != nil { + return GetAllResponse{}, errors.New("failed to batch fetch ringmaster configs: " + err.Error()) } responseConfigs := make([]ConfigTable, len(tables)) for i, table := range tables { - disocveryTable, err := m.DiscoveryConfigRepo.GetById(int(table.DiscoveryID)) - if err != nil { - return GetAllResponse{}, errors.New("failed to get discovery config by id: " + err.Error()) + disocveryTable, exists := discoveryMap[int(table.DiscoveryID)] + if !exists { + return GetAllResponse{}, errors.New("failed to find discovery config by id: " + strconv.Itoa(int(table.DiscoveryID))) } - serviceDeployableID := disocveryTable.ServiceDeployableID - serviceDeployableTable, err := m.ServiceDeployableConfigRepo.GetById(int(serviceDeployableID)) - if err != nil { - return GetAllResponse{}, errors.New("failed to get service deployable config by id: " + err.Error()) + serviceDeployableTable, exists := serviceDeployableMap[disocveryTable.ServiceDeployableID] + if !exists { + return GetAllResponse{}, errors.New("failed to find service deployable config by id: " + strconv.Itoa(disocveryTable.ServiceDeployableID)) } ConfigValue := AdaptFromDbToInferFlowConfig(table.ConfigValue) - infraConfig := m.infrastructureHandler.GetConfig(serviceDeployableTable.Name, m.workingEnv) - // Convert to mainHandler.Config for compatibility - DeployableConfig := mainHandler.Config{ - MinReplica: infraConfig.MinReplica, - MaxReplica: infraConfig.MaxReplica, - RunningStatus: infraConfig.RunningStatus, - } + ringMasterConfig := ringMasterConfigs[serviceDeployableTable.ID] responseConfigs[i] = ConfigTable{ ConfigID: table.ConfigID, ConfigValue: ConfigValue, Host: serviceDeployableTable.Host, - DeployableRunningStatus: DeployableConfig.RunningStatus == "true", + DeployableRunningStatus: ringMasterConfig.RunningStatus == "true", MonitoringUrl: serviceDeployableTable.MonitoringUrl, CreatedBy: table.CreatedBy, UpdatedBy: table.UpdatedBy, @@ -768,6 +1147,7 @@ func (m *InferFlow) GetAll() (GetAllResponse, error) { Tested: table.TestResults.Tested, Message: table.TestResults.Message, }, + SourceConfigID: table.SourceConfigID, } } @@ -779,6 +1159,81 @@ func (m *InferFlow) GetAll() (GetAllResponse, error) { return response, nil } +func (m *InferFlow) batchFetchDiscoveryConfigs(discoveryIDs []int) ( + map[int]*discovery_config.DiscoveryConfig, + map[int]*service_deployable_config.ServiceDeployableConfig, + error, +) { + emptyDiscoveryMap := make(map[int]*discovery_config.DiscoveryConfig) + emptyServiceDeployableMap := make(map[int]*service_deployable_config.ServiceDeployableConfig) + + if len(discoveryIDs) == 0 { + return emptyDiscoveryMap, emptyServiceDeployableMap, nil + } + + discoveryConfigs, err := m.DiscoveryConfigRepo.GetByServiceDeployableIDs(discoveryIDs) + if err != nil { + return nil, nil, fmt.Errorf("failed to get discovery configs: %w", err) + } + + discoveryMap := make(map[int]*discovery_config.DiscoveryConfig) + for i := range discoveryConfigs { + discoveryMap[discoveryConfigs[i].ID] = &discoveryConfigs[i] + } + + serviceDeployableIDsMap := make(map[int]bool) + for _, dc := range discoveryConfigs { + serviceDeployableIDsMap[dc.ServiceDeployableID] = true + } + + serviceDeployableIDs := make([]int, 0, len(serviceDeployableIDsMap)) + for id := range serviceDeployableIDsMap { + serviceDeployableIDs = append(serviceDeployableIDs, id) + } + + if len(serviceDeployableIDs) == 0 { + return discoveryMap, emptyServiceDeployableMap, nil + } + + serviceDeployables, err := m.ServiceDeployableConfigRepo.GetByIds(serviceDeployableIDs) + if err != nil { + return nil, nil, fmt.Errorf("failed to get service deployable configs: %w", err) + } + + serviceDeployableMap := make(map[int]*service_deployable_config.ServiceDeployableConfig) + for i := range serviceDeployables { + serviceDeployableMap[serviceDeployables[i].ID] = &serviceDeployables[i] + } + + return discoveryMap, serviceDeployableMap, nil +} + +func (m *InferFlow) batchFetchRingMasterConfigs(serviceDeployables map[int]*service_deployable_config.ServiceDeployableConfig) (map[int]infrastructurehandler.Config, error) { + ringMasterConfigs := make(map[int]infrastructurehandler.Config) + var mu sync.Mutex + var wg sync.WaitGroup + + semaphore := make(chan struct{}, 10) + + for id, deployable := range serviceDeployables { + wg.Add(1) + go func(deployableID int, sd *service_deployable_config.ServiceDeployableConfig) { + defer wg.Done() + semaphore <- struct{}{} + defer func() { <-semaphore }() + + config := m.infrastructureHandler.GetConfig(sd.Name, inferflowPkg.AppEnv) + + mu.Lock() + ringMasterConfigs[deployableID] = config + mu.Unlock() + }(id, deployable) + } + + wg.Wait() + return ringMasterConfigs, nil +} + func (m *InferFlow) ValidateRequest(request ValidateRequest, token string) (Response, error) { tables, err := m.InferFlowRequestRepo.GetAll() if err != nil { @@ -858,6 +1313,15 @@ func ValidateInferFlowConfig(config InferflowConfig, token string) (Response, er } func (m *InferFlow) ValidateOnboardRequest(request OnboardPayload) (Response, error) { + outputs := mapset.NewSet[string]() + deployableConfig, err := m.ServiceDeployableConfigRepo.GetById(request.ConfigMapping.DeployableID) + if err != nil { + return Response{ + Error: "Failed to fetch deployable config for the request", + Data: Message{Message: emptyResponse}, + }, errors.New("Failed to fetch deployable config for the request") + } + permissibleEndpoints := m.EtcdConfig.GetConfiguredEndpoints(deployableConfig.Name) for _, ranker := range request.Rankers { if len(ranker.EntityID) == 0 { return Response{ @@ -865,6 +1329,16 @@ func (m *InferFlow) ValidateOnboardRequest(request OnboardPayload) (Response, er Data: Message{Message: emptyResponse}, }, errors.New("Entity ID is not set for model: " + ranker.ModelName) } + if !permissibleEndpoints.Contains(ranker.EndPoint) { + errorMsg := fmt.Sprintf( + "invalid endpoint: %s chosen for service deployable: %s for model: %s", + ranker.EndPoint, deployableConfig.Name, ranker.ModelName, + ) + return Response{ + Error: errorMsg, + Data: Message{Message: emptyResponse}, + }, errors.New(errorMsg) + } for _, output := range ranker.Outputs { if len(output.ModelScores) != len(output.ModelScoresDims) { return Response{ @@ -872,6 +1346,15 @@ func (m *InferFlow) ValidateOnboardRequest(request OnboardPayload) (Response, er Data: Message{Message: emptyResponse}, }, errors.New("model scores and model scores dims are not equal for model: " + ranker.ModelName) } + for _, modelScore := range output.ModelScores { + if outputs.Contains(modelScore) { + return Response{ + Error: "duplicate model scores: " + modelScore + " for model: " + ranker.ModelName, + Data: Message{Message: emptyResponse}, + }, errors.New("duplicate model scores: " + modelScore + " for model: " + ranker.ModelName) + } + outputs.Add(modelScore) + } } } @@ -897,6 +1380,56 @@ func (m *InferFlow) ValidateOnboardRequest(request OnboardPayload) (Response, er }, errors.New("invalid eq variable: " + value) } } + if outputs.Contains(reRanker.Score) { + return Response{ + Error: "duplicate score: " + reRanker.Score + " for reRanker: " + reRanker.Score, + Data: Message{Message: emptyResponse}, + }, errors.New("duplicate score: " + reRanker.Score + " for reRanker: " + reRanker.Score) + } + outputs.Add(reRanker.Score) + } + + // Validate MODEL_FEATURE list + for _, ranker := range request.Rankers { + for _, input := range ranker.Inputs { + for _, feature := range input.Features { + featureParts := strings.Split(feature, PIPE_DELIMITER) + if len(featureParts) != 2 { + return Response{ + Error: "invalid feature: " + feature + " in input features of ranker: " + ranker.ModelName, + Data: Message{Message: emptyResponse}, + }, errors.New("invalid feature: " + feature + " in input features of ranker: " + ranker.ModelName) + } + if strings.Contains(featureParts[0], MODEL_FEATURE) { + if !outputs.Contains(featureParts[1]) { + return Response{ + Error: "model score " + featureParts[1] + " is not found in other model scores of ranker: " + ranker.ModelName, + Data: Message{Message: emptyResponse}, + }, errors.New("model score " + featureParts[1] + " is not found in other model scores of ranker: " + ranker.ModelName) + } + } + } + } + } + + for _, reRanker := range request.ReRankers { + for _, feature := range reRanker.EqVariables { + featureParts := strings.Split(feature, PIPE_DELIMITER) + if len(featureParts) != 2 { + return Response{ + Error: "invalid feature: " + feature, + Data: Message{Message: emptyResponse}, + }, errors.New("invalid feature: " + feature) + } + if strings.Contains(featureParts[0], MODEL_FEATURE) { + if !outputs.Contains(featureParts[1]) { + return Response{ + Error: "model score " + featureParts[1] + " is not found in other model scores of re ranker: " + strconv.Itoa(reRanker.EqID), + Data: Message{Message: emptyResponse}, + }, errors.New("model score " + featureParts[1] + " is not found in other model scores of re ranker: " + strconv.Itoa(reRanker.EqID)) + } + } + } } return Response{ @@ -956,26 +1489,18 @@ func (m *InferFlow) ExecuteFuncitonalTestRequest(request ExecuteRequestFunctiona ep = strings.TrimPrefix(ep, "https://") } ep = strings.TrimSuffix(ep, "/") - - // Check if port is already present - hasPort := false if idx := strings.LastIndex(ep, ":"); idx != -1 { - // Ensure the colon is part of a port, not in a hostname if idx < len(ep)-1 { - hasPort = true + ep = ep[:idx] } } - // Only add port if not already present - if !hasPort { - port := ":8080" - env := strings.ToLower(strings.TrimSpace(inferflowPkg.AppEnv)) - if env == "stg" || env == "int" { - port = ":80" - } - ep = ep + port + port := ":8080" + env := strings.ToLower(strings.TrimSpace(inferflowPkg.AppEnv)) + if env == "stg" || env == "int" { + port = ":80" } - + ep = ep + port return ep }(request.EndPoint) @@ -1018,7 +1543,7 @@ func (m *InferFlow) ExecuteFuncitonalTestRequest(request ExecuteRequestFunctiona protoResponse := &pb.InferflowResponseProto{} - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err = grpc.SendGRPCRequest(ctx, conn, inferFlowRetrieveModelScoreMethod, protoRequest, protoResponse, md) @@ -1053,7 +1578,9 @@ func (m *InferFlow) ExecuteFuncitonalTestRequest(request ExecuteRequestFunctiona inferFlowConfig, err := m.InferFlowConfigRepo.GetByID(request.RequestBody.ModelConfigID) if err != nil { - fmt.Println("Error getting model proxy config: ", err) + fmt.Println("Error getting inferflow config: ", err) + } else if inferFlowConfig == nil { + log.Error().Msgf("inferflow config '%s' does not exist in DB", request.RequestBody.ModelConfigID) } else { if response.Error != emptyResponse { inferFlowConfig.TestResults = inferflow.TestResults{ @@ -1068,7 +1595,7 @@ func (m *InferFlow) ExecuteFuncitonalTestRequest(request ExecuteRequestFunctiona } err = m.InferFlowConfigRepo.Update(inferFlowConfig) if err != nil { - fmt.Println("Error updating model proxy config: ", err) + fmt.Println("Error updating inferflow config: ", err) } } @@ -1117,8 +1644,56 @@ func (m *InferFlow) GetLatestRequest(requestID string) (GetLatestRequestResponse }, nil } +func (m *InferFlow) GetDerivedConfigID(configID string, deployableID int) (string, error) { + serviceDeployableConfig, err := m.ServiceDeployableConfigRepo.GetById(deployableID) + if err != nil { + return "", fmt.Errorf("failed to fetch service service deployable config for name generation: %w", err) + } + deployableTag := serviceDeployableConfig.DeployableTag + if deployableTag == "" { + return configID, nil + } + + derivedConfigID := configID + deployableTagDelimiter + deployableTag + deployableTagDelimiter + scaleupTag + return derivedConfigID, nil +} + func (m *InferFlow) GetLoggingTTL() (GetLoggingTTLResponse, error) { return GetLoggingTTLResponse{ Data: []int{30, 60, 90}, }, nil } + +func (m *InferFlow) GetFeatureSchema(request FeatureSchemaRequest) (FeatureSchemaResponse, error) { + version, err := strconv.Atoi(request.Version) + if err != nil { + return FeatureSchemaResponse{ + Data: []inferflow.SchemaComponents{}, + }, err + } + inferflowRequests, err := m.InferFlowRequestRepo.GetByConfigIDandVersion(request.ModelConfigId, version) + if err != nil { + log.Error().Err(err).Str("model_config_id", request.ModelConfigId).Msg("Failed to get inferflow config") + return FeatureSchemaResponse{ + Data: []inferflow.SchemaComponents{}, + }, err + } + inferflowConfig := inferflowRequests[0].Payload + componentConfig := &inferflowConfig.ConfigValue.ComponentConfig + responseConfig := &inferflowConfig.ConfigValue.ResponseConfig + + response := BuildFeatureSchemaFromInferflow(componentConfig, responseConfig) + + if responseConfig.LogSelectiveFeatures { + responseSchemaComponents := ProcessResponseConfigFromInferflow(responseConfig, response) + log.Info().Str("model_config_id", request.ModelConfigId).Int("schema_components_count", len(responseSchemaComponents)).Msg("Successfully generated feature schema") + return FeatureSchemaResponse{ + Data: responseSchemaComponents, + }, nil + } + + log.Info().Str("model_config_id", request.ModelConfigId).Int("schema_components_count", len(response)).Msg("Successfully generated feature schema") + return FeatureSchemaResponse{ + Data: response, + }, nil +} diff --git a/horizon/internal/inferflow/handler/models.go b/horizon/internal/inferflow/handler/models.go index c595ba93..95bac520 100644 --- a/horizon/internal/inferflow/handler/models.go +++ b/horizon/internal/inferflow/handler/models.go @@ -52,6 +52,7 @@ type ResponseConfig struct { ResponseFeatures []string `json:"response_features"` LogSelectiveFeatures bool `json:"log_features"` LogBatchSize int `json:"log_batch_size"` + LoggingTTL int `json:"logging_ttl"` } type ConfigMapping struct { @@ -60,6 +61,7 @@ type ConfigMapping struct { DeployableID int `json:"deployable_id,omitempty"` DeployableName string `json:"deployable_name,omitempty"` ResponseDefaultValues []string `json:"response_default_values"` + SourceConfigID string `json:"source_config_id"` } type OnboardPayload struct { @@ -84,6 +86,7 @@ func (r InferflowOnboardRequest) GetConfigMapping() dbModel.ConfigMapping { ConnectionConfigID: r.Payload.ConfigMapping.ConnectionConfigID, DeployableID: r.Payload.ConfigMapping.DeployableID, ResponseDefaultValues: r.Payload.ConfigMapping.ResponseDefaultValues, + SourceConfigID: r.Payload.ConfigMapping.SourceConfigID, } } @@ -138,6 +141,7 @@ type FinalResponseConfig struct { Features []string `json:"features"` LogSelectiveFeatures bool `json:"log_features"` LogBatchSize int `json:"log_batch_size"` + LoggingTTL int `json:"logging_ttl"` } type FSKey struct { @@ -180,14 +184,23 @@ type RTPComponent struct { CompCacheEnabled bool `json:"comp_cache_enabled"` } +type SeenScoreComponent struct { + Component string `json:"component"` + ComponentID string `json:"component_id,omitempty"` + ColNamePrefix string `json:"col_name_prefix,omitempty"` + FSKeys []FSKey `json:"fs_keys"` + FSRequest *FSRequest `json:"fs_request"` +} + type ComponentConfig struct { - CacheEnabled bool `json:"cache_enabled"` - CacheTTL int `json:"cache_ttl"` - CacheVersion int `json:"cache_version"` - FeatureComponents []FeatureComponent `json:"feature_components"` - RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` - PredatorComponents []PredatorComponent `json:"predator_components"` - NumerixComponents []NumerixComponent `json:"numerix_components"` + CacheEnabled bool `json:"cache_enabled"` + CacheTTL int `json:"cache_ttl"` + CacheVersion int `json:"cache_version"` + FeatureComponents []FeatureComponent `json:"feature_components"` + RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` + SeenScoreComponents []SeenScoreComponent `json:"seen_score_components"` + PredatorComponents []PredatorComponent `json:"predator_components"` + NumerixComponents []NumerixComponent `json:"numerix_components"` } type DagExecutionConfig struct { @@ -231,6 +244,7 @@ func (r PromoteConfigRequest) GetConfigMapping() dbModel.ConfigMapping { ConnectionConfigID: r.Payload.ConfigMapping.ConnectionConfigID, DeployableID: r.Payload.ConfigMapping.DeployableID, ResponseDefaultValues: r.Payload.ConfigMapping.ResponseDefaultValues, + SourceConfigID: r.Payload.ConfigMapping.SourceConfigID, } } @@ -245,6 +259,7 @@ func (r EditConfigOrCloneConfigRequest) GetConfigMapping() dbModel.ConfigMapping ConnectionConfigID: r.Payload.ConfigMapping.ConnectionConfigID, DeployableID: r.Payload.ConfigMapping.DeployableID, ResponseDefaultValues: r.Payload.ConfigMapping.ResponseDefaultValues, + SourceConfigID: r.Payload.ConfigMapping.SourceConfigID, } } @@ -259,6 +274,7 @@ type ScaleUpConfigPayload struct { ConfigValue InferflowConfig `json:"config_value"` ConfigMapping ConfigMapping `json:"config_mapping"` LoggingPerc int `json:"logging_perc"` + LoggingTTL int `json:"logging_ttl"` ModelNameToEndPointMap []ModelNameToEndPointMap `json:"proposed_model_endpoints"` } @@ -353,6 +369,7 @@ type ConfigTable struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` TestResults TestResults `json:"test_results"` + SourceConfigID string `json:"source_config_id"` } type TestResults struct { @@ -445,3 +462,12 @@ type RTPFeatureGroup struct { Features []string `json:"features"` DataType string `json:"dataType"` } + +type FeatureSchemaRequest struct { + ModelConfigId string `json:"model_config_id"` + Version string `json:"version"` +} + +type FeatureSchemaResponse struct { + Data []dbModel.SchemaComponents `json:"data"` +} diff --git a/horizon/internal/inferflow/handler/schema_adapter.go b/horizon/internal/inferflow/handler/schema_adapter.go new file mode 100644 index 00000000..c0fbfa57 --- /dev/null +++ b/horizon/internal/inferflow/handler/schema_adapter.go @@ -0,0 +1,267 @@ +package handler + +import ( + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow" + "github.com/Meesho/BharatMLStack/horizon/pkg/configschemaclient" +) + +// BuildFeatureSchemaFromInferflow builds a feature schema from the component and response configs. +// It converts inferflow types to configschemaclient types, calls the client, and converts back. +func BuildFeatureSchemaFromInferflow(componentConfig *inferflow.ComponentConfig, responseConfig *inferflow.ResponseConfig) []inferflow.SchemaComponents { + if componentConfig == nil { + return nil + } + + clientSchema := configschemaclient.BuildFeatureSchema( + toClientComponentConfig(componentConfig), + toClientResponseConfig(responseConfig), + ) + + return toInferflowSchemaComponents(clientSchema) +} + +// ProcessResponseConfigFromInferflow processes the response config and builds schema components. +func ProcessResponseConfigFromInferflow(responseConfig *inferflow.ResponseConfig, schemaComponents []inferflow.SchemaComponents) []inferflow.SchemaComponents { + if responseConfig == nil || len(responseConfig.Features) == 0 { + return nil + } + + clientSchemaComponents := make([]configschemaclient.SchemaComponents, len(schemaComponents)) + for i, c := range schemaComponents { + clientSchemaComponents[i] = configschemaclient.SchemaComponents{ + FeatureName: c.FeatureName, + FeatureType: c.FeatureType, + FeatureSize: c.FeatureSize, + } + } + + result := configschemaclient.ProcessResponseConfig( + toClientResponseConfig(responseConfig), + clientSchemaComponents, + ) + + return toInferflowSchemaComponents(result) +} + +func toClientComponentConfig(config *inferflow.ComponentConfig) *configschemaclient.ComponentConfig { + if config == nil { + return nil + } + + return &configschemaclient.ComponentConfig{ + CacheEnabled: config.CacheEnabled, + CacheTTL: config.CacheTTL, + CacheVersion: config.CacheVersion, + FeatureComponents: toClientFeatureComponents(config.FeatureComponents), + RTPComponents: toClientRTPComponents(config.RTPComponents), + PredatorComponents: toClientPredatorComponents(config.PredatorComponents), + NumerixComponents: toClientNumerixComponents(config.NumerixComponents), + } +} + +func toClientResponseConfig(config *inferflow.ResponseConfig) *configschemaclient.ResponseConfig { + if config == nil { + return nil + } + + return &configschemaclient.ResponseConfig{ + LoggingPerc: config.LoggingPerc, + ModelSchemaPerc: config.ModelSchemaPerc, + Features: config.Features, + LogSelectiveFeatures: config.LogSelectiveFeatures, + LogBatchSize: config.LogBatchSize, + } +} + +func toClientNumerixComponents(components []inferflow.NumerixComponent) []configschemaclient.NumerixComponent { + if len(components) == 0 { + return nil + } + + result := make([]configschemaclient.NumerixComponent, len(components)) + for i, c := range components { + result[i] = configschemaclient.NumerixComponent{ + Component: c.Component, + ComponentID: c.ComponentID, + ScoreCol: c.ScoreCol, + ComputeID: c.ComputeID, + ScoreMapping: c.ScoreMapping, + DataType: c.DataType, + } + } + return result +} + +func toClientFeatureComponents(components []inferflow.FeatureComponent) []configschemaclient.FeatureComponent { + if len(components) == 0 { + return nil + } + + result := make([]configschemaclient.FeatureComponent, len(components)) + for i, c := range components { + result[i] = configschemaclient.FeatureComponent{ + Component: c.Component, + ComponentID: c.ComponentID, + ColNamePrefix: c.ColNamePrefix, + CompCacheEnabled: c.CompCacheEnabled, + CompCacheTTL: c.CompCacheTTL, + CompositeID: c.CompositeID, + FSKeys: toClientFSKeys(c.FSKeys), + FSRequest: toClientFSRequest(c.FSRequest), + FSFlattenRespKeys: c.FSFlattenRespKeys, + } + } + return result +} + +func toClientRTPComponents(components []inferflow.RTPComponent) []configschemaclient.RTPComponent { + if len(components) == 0 { + return nil + } + + result := make([]configschemaclient.RTPComponent, len(components)) + for i, c := range components { + result[i] = configschemaclient.RTPComponent{ + Component: c.Component, + ComponentID: c.ComponentID, + CompositeID: c.CompositeID, + FSKeys: toClientFSKeys(c.FSKeys), + FSRequest: toClientFSRequest(c.FSRequest), + FSFlattenRespKeys: c.FSFlattenRespKeys, + ColNamePrefix: c.ColNamePrefix, + CompCacheEnabled: c.CompCacheEnabled, + } + } + return result +} + +func toClientPredatorComponents(components []inferflow.PredatorComponent) []configschemaclient.PredatorComponent { + if len(components) == 0 { + return nil + } + + result := make([]configschemaclient.PredatorComponent, len(components)) + for i, c := range components { + result[i] = configschemaclient.PredatorComponent{ + Component: c.Component, + ComponentID: c.ComponentID, + ModelName: c.ModelName, + ModelEndPoint: c.ModelEndPoint, + Calibration: c.Calibration, + Deadline: c.Deadline, + BatchSize: c.BatchSize, + Inputs: toClientPredatorInputs(c.Inputs), + Outputs: toClientPredatorOutputs(c.Outputs), + RoutingConfig: toClientRoutingConfigs(c.RoutingConfig), + } + } + return result +} + +func toClientFSKeys(keys []inferflow.FSKey) []configschemaclient.FSKey { + if len(keys) == 0 { + return nil + } + + result := make([]configschemaclient.FSKey, len(keys)) + for i, k := range keys { + result[i] = configschemaclient.FSKey{ + Schema: k.Schema, + Col: k.Col, + } + } + return result +} + +func toClientFSRequest(req *inferflow.FSRequest) *configschemaclient.FSRequest { + if req == nil { + return nil + } + + return &configschemaclient.FSRequest{ + Label: req.Label, + FeatureGroups: toClientFSFeatureGroups(req.FeatureGroups), + } +} + +func toClientFSFeatureGroups(groups []inferflow.FSFeatureGroup) []configschemaclient.FSFeatureGroup { + if len(groups) == 0 { + return nil + } + + result := make([]configschemaclient.FSFeatureGroup, len(groups)) + for i, g := range groups { + result[i] = configschemaclient.FSFeatureGroup{ + Label: g.Label, + Features: g.Features, + DataType: g.DataType, + } + } + return result +} + +func toClientPredatorInputs(inputs []inferflow.PredatorInput) []configschemaclient.PredatorInput { + if len(inputs) == 0 { + return nil + } + + result := make([]configschemaclient.PredatorInput, len(inputs)) + for i, input := range inputs { + result[i] = configschemaclient.PredatorInput{ + Name: input.Name, + Features: input.Features, + Dims: input.Dims, + DataType: input.DataType, + } + } + return result +} + +func toClientPredatorOutputs(outputs []inferflow.PredatorOutput) []configschemaclient.PredatorOutput { + if len(outputs) == 0 { + return nil + } + + result := make([]configschemaclient.PredatorOutput, len(outputs)) + for i, output := range outputs { + result[i] = configschemaclient.PredatorOutput{ + Name: output.Name, + ModelScores: output.ModelScores, + ModelScoresDims: output.ModelScoresDims, + DataType: output.DataType, + } + } + return result +} + +func toClientRoutingConfigs(configs []inferflow.RoutingConfig) []configschemaclient.RoutingConfig { + if len(configs) == 0 { + return nil + } + + result := make([]configschemaclient.RoutingConfig, len(configs)) + for i, c := range configs { + result[i] = configschemaclient.RoutingConfig{ + ModelName: c.ModelName, + ModelEndpoint: c.ModelEndpoint, + RoutingPercentage: c.RoutingPercentage, + } + } + return result +} + +func toInferflowSchemaComponents(components []configschemaclient.SchemaComponents) []inferflow.SchemaComponents { + if len(components) == 0 { + return nil + } + + result := make([]inferflow.SchemaComponents, len(components)) + for i, c := range components { + result[i] = inferflow.SchemaComponents{ + FeatureName: c.FeatureName, + FeatureType: c.FeatureType, + FeatureSize: c.FeatureSize, + } + } + return result +} diff --git a/horizon/internal/inferflow/route/router.go b/horizon/internal/inferflow/route/router.go index 8483c6a8..429ed761 100644 --- a/horizon/internal/inferflow/route/router.go +++ b/horizon/internal/inferflow/route/router.go @@ -25,6 +25,7 @@ func Init() { register.GET("/logging-ttl", controller.NewConfigController().GetLoggingTTL) register.PATCH("/delete", controller.NewConfigController().Delete) register.GET("/latestRequest/:config_id", controller.NewConfigController().GetLatestRequest) + register.GET("/get_feature_schema", controller.NewConfigController().GetFeatureSchema) } discovery := v1.Group("/inferflow-config-discovery") diff --git a/horizon/internal/repositories/sql/discoveryconfig/sql.go b/horizon/internal/repositories/sql/discoveryconfig/sql.go index 37fb3180..39577085 100644 --- a/horizon/internal/repositories/sql/discoveryconfig/sql.go +++ b/horizon/internal/repositories/sql/discoveryconfig/sql.go @@ -2,6 +2,7 @@ package discoveryconfig import ( "errors" + "time" "github.com/Meesho/BharatMLStack/horizon/pkg/infra" "gorm.io/gorm" @@ -16,8 +17,12 @@ type DiscoveryConfigRepository interface { GetByToken(token string) ([]DiscoveryConfig, error) GetById(configId int) (*DiscoveryConfig, error) GetByServiceDeployableID(serviceDeployableID int) ([]DiscoveryConfig, error) + GetByServiceDeployableIDs(serviceDeployableIDs []int) ([]DiscoveryConfig, error) DB() *gorm.DB WithTx(tx *gorm.DB) DiscoveryConfigRepository + DeleteByIDTx(tx *gorm.DB, id int) error + ReactivateByIDTx(tx *gorm.DB, id int) error + DeactivateByID(id int, updatedBy string) error } type discoveryConfigRepo struct { @@ -102,3 +107,30 @@ func (r *discoveryConfigRepo) WithTx(tx *gorm.DB) DiscoveryConfigRepository { db: tx, } } + +func (r *discoveryConfigRepo) GetByServiceDeployableIDs(serviceDeployableIDs []int) ([]DiscoveryConfig, error) { + if len(serviceDeployableIDs) == 0 { + return []DiscoveryConfig{}, nil + } + var configs []DiscoveryConfig + err := r.db.Where("id IN ?", serviceDeployableIDs).Find(&configs).Error + return configs, err +} + +func (r *discoveryConfigRepo) DeleteByIDTx(tx *gorm.DB, id int) error { + return tx.Where("id = ?", id).Delete(&DiscoveryConfig{}).Error +} + +func (r *discoveryConfigRepo) ReactivateByIDTx(tx *gorm.DB, id int) error { + return tx.Model(&DiscoveryConfig{}).Where("id = ?", id).Update("active", true).Error +} + +func (r *discoveryConfigRepo) DeactivateByID(id int, updatedBy string) error { + return r.db.Model(&DiscoveryConfig{}). + Where("id = ?", id). + Updates(map[string]interface{}{ + "active": false, + "updated_by": updatedBy, + "updated_at": time.Now(), + }).Error +} diff --git a/horizon/internal/repositories/sql/inferflow/config/repository.go b/horizon/internal/repositories/sql/inferflow/config/repository.go index 4370f0aa..de615420 100644 --- a/horizon/internal/repositories/sql/inferflow/config/repository.go +++ b/horizon/internal/repositories/sql/inferflow/config/repository.go @@ -18,6 +18,9 @@ type Repository interface { Update(table *Table) error FindByDiscoveryIDsAndCreatedBefore(discoveryIDs []int, daysAgo int) ([]Table, error) Deactivate(configID string) error + DeleteByConfigIDTx(tx *gorm.DB, configID string) error + GetLatestInactiveByConfigID(tx *gorm.DB, configID string) (*Table, error) + ReactivateByIDTx(tx *gorm.DB, id int, updatedBy string) error } type InferflowConfig struct { @@ -84,8 +87,17 @@ func (g *InferflowConfig) GetAll() ([]Table, error) { } func (g *InferflowConfig) GetByID(configID string) (table *Table, err error) { - result := g.db.Where("config_id = ? and active = ?", configID, true).First(&table) - return table, result.Error + result := g.db.Where("config_id = ? and active = ?", configID, true). + Order("updated_at DESC"). + First(table) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, result.Error + } + return table, nil } func (g *InferflowConfig) DoesConfigIDExist(configID string) (bool, error) { @@ -119,3 +131,44 @@ func (r *InferflowConfig) FindByDiscoveryIDsAndCreatedBefore(discoveryIDs []int, func (g *InferflowConfig) Deactivate(configID string) error { return g.db.Model(&Table{}).Where("config_id = ?", configID).Update("active", false).Error } + +func (g *InferflowConfig) DeleteByConfigIDTx(tx *gorm.DB, configID string) error { + result := tx.Where("config_id = ?", configID).Delete(&Table{}) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errors.New("no config found with the given config_id") + } + return nil +} + +func (g *InferflowConfig) GetLatestInactiveByConfigID(tx *gorm.DB, configID string) (*Table, error) { + var table Table + err := tx.Where("config_id = ? AND active = ?", configID, false). + Order("updated_at DESC"). + First(&table).Error + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return &table, nil +} + +func (g *InferflowConfig) ReactivateByIDTx(tx *gorm.DB, id int, updatedBy string) error { + result := tx.Model(&Table{}).Where("id = ?", id).Updates(map[string]interface{}{ + "active": true, + "updated_by": updatedBy, + "updated_at": time.Now(), + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errors.New("no config found with the given id") + } + return nil +} diff --git a/horizon/internal/repositories/sql/inferflow/config/table.go b/horizon/internal/repositories/sql/inferflow/config/table.go index de43f8cb..ef80a759 100644 --- a/horizon/internal/repositories/sql/inferflow/config/table.go +++ b/horizon/internal/repositories/sql/inferflow/config/table.go @@ -26,6 +26,7 @@ type Table struct { CreatedAt time.Time `gorm:"not null"` UpdatedAt time.Time TestResults inferflow.TestResults `gorm:"type:json"` + SourceConfigID string `gorm:"column:source_config_id"` } func (Table) TableName() string { diff --git a/horizon/internal/repositories/sql/inferflow/models.go b/horizon/internal/repositories/sql/inferflow/models.go index 68ecce49..0ac32f19 100644 --- a/horizon/internal/repositories/sql/inferflow/models.go +++ b/horizon/internal/repositories/sql/inferflow/models.go @@ -58,6 +58,7 @@ type ResponseConfig struct { Features []string `json:"features"` LogSelectiveFeatures bool `json:"log_features"` LogBatchSize int `json:"log_batch_size"` + LoggingTTL int `json:"logging_ttl"` } type FSKey struct { @@ -99,14 +100,23 @@ type RTPComponent struct { CompCacheEnabled bool `json:"comp_cache_enabled"` } +type SeenScoreComponent struct { + Component string `json:"component"` + ComponentID string `json:"component_id,omitempty"` + ColNamePrefix string `json:"col_name_prefix,omitempty"` + FSKeys []FSKey `json:"fs_keys"` + FSRequest *FSRequest `json:"fs_request"` +} + type ComponentConfig struct { - CacheEnabled bool `json:"cache_enabled"` - CacheTTL int `json:"cache_ttl"` - CacheVersion int `json:"cache_version"` - FeatureComponents []FeatureComponent `json:"feature_components"` - RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` - PredatorComponents []PredatorComponent `json:"predator_components"` - NumerixComponents []NumerixComponent `json:"numerix_components"` + CacheEnabled bool `json:"cache_enabled"` + CacheTTL int `json:"cache_ttl"` + CacheVersion int `json:"cache_version"` + FeatureComponents []FeatureComponent `json:"feature_components"` + RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` + PredatorComponents []PredatorComponent `json:"predator_components"` + NumerixComponents []NumerixComponent `json:"numerix_components"` + SeenScoreComponents []SeenScoreComponent `json:"seen_score_components"` } type InferflowConfig struct { @@ -120,6 +130,7 @@ type ConfigMapping struct { ConnectionConfigID int `json:"connection_config_id"` DeployableID int `json:"deployable_id"` ResponseDefaultValues []string `json:"response_default_values"` + SourceConfigID string `json:"source_config_id"` } type OnboardPayload struct { @@ -234,3 +245,13 @@ func (t *TestResults) Scan(value interface{}) error { func (t TestResults) Value() (driver.Value, error) { return json.Marshal(t) } + +type GetSchemaResponse struct { + Components []SchemaComponents +} + +type SchemaComponents struct { + FeatureName string `json:"feature_name"` + FeatureType string `json:"feature_type"` + FeatureSize any `json:"feature_size"` +} diff --git a/horizon/internal/repositories/sql/inferflow/request/repository.go b/horizon/internal/repositories/sql/inferflow/request/repository.go index dbe5d34d..f124ec04 100644 --- a/horizon/internal/repositories/sql/inferflow/request/repository.go +++ b/horizon/internal/repositories/sql/inferflow/request/repository.go @@ -21,8 +21,10 @@ type Repository interface { DoesConfigIdExistWithRequestType(configID string, requestType string) (bool, error) CurrentRequestStatus(requestID uint) (string, error) Deactivate(configID string) error - GetLatestPendingRequestByConfigID(configID string) ([]Table, error) + GetByConfigIDandVersion(configID string, version int) ([]Table, error) GetApprovedRequestsByConfigID(configID string) ([]Table, error) + GetLatestPendingRequestByConfigID(configID string) ([]Table, error) + GetRequestByID(requestID uint) (Table, error) } type InferflowRequest struct { @@ -91,6 +93,18 @@ func (g *InferflowRequest) GetByUser(email string) ([]Table, error) { return tables, result.Error } +func (g *InferflowRequest) GetByConfigIDandVersion(configID string, version int) ([]Table, error) { + var tables []Table + result := g.db.Where("config_id = ? AND version = ? AND status = 'APPROVED'", configID, version).Find(&tables) + if result.Error != nil { + return nil, result.Error + } + if len(tables) == 0 { + return nil, errors.New("no request found with the given config_id and version") + } + return tables, result.Error +} + func (g *InferflowRequest) DoesRecordExist(query string, args ...interface{}) (bool, error) { var table Table result := g.db.Where(query, args...).First(&table) @@ -112,7 +126,7 @@ func (g *InferflowRequest) DoesRequestIDExist(requestID uint) (bool, error) { } func (g *InferflowRequest) DoesConfigIdExistWithRequestType(configID string, requestType string) (bool, error) { - return g.DoesRecordExist("config_id = ? AND request_type = ? AND STATUS != 'REJECTED'", configID, requestType) + return g.DoesRecordExist("config_id = ? AND request_type = ? AND STATUS = 'PENDING APPROVAL'", configID, requestType) } func (g *InferflowRequest) CurrentRequestStatus(requestID uint) (string, error) { @@ -144,3 +158,15 @@ func (g *InferflowRequest) GetLatestPendingRequestByConfigID(configID string) ([ Find(&tables) return tables, result.Error } + +func (g *InferflowRequest) GetRequestByID(requestID uint) (Table, error) { + var table Table + result := g.db.Where("request_id = ?", requestID).First(&table) + if result.Error != nil { + if result.Error == gorm.ErrRecordNotFound { + return Table{}, nil + } + return Table{}, result.Error + } + return table, nil +} diff --git a/horizon/internal/repositories/sql/servicedeployableconfig/sql.go b/horizon/internal/repositories/sql/servicedeployableconfig/sql.go index d8d12260..09744fe9 100644 --- a/horizon/internal/repositories/sql/servicedeployableconfig/sql.go +++ b/horizon/internal/repositories/sql/servicedeployableconfig/sql.go @@ -18,6 +18,7 @@ type ServiceDeployableRepository interface { GetByWorkflowStatus(status string) ([]ServiceDeployableConfig, error) GetByDeployableHealth(health string) ([]ServiceDeployableConfig, error) GetByNameAndService(name, service string) (*ServiceDeployableConfig, error) + GetByIds(ids []int) ([]ServiceDeployableConfig, error) } type serviceDeployableRepo struct { @@ -98,3 +99,12 @@ func (r *serviceDeployableRepo) GetByNameAndService(name, service string) (*Serv } return &deployable, nil } + +func (r *serviceDeployableRepo) GetByIds(ids []int) ([]ServiceDeployableConfig, error) { + if len(ids) == 0 { + return []ServiceDeployableConfig{}, nil + } + var deployables []ServiceDeployableConfig + err := r.db.Where("id IN ?", ids).Find(&deployables).Error + return deployables, err +} diff --git a/horizon/internal/repositories/sql/servicedeployableconfig/table.go b/horizon/internal/repositories/sql/servicedeployableconfig/table.go index 9786ab91..a9132d6e 100644 --- a/horizon/internal/repositories/sql/servicedeployableconfig/table.go +++ b/horizon/internal/repositories/sql/servicedeployableconfig/table.go @@ -28,6 +28,8 @@ type ServiceDeployableConfig struct { DeploymentRunID string DeployableHealth string `gorm:"type:ENUM('DEPLOYMENT_REASON_ARGO_APP_HEALTH_DEGRADED', 'DEPLOYMENT_REASON_ARGO_APP_HEALTHY')"` WorkFlowStatus string `gorm:"type:ENUM('WORKFLOW_COMPLETED' , 'WORKFLOW_NOT_FOUND' , 'WORKFLOW_RUNNING','WORKFLOW_FAILED' ,'WORKFLOW_NOT_STARTED' )"` + OverrideTesting bool `gorm:"default:false"` + DeployableTag string `gorm:"column:deployable_tag"` } func (ServiceDeployableConfig) TableName() string { diff --git a/horizon/pkg/configschemaclient/client.go b/horizon/pkg/configschemaclient/client.go new file mode 100644 index 00000000..d372f616 --- /dev/null +++ b/horizon/pkg/configschemaclient/client.go @@ -0,0 +1,241 @@ +package configschemaclient + +import ( + "strings" +) + +// BuildFeatureSchema builds a feature schema from the component and response configs. +// It processes components in order: FS → RTP → Numerix Output → Predator Output → Numerix Input → Predator Input +func BuildFeatureSchema(componentConfig *ComponentConfig, responseConfig *ResponseConfig) []SchemaComponents { + if componentConfig == nil { + return nil + } + + existingFeatures := make(map[string]bool) + var response []SchemaComponents + + addUniqueComponents := func(components []SchemaComponents) { + for _, component := range components { + if !existingFeatures[component.FeatureName] { + response = append(response, component) + existingFeatures[component.FeatureName] = true + } + } + } + + addOrUpdateComponents := func(components []SchemaComponents) { + for _, component := range components { + if !existingFeatures[component.FeatureName] { + component.FeatureType = "String" + response = append(response, component) + existingFeatures[component.FeatureName] = true + } + } + } + + // 1. FS (Feature Store) + addUniqueComponents(processFS(componentConfig.FeatureComponents)) + + // 2. RTP (Real Time Pricing) + addUniqueComponents(processRTP(componentConfig.RTPComponents)) + + // 3. Numerix Output + addUniqueComponents(processNumerixOutput(componentConfig.NumerixComponents)) + + // 4. Predator Output + addUniqueComponents(processPredatorOutput(componentConfig.PredatorComponents)) + + // 5. Numerix Input (only add if not already present) + addOrUpdateComponents(processNumerixInput(componentConfig.NumerixComponents)) + + // 6. Predator Input (only add if not already present) + addOrUpdateComponents(processPredatorInput(componentConfig.PredatorComponents)) + + return response +} + +func processNumerixOutput(numerixComponents []NumerixComponent) []SchemaComponents { + if len(numerixComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, numerixComponent := range numerixComponents { + response = append(response, SchemaComponents{ + FeatureName: numerixComponent.ScoreCol, + FeatureType: numerixComponent.DataType, + FeatureSize: 1, + }) + } + return response +} + +func processNumerixInput(numerixComponents []NumerixComponent) []SchemaComponents { + if len(numerixComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, numerixComponent := range numerixComponents { + for numerixInput, featureName := range numerixComponent.ScoreMapping { + inputParts := strings.Split(numerixInput, "@") + response = append(response, SchemaComponents{ + FeatureName: featureName, + FeatureType: inputParts[1], + FeatureSize: 1, + }) + } + } + return response +} + +func getFeatureName(prefix, entityLabel, fgLabel, feature string) string { + featureName := "" + if prefix != "" { + featureName = prefix + } + if entityLabel != "" { + featureName = featureName + entityLabel + ":" + } + if fgLabel != "" { + featureName = featureName + fgLabel + ":" + } + return featureName + feature +} + +func processFS(featureComponents []FeatureComponent) []SchemaComponents { + if len(featureComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, featureComponent := range featureComponents { + if featureComponent.FSRequest == nil { + continue + } + for _, featureGroup := range featureComponent.FSRequest.FeatureGroups { + for _, feature := range featureGroup.Features { + response = append(response, SchemaComponents{ + FeatureName: getFeatureName(featureComponent.ColNamePrefix, featureComponent.FSRequest.Label, featureGroup.Label, feature), + FeatureType: featureGroup.DataType, + FeatureSize: 1, + }) + } + } + } + return response +} + +func processRTP(rtpComponents []RTPComponent) []SchemaComponents { + if len(rtpComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, rtpComponent := range rtpComponents { + if rtpComponent.FSRequest == nil { + continue + } + for _, featureGroup := range rtpComponent.FSRequest.FeatureGroups { + for _, feature := range featureGroup.Features { + response = append(response, SchemaComponents{ + FeatureName: getFeatureName(rtpComponent.ColNamePrefix, rtpComponent.FSRequest.Label, featureGroup.Label, feature), + FeatureType: featureGroup.DataType, + FeatureSize: 1, + }) + } + } + } + return response +} + +func processPredatorOutput(predatorComponents []PredatorComponent) []SchemaComponents { + if len(predatorComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, predatorComponent := range predatorComponents { + for _, output := range predatorComponent.Outputs { + for index, modelScore := range output.ModelScores { + var featureSize any = 1 + dataType := output.DataType + if index < len(output.ModelScoresDims) { + featureSize, dataType = getPredatorFeatureTypeAndSize(output.DataType, output.ModelScoresDims[index]) + } + response = append(response, SchemaComponents{ + FeatureName: modelScore, + FeatureType: dataType, + FeatureSize: featureSize, + }) + } + } + } + return response +} + +func processPredatorInput(predatorComponents []PredatorComponent) []SchemaComponents { + if len(predatorComponents) == 0 { + return nil + } + + var response []SchemaComponents + for _, predatorComponent := range predatorComponents { + for _, input := range predatorComponent.Inputs { + for _, feature := range input.Features { + size, dataType := getPredatorFeatureTypeAndSize(input.DataType, input.Dims) + response = append(response, SchemaComponents{ + FeatureName: feature, + FeatureType: dataType, + FeatureSize: size, + }) + } + } + } + return response +} + +func getPredatorFeatureTypeAndSize(dataType string, shape []int) (int, string) { + if len(shape) == 1 && shape[0] == 1 { + return 1, dataType + } + if len(shape) == 2 && shape[0] == -1 { + return shape[1], dataType + "Vector" + } + if len(shape) > 0 { + return shape[0], dataType + "Vector" + } + return 1, dataType +} + +// ProcessResponseConfig processes the response config and builds schema components +// based on the features specified in the response config. +func ProcessResponseConfig(responseConfig *ResponseConfig, schemaComponents []SchemaComponents) []SchemaComponents { + if responseConfig == nil || len(responseConfig.Features) == 0 { + return nil + } + + var response []SchemaComponents + + schemaMap := make(map[string]SchemaComponents) + for _, component := range schemaComponents { + schemaMap[component.FeatureName] = component + } + + for _, feature := range responseConfig.Features { + if existingComponent, exists := schemaMap[feature]; exists { + response = append(response, SchemaComponents{ + FeatureName: feature, + FeatureType: existingComponent.FeatureType, + FeatureSize: existingComponent.FeatureSize, + }) + } else { + response = append(response, SchemaComponents{ + FeatureName: feature, + FeatureType: "String", + FeatureSize: 1, + }) + } + } + return response +} diff --git a/horizon/pkg/configschemaclient/types.go b/horizon/pkg/configschemaclient/types.go new file mode 100644 index 00000000..6f63b7be --- /dev/null +++ b/horizon/pkg/configschemaclient/types.go @@ -0,0 +1,119 @@ +package configschemaclient + +// SchemaComponents represents a feature schema component +type SchemaComponents struct { + FeatureName string `json:"feature_name"` + FeatureType string `json:"feature_type"` + FeatureSize any `json:"feature_size"` +} + +// ComponentConfig contains all component configurations +type ComponentConfig struct { + CacheEnabled bool `json:"cache_enabled"` + CacheTTL int `json:"cache_ttl"` + CacheVersion int `json:"cache_version"` + FeatureComponents []FeatureComponent `json:"feature_components"` + RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` + PredatorComponents []PredatorComponent `json:"predator_components"` + NumerixComponents []NumerixComponent `json:"numerix_components"` +} + +// ResponseConfig contains response configuration +type ResponseConfig struct { + LoggingPerc int `json:"logging_perc"` + ModelSchemaPerc int `json:"model_schema_features_perc"` + Features []string `json:"features"` + LogSelectiveFeatures bool `json:"log_features"` + LogBatchSize int `json:"log_batch_size"` +} + +// NumerixComponent represents a Numerix/Numerix component +type NumerixComponent struct { + Component string `json:"component"` + ComponentID string `json:"component_id"` + ScoreCol string `json:"score_col"` + ComputeID string `json:"compute_id"` + ScoreMapping map[string]string `json:"score_mapping"` + DataType string `json:"data_type"` +} + +// FeatureComponent represents a feature store component +type FeatureComponent struct { + Component string `json:"component"` + ComponentID string `json:"component_id"` + ColNamePrefix string `json:"col_name_prefix,omitempty"` + CompCacheEnabled bool `json:"comp_cache_enabled"` + CompCacheTTL int `json:"comp_cache_ttl,omitempty"` + CompositeID bool `json:"composite_id,omitempty"` + FSKeys []FSKey `json:"fs_keys"` + FSRequest *FSRequest `json:"fs_request"` + FSFlattenRespKeys []string `json:"fs_flatten_resp_keys"` +} + +// RTPComponent represents a real-time pricing component +type RTPComponent struct { + Component string `json:"component"` + ComponentID string `json:"component_id"` + CompositeID bool `json:"composite_id"` + FSKeys []FSKey `json:"fs_keys"` + FSRequest *FSRequest `json:"fs_request"` + FSFlattenRespKeys []string `json:"fs_flatten_resp_keys"` + ColNamePrefix string `json:"col_name_prefix"` + CompCacheEnabled bool `json:"comp_cache_enabled"` +} + +// PredatorComponent represents a Predator model component +type PredatorComponent struct { + Component string `json:"component"` + ComponentID string `json:"component_id"` + ModelName string `json:"model_name"` + ModelEndPoint string `json:"model_end_point"` + Calibration string `json:"calibration,omitempty"` + Deadline int `json:"deadline"` + BatchSize int `json:"batch_size"` + Inputs []PredatorInput `json:"inputs"` + Outputs []PredatorOutput `json:"outputs"` + RoutingConfig []RoutingConfig `json:"route_config,omitempty"` +} + +// PredatorInput represents input configuration for Predator +type PredatorInput struct { + Name string `json:"name"` + Features []string `json:"features"` + Dims []int `json:"shape"` + DataType string `json:"data_type"` +} + +// PredatorOutput represents output configuration for Predator +type PredatorOutput struct { + Name string `json:"name"` + ModelScores []string `json:"model_scores"` + ModelScoresDims [][]int `json:"model_scores_dims"` + DataType string `json:"data_type"` +} + +// RoutingConfig represents routing configuration +type RoutingConfig struct { + ModelName string `json:"model_name"` + ModelEndpoint string `json:"model_endpoint"` + RoutingPercentage float32 `json:"routing_percentage"` +} + +// FSKey represents a feature store key +type FSKey struct { + Schema string `json:"schema"` + Col string `json:"col"` +} + +// FSRequest represents a feature store request +type FSRequest struct { + Label string `json:"label"` + FeatureGroups []FSFeatureGroup `json:"featureGroups"` +} + +// FSFeatureGroup represents a feature group +type FSFeatureGroup struct { + Label string `json:"label"` + Features []string `json:"features"` + DataType string `json:"data_type"` +} diff --git a/inferflow/handlers/inferflow/inferflow.go b/inferflow/handlers/inferflow/inferflow.go index 4ea6ab7e..550441cc 100644 --- a/inferflow/handlers/inferflow/inferflow.go +++ b/inferflow/handlers/inferflow/inferflow.go @@ -66,7 +66,7 @@ func InitInferflowHandler(configs *configs.AppConfigs) { }, }, } - logger.Info("Model Proxy handler initialized") + logger.Info("Inferflow handler initialized") } func ReloadModelConfigMapAndRegisterComponents() error { diff --git a/trufflehog/trufflehog-hook.sh b/trufflehog/trufflehog-hook.sh deleted file mode 100755 index 2825d238..00000000 --- a/trufflehog/trufflehog-hook.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -OUTPUT=$(trufflehog git file://. --since-commit HEAD --branch=$(git rev-parse --abbrev-ref HEAD) --no-update --json --results=verified 2>/dev/null) - -if echo "$OUTPUT" | grep -q "\"Verified\":true"; then - METADATA_COUNT=$(echo "$OUTPUT" | grep -o "SourceMetadata" | wc -l | xargs) - echo "🚨 $METADATA_COUNT Verified secret/s found! Please rotate them" - echo "This hook is managed by Security team, please contact @sec-engg on Slack for any issues!" - echo ""; echo "🔍 Detected Secrets:"; echo "$OUTPUT" | sed "s/}{/}\\n{/g" | jq -r "." - - - REPO_NAME=$(basename "$(git rev-parse --show-toplevel)") - BRANCH_NAME=$(git rev-parse --abbrev-ref HEAD) - USER_NAME=$(git config user.name) - USER_EMAIL=$(git config user.email) - - echo "$OUTPUT" | sed "s/}{/}\\n{/g" | while read -r finding; do - [ "$(echo "$finding" | jq -r '.Verified')" = true ] || continue - DETECTOR=$(echo "$finding" | jq -r ".DetectorName // \"unknown\"") - COMMIT=$(echo "$finding" | jq -r ".SourceMetadata.Data.Git.commit // \"unknown\"") - FILE=$(echo "$finding" | jq -r ".SourceMetadata.Data.Git.file // \"unknown\"") - LINE=$(echo "$finding" | jq -r ".SourceMetadata.Data.Git.line // \"unknown\"") - EMAIL=$(echo "$finding" | jq -r ".SourceMetadata.Data.Git.email // \"None\"") - - CMD64=$(cat < Date: Mon, 2 Feb 2026 16:50:31 +0530 Subject: [PATCH 02/24] internal components interface methods --- horizon/internal/inferflow/handler/adaptor.go | 255 +--------- .../component_builder_internal_stub.go | 105 ++++ .../inferflow/handler/config_builder.go | 470 +++--------------- .../internal/inferflow/handler/inferflow.go | 2 +- .../inferflow/handler/internal_components.go | 93 ++++ 5 files changed, 293 insertions(+), 632 deletions(-) create mode 100644 horizon/internal/inferflow/handler/component_builder_internal_stub.go create mode 100644 horizon/internal/inferflow/handler/internal_components.go diff --git a/horizon/internal/inferflow/handler/adaptor.go b/horizon/internal/inferflow/handler/adaptor.go index 2246ad8d..7b7f0bec 100644 --- a/horizon/internal/inferflow/handler/adaptor.go +++ b/horizon/internal/inferflow/handler/adaptor.go @@ -21,9 +21,9 @@ func AdaptOnboardRequestToDBPayload(req interface{}, inferflowConfig InferflowCo dbNumerixComponents := AdaptToDBNumerixComponent(inferflowConfig) - dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) - - dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + // Use interface for internal component adaptation (RTP, SeenScore) + dbRTPComponents := InternalComponentBuilderInstance.AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := InternalComponentBuilderInstance.AdaptToDBSeenScoreComponent(inferflowConfig) featureComponents := AdaptToDBFeatureComponent(inferflowConfig) @@ -48,9 +48,9 @@ func AdaptEditRequestToDBPayload(req interface{}, inferflowConfig InferflowConfi dbNumerixComponents := AdaptToDBNumerixComponent(inferflowConfig) - dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) - - dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + // Use interface for internal component adaptation (RTP, SeenScore) + dbRTPComponents := InternalComponentBuilderInstance.AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := InternalComponentBuilderInstance.AdaptToDBSeenScoreComponent(inferflowConfig) featureComponents := AdaptToDBFeatureComponent(inferflowConfig) @@ -76,9 +76,9 @@ func AdaptCloneConfigRequestToDBPayload(req interface{}, inferflowConfig Inferfl dbNumerixComponents := AdaptToDBNumerixComponent(inferflowConfig) - dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) - - dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + // Use interface for internal component adaptation (RTP, SeenScore) + dbRTPComponents := InternalComponentBuilderInstance.AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := InternalComponentBuilderInstance.AdaptToDBSeenScoreComponent(inferflowConfig) featureComponents := AdaptToDBFeatureComponent(inferflowConfig) @@ -105,9 +105,9 @@ func AdaptPromoteRequestToDBPayload(req interface{}, requestPayload RequestConfi dbNumerixComponents := AdaptToDBNumerixComponent(inferflowConfig) - dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) - - dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + // Use interface for internal component adaptation (RTP, SeenScore) + dbRTPComponents := InternalComponentBuilderInstance.AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := InternalComponentBuilderInstance.AdaptToDBSeenScoreComponent(inferflowConfig) featureComponents := AdaptToDBFeatureComponent(inferflowConfig) @@ -135,9 +135,9 @@ func AdaptScaleUpRequestToDBPayload(req interface{}, requestPayload RequestConfi dbNumerixComponents := AdaptToDBNumerixComponent(inferflowConfig) - dbRTPComponents := AdaptToDBRTPComponent(inferflowConfig) - - dbSeenScoreComponents := AdaptToDBSeenScoreComponent(inferflowConfig) + // Use interface for internal component adaptation (RTP, SeenScore) + dbRTPComponents := InternalComponentBuilderInstance.AdaptToDBRTPComponent(inferflowConfig) + dbSeenScoreComponents := InternalComponentBuilderInstance.AdaptToDBSeenScoreComponent(inferflowConfig) featureComponents := AdaptToDBFeatureComponent(inferflowConfig) @@ -247,79 +247,6 @@ func AdaptToDBNumerixComponent(inferflowConfig InferflowConfig) []dbModel.Numeri return NumerixComponents } -func AdaptToDBRTPComponent(inferflowConfig InferflowConfig) []dbModel.RTPComponent { - var rtpComponents []dbModel.RTPComponent - for _, rtpComponent := range inferflowConfig.ComponentConfig.RTPComponents { - fsKeys := make([]dbModel.FSKey, len(rtpComponent.FSKeys)) - for i, key := range rtpComponent.FSKeys { - fsKeys[i] = dbModel.FSKey{ - Schema: key.Schema, - Col: key.Col, - } - } - fsFeatureGroups := make([]dbModel.FSFeatureGroup, len(rtpComponent.FeatureRequest.FeatureGroups)) - for i, grp := range rtpComponent.FeatureRequest.FeatureGroups { - fsFeatureGroups[i] = dbModel.FSFeatureGroup{ - Label: grp.Label, - Features: grp.Features, - DataType: grp.DataType, - } - } - fsRequest := dbModel.FSRequest{ - Label: rtpComponent.FeatureRequest.Label, - FeatureGroups: fsFeatureGroups, - } - dbRTPComponent := dbModel.RTPComponent{ - Component: rtpComponent.Component, - ComponentID: rtpComponent.ComponentID, - CompositeID: rtpComponent.CompositeID, - FSKeys: fsKeys, - FSRequest: &fsRequest, - FSFlattenRespKeys: rtpComponent.FSFlattenRespKeys, - ColNamePrefix: rtpComponent.ColNamePrefix, - CompCacheEnabled: rtpComponent.CompCacheEnabled, - } - rtpComponents = append(rtpComponents, dbRTPComponent) - } - return rtpComponents -} - -func AdaptToDBSeenScoreComponent(inferflowConfig InferflowConfig) []dbModel.SeenScoreComponent { - var seenScoreComponents []dbModel.SeenScoreComponent - for _, seenScoreComponent := range inferflowConfig.ComponentConfig.SeenScoreComponents { - fsKeys := make([]dbModel.FSKey, len(seenScoreComponent.FSKeys)) - for i, key := range seenScoreComponent.FSKeys { - fsKeys[i] = dbModel.FSKey{ - Schema: key.Schema, - Col: key.Col, - } - } - var fsRequest *dbModel.FSRequest - if seenScoreComponent.FSRequest != nil { - fsFeatureGroups := make([]dbModel.FSFeatureGroup, len(seenScoreComponent.FSRequest.FeatureGroups)) - for i, grp := range seenScoreComponent.FSRequest.FeatureGroups { - fsFeatureGroups[i] = dbModel.FSFeatureGroup{ - Label: grp.Label, - Features: grp.Features, - DataType: grp.DataType, - } - } - fsRequest = &dbModel.FSRequest{ - Label: seenScoreComponent.FSRequest.Label, - FeatureGroups: fsFeatureGroups, - } - } - dbSeenScoreComponent := dbModel.SeenScoreComponent{ - Component: seenScoreComponent.Component, - ComponentID: seenScoreComponent.ComponentID, - ColNamePrefix: seenScoreComponent.ColNamePrefix, - FSKeys: fsKeys, - FSRequest: fsRequest, - } - seenScoreComponents = append(seenScoreComponents, dbSeenScoreComponent) - } - return seenScoreComponents -} func AdaptToDBFeatureComponent(inferflowConfig InferflowConfig) []dbModel.FeatureComponent { var featureComponents []dbModel.FeatureComponent @@ -552,8 +479,8 @@ func AdaptFromDbToComponentConfig(dbComponentConfig dbModel.ComponentConfig) *Co FeatureComponents: AdaptFromDbToFeatureComponent(dbComponentConfig.FeatureComponents), PredatorComponents: AdaptFromDbToPredatorComponent(dbComponentConfig.PredatorComponents), NumerixComponents: AdaptFromDbToNumerixComponent(dbComponentConfig.NumerixComponents), - RTPComponents: AdaptFromDbToRTPComponent(dbComponentConfig.RTPComponents), - SeenScoreComponents: AdaptFromDbToSeenScoreComponent(dbComponentConfig.SeenScoreComponents), + RTPComponents: InternalComponentBuilderInstance.AdaptFromDbToRTPComponent(dbComponentConfig.RTPComponents), + SeenScoreComponents: InternalComponentBuilderInstance.AdaptFromDbToSeenScoreComponent(dbComponentConfig.SeenScoreComponents), } } @@ -642,78 +569,6 @@ func AdaptFromDbToNumerixComponent(dbNumerixComponents []dbModel.NumerixComponen return NumerixComponents } -func AdaptFromDbToRTPComponent(dbRTPComponents []dbModel.RTPComponent) []RTPComponent { - - var rtpComponents []RTPComponent - for _, rtpComponent := range dbRTPComponents { - fsKeys := make([]FSKey, len(rtpComponent.FSKeys)) - for i, key := range rtpComponent.FSKeys { - fsKeys[i] = FSKey{ - Schema: key.Schema, - Col: key.Col, - } - } - fsFeatureGroups := make([]FSFeatureGroup, len(rtpComponent.FSRequest.FeatureGroups)) - for i, grp := range rtpComponent.FSRequest.FeatureGroups { - fsFeatureGroups[i] = FSFeatureGroup{ - Label: grp.Label, - Features: grp.Features, - DataType: grp.DataType, - } - } - fsRequest := FSRequest{ - Label: rtpComponent.FSRequest.Label, - FeatureGroups: fsFeatureGroups, - } - rtpComponents = append(rtpComponents, RTPComponent{ - Component: rtpComponent.Component, - ComponentID: rtpComponent.ComponentID, - CompositeID: rtpComponent.CompositeID, - FSKeys: fsKeys, - FeatureRequest: &fsRequest, - FSFlattenRespKeys: rtpComponent.FSFlattenRespKeys, - ColNamePrefix: rtpComponent.ColNamePrefix, - CompCacheEnabled: rtpComponent.CompCacheEnabled, - }) - } - return rtpComponents -} - -func AdaptFromDbToSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []SeenScoreComponent { - var seenScoreComponents []SeenScoreComponent - for _, seenScoreComponent := range dbSeenScoreComponents { - fsKeys := make([]FSKey, len(seenScoreComponent.FSKeys)) - for i, key := range seenScoreComponent.FSKeys { - fsKeys[i] = FSKey{ - Schema: key.Schema, - Col: key.Col, - } - } - var fsRequest *FSRequest - if seenScoreComponent.FSRequest != nil { - fsFeatureGroups := make([]FSFeatureGroup, len(seenScoreComponent.FSRequest.FeatureGroups)) - for i, grp := range seenScoreComponent.FSRequest.FeatureGroups { - fsFeatureGroups[i] = FSFeatureGroup{ - Label: grp.Label, - Features: grp.Features, - DataType: grp.DataType, - } - } - fsRequest = &FSRequest{ - Label: seenScoreComponent.FSRequest.Label, - FeatureGroups: fsFeatureGroups, - } - } - seenScoreComponents = append(seenScoreComponents, SeenScoreComponent{ - Component: seenScoreComponent.Component, - ComponentID: seenScoreComponent.ComponentID, - ColNamePrefix: seenScoreComponent.ColNamePrefix, - FSKeys: fsKeys, - FSRequest: fsRequest, - }) - } - return seenScoreComponents -} func AdaptFromDbToFeatureComponent(dbFeatureComponents []dbModel.FeatureComponent) []FeatureComponent { var featureComponents []FeatureComponent @@ -780,8 +635,8 @@ func AdaptToEtcdComponentConfig(dbComponentConfig dbModel.ComponentConfig) etcdM FeatureComponents: AdaptToEtcdFeatureComponent(dbComponentConfig.FeatureComponents), PredatorComponents: AdaptToEtcdPredatorComponent(dbComponentConfig.PredatorComponents), NumerixComponents: AdaptToEtcdNumerixComponent(dbComponentConfig.NumerixComponents), - RTPComponents: AdaptToEtcdRTPComponent(dbComponentConfig.RTPComponents), - SeenScoreComponents: AdaptToEtcdSeenScoreComponent(dbComponentConfig.SeenScoreComponents), + RTPComponents: InternalComponentBuilderInstance.AdaptToEtcdRTPComponent(dbComponentConfig.RTPComponents), + SeenScoreComponents: InternalComponentBuilderInstance.AdaptToEtcdSeenScoreComponent(dbComponentConfig.SeenScoreComponents), } } @@ -868,78 +723,6 @@ func AdaptToEtcdNumerixComponent(dbNumerixComponents []dbModel.NumerixComponent) return NumerixComponents } -func AdaptToEtcdRTPComponent(dbRTPComponents []dbModel.RTPComponent) []etcdModel.RTPComponent { - - var rtpComponents []etcdModel.RTPComponent - for _, rtpComponent := range dbRTPComponents { - fsKeys := make([]etcdModel.FSKey, len(rtpComponent.FSKeys)) - for i, key := range rtpComponent.FSKeys { - fsKeys[i] = etcdModel.FSKey{ - Schema: key.Schema, - Col: key.Col, - } - } - fsFeatureGroups := make([]etcdModel.FSFeatureGroup, len(rtpComponent.FSRequest.FeatureGroups)) - for i, grp := range rtpComponent.FSRequest.FeatureGroups { - fsFeatureGroups[i] = etcdModel.FSFeatureGroup{ - Label: grp.Label, - Features: grp.Features, - DataType: grp.DataType, - } - } - fsRequest := etcdModel.FSRequest{ - Label: rtpComponent.FSRequest.Label, - FeatureGroups: fsFeatureGroups, - } - rtpComponents = append(rtpComponents, etcdModel.RTPComponent{ - Component: rtpComponent.Component, - ComponentID: rtpComponent.ComponentID, - CompositeID: rtpComponent.CompositeID, - FSKeys: fsKeys, - FSRequest: &fsRequest, - FSFlattenRespKeys: rtpComponent.FSFlattenRespKeys, - ColNamePrefix: rtpComponent.ColNamePrefix, - CompCacheEnabled: rtpComponent.CompCacheEnabled, - }) - } - return rtpComponents -} - -func AdaptToEtcdSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []etcdModel.SeenScoreComponent { - var seenScoreComponents []etcdModel.SeenScoreComponent - for _, seenScoreComponent := range dbSeenScoreComponents { - fsKeys := make([]etcdModel.FSKey, len(seenScoreComponent.FSKeys)) - for i, key := range seenScoreComponent.FSKeys { - fsKeys[i] = etcdModel.FSKey{ - Schema: key.Schema, - Col: key.Col, - } - } - var fsRequest *etcdModel.FSRequest - if seenScoreComponent.FSRequest != nil { - fsFeatureGroups := make([]etcdModel.FSFeatureGroup, len(seenScoreComponent.FSRequest.FeatureGroups)) - for i, grp := range seenScoreComponent.FSRequest.FeatureGroups { - fsFeatureGroups[i] = etcdModel.FSFeatureGroup{ - Label: grp.Label, - Features: grp.Features, - DataType: grp.DataType, - } - } - fsRequest = &etcdModel.FSRequest{ - Label: seenScoreComponent.FSRequest.Label, - FeatureGroups: fsFeatureGroups, - } - } - seenScoreComponents = append(seenScoreComponents, etcdModel.SeenScoreComponent{ - Component: seenScoreComponent.Component, - ComponentID: seenScoreComponent.ComponentID, - ColNamePrefix: seenScoreComponent.ColNamePrefix, - FSKeys: fsKeys, - FSRequest: fsRequest, - }) - } - return seenScoreComponents -} func AdaptToEtcdFeatureComponent(dbFeatureComponents []dbModel.FeatureComponent) []etcdModel.FeatureComponent { var featureComponents []etcdModel.FeatureComponent diff --git a/horizon/internal/inferflow/handler/component_builder_internal_stub.go b/horizon/internal/inferflow/handler/component_builder_internal_stub.go new file mode 100644 index 00000000..0c656139 --- /dev/null +++ b/horizon/internal/inferflow/handler/component_builder_internal_stub.go @@ -0,0 +1,105 @@ +//go:build !meesho + +package handler + +import ( + etcd "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" + etcdModel "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" + dbModel "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow" + mapset "github.com/deckarep/golang-set/v2" +) + +// internalComponentBuilderStub is the stub implementation for open-source builds. +// It has NO knowledge of RTP, SEEN Score, or any other internal features. +// All internal processing is a no-op. +type internalComponentBuilderStub struct{} + +func init() { + InternalComponentBuilderInstance = &internalComponentBuilderStub{} +} + +// IsEnabled returns false - internal components not available in open-source builds +func (s *internalComponentBuilderStub) IsEnabled() bool { + return false +} + +// ProcessFeatures returns empty results - no internal features in open-source builds +func (s *internalComponentBuilderStub) ProcessFeatures( + initialFeatures mapset.Set[string], + featureDataTypes map[string]string, +) (mapset.Set[string], map[string]string, error) { + return mapset.NewSet[string](), make(map[string]string), nil +} + +// ClassifyFeature returns false - no features are internal in open-source builds +func (s *internalComponentBuilderStub) ClassifyFeature(feature string) (string, bool) { + return "", false +} + +// GetInternalComponents returns empty slices - no internal components in open-source builds +func (s *internalComponentBuilderStub) GetInternalComponents( + request InferflowOnboardRequest, + internalFeatures mapset.Set[string], + etcdConfig etcd.Manager, + token string, +) ([]RTPComponent, []SeenScoreComponent, error) { + return []RTPComponent{}, []SeenScoreComponent{}, nil +} + +// FetchInternalComponentFeatures returns empty results - no internal components in open-source builds +func (s *internalComponentBuilderStub) FetchInternalComponentFeatures( + internalFeatures mapset.Set[string], + etcdConfig etcd.Manager, +) (mapset.Set[string], mapset.Set[string], map[string]string, error) { + return mapset.NewSet[string](), mapset.NewSet[string](), make(map[string]string), nil +} + +// FetchMissingInternalDataTypes is a no-op - no internal features in open-source builds +func (s *internalComponentBuilderStub) FetchMissingInternalDataTypes( + featureToDataType map[string]string, + internalFeatures mapset.Set[string], +) error { + return nil +} + +// AddInternalDependenciesToDAG is a no-op - no internal components in open-source builds +func (s *internalComponentBuilderStub) AddInternalDependenciesToDAG( + rtpComponents []RTPComponent, + seenScoreComponents []SeenScoreComponent, + featureComponents []FeatureComponent, + dagConfig *DagExecutionConfig, +) { + // No-op for open-source builds +} + +// ============= Adaptor Stub Methods ============= + +// AdaptToDBRTPComponent returns empty slice - no RTP components in open-source builds +func (s *internalComponentBuilderStub) AdaptToDBRTPComponent(inferflowConfig InferflowConfig) []dbModel.RTPComponent { + return []dbModel.RTPComponent{} +} + +// AdaptToDBSeenScoreComponent returns empty slice - no SeenScore components in open-source builds +func (s *internalComponentBuilderStub) AdaptToDBSeenScoreComponent(inferflowConfig InferflowConfig) []dbModel.SeenScoreComponent { + return []dbModel.SeenScoreComponent{} +} + +// AdaptFromDbToRTPComponent returns empty slice - no RTP components in open-source builds +func (s *internalComponentBuilderStub) AdaptFromDbToRTPComponent(dbRTPComponents []dbModel.RTPComponent) []RTPComponent { + return []RTPComponent{} +} + +// AdaptFromDbToSeenScoreComponent returns empty slice - no SeenScore components in open-source builds +func (s *internalComponentBuilderStub) AdaptFromDbToSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []SeenScoreComponent { + return []SeenScoreComponent{} +} + +// AdaptToEtcdRTPComponent returns empty slice - no RTP components in open-source builds +func (s *internalComponentBuilderStub) AdaptToEtcdRTPComponent(dbRTPComponents []dbModel.RTPComponent) []etcdModel.RTPComponent { + return []etcdModel.RTPComponent{} +} + +// AdaptToEtcdSeenScoreComponent returns empty slice - no SeenScore components in open-source builds +func (s *internalComponentBuilderStub) AdaptToEtcdSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []etcdModel.SeenScoreComponent { + return []etcdModel.SeenScoreComponent{} +} diff --git a/horizon/internal/inferflow/handler/config_builder.go b/horizon/internal/inferflow/handler/config_builder.go index 2dd64496..c3cf507c 100644 --- a/horizon/internal/inferflow/handler/config_builder.go +++ b/horizon/internal/inferflow/handler/config_builder.go @@ -6,8 +6,6 @@ import ( "strconv" "strings" - "github.com/Meesho/BharatMLStack/horizon/internal/externalcall" - inferflow "github.com/Meesho/BharatMLStack/horizon/internal/inferflow" ofsHandler "github.com/Meesho/BharatMLStack/horizon/internal/online-feature-store/handler" etcd "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" @@ -21,10 +19,8 @@ const ( MODEL_FEATURE = "MODEL" OFFLINE_FEATURE = "OFFLINE" CALIBRATION = "CALIBRATION" - RTP_FEATURE = "RTP" PCTR_CALIBRATION = "PCTR_CALIBRATION" PCVR_CALIBRATION = "PCVR_CALIBRATION" - SEEN_SCORE = "SEEN_SCORE_FEATURE" PIPE_DELIMITER = "|" UNDERSCORE_DELIMITER = "_" COLON_DELIMITER = ":" @@ -33,21 +29,17 @@ const ( featureClassOnline = "online" featureClassDefault = "default" featureClassModel = "model" - featureClassRtp = "rtp" featureClassPCVRCalibration = "pcvr_calibration" featureClassPCTRCalibration = "pctr_calibration" - featureClassSeenScore = "seen_score" featureClassInvalid = "invalid" COMPONENT_NAME_PREFIX = "composite_key_gen_" FEATURE_INITIALIZER = "feature_initializer" - SeenScoreComponentName = "product_seen_score" - SeenScoreDefaultDataType = "DataTypeString" ) type FeatureLists struct { allFeatureList mapset.Set[string] - rtpFeatures, pcvrCalibrationFeatures, pctrCalibrationFeatures, seenScoreFeatures mapset.Set[string] + pcvrCalibrationFeatures, pctrCalibrationFeatures mapset.Set[string] featureToDataType, predatorAndIrisOutputsToDataType, offlineToOnlineMapping map[string]string } @@ -59,34 +51,25 @@ type ClassifiedFeatures struct { DefaultFeatures mapset.Set[string] - RTPFeatures mapset.Set[string] - PCTRCalibrationFeatures mapset.Set[string] PCVRCalibrationFeatures mapset.Set[string] - SeenScoreFeatures mapset.Set[string] - FeatureToDataType map[string]string } type AllComponents struct { FeatureComponents []FeatureComponent - RTPComponents []RTPComponent - IrisComponents []NumerixComponent PredatorComponents []PredatorComponent - - SeenScoreComponents []SeenScoreComponent } func (m *InferFlow) GetInferflowConfig(request InferflowOnboardRequest, token string) (InferflowConfig, error) { - // RTP client is initialized in externalcall.Init() entityIDs := extractEntityIDs(request) - featureList, featureToDataType, rtpFeatures, pcvrCalibrationFeatures, pctrCalibrationFeatures, predatorAndNumerixOutputsToDataType, offlineToOnlineMapping, err := GetFeatureList(request, m.EtcdConfig, token, entityIDs) + featureList, featureToDataType, internalFeatures, pcvrCalibrationFeatures, pctrCalibrationFeatures, predatorAndNumerixOutputsToDataType, offlineToOnlineMapping, err := GetFeatureList(request, m.EtcdConfig, token, entityIDs) if err != nil { return InferflowConfig{}, err } @@ -106,7 +89,8 @@ func (m *InferFlow) GetInferflowConfig(request InferflowOnboardRequest, token st return InferflowConfig{}, err } - rtpComponents, err := GetRTPComponents(request, rtpFeatures, m.EtcdConfig, token) + // Get internal components (RTP, SEEN Score, etc.) - only available in meesho builds + rtpComponents, seenScoreComponents, err := InternalComponentBuilderInstance.GetInternalComponents(request, internalFeatures, m.EtcdConfig, token) if err != nil { return InferflowConfig{}, err } @@ -116,12 +100,12 @@ func (m *InferFlow) GetInferflowConfig(request InferflowOnboardRequest, token st return InferflowConfig{}, err } - componentConfig, err := GetComponentConfig(featureComponents, rtpComponents, NumerixComponents, predatorComponents) + componentConfig, err := GetComponentConfig(featureComponents, rtpComponents, seenScoreComponents, NumerixComponents, predatorComponents) if err != nil { return InferflowConfig{}, err } - dagExecutionConfig, err := GetDagExecutionConfig(request, featureComponents, rtpComponents, NumerixComponents, predatorComponents, m.EtcdConfig) + dagExecutionConfig, err := GetDagExecutionConfig(request, featureComponents, rtpComponents, seenScoreComponents, NumerixComponents, predatorComponents, m.EtcdConfig) if err != nil { return InferflowConfig{}, err } @@ -141,7 +125,18 @@ func (m *InferFlow) GetInferflowConfig(request InferflowOnboardRequest, token st func GetFeatureList(request InferflowOnboardRequest, etcdConfig etcd.Manager, token string, entityIDs map[string]bool) (mapset.Set[string], map[string]string, mapset.Set[string], mapset.Set[string], mapset.Set[string], map[string]string, map[string]string, error) { initialFeatures, featureToDataType, predatorAndIrisOutputsToDataType := extractFeatures(request, entityIDs) - offlineFeatures, onlineFeatures, defaultFeatures, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, newFeatureToDataType, err := classifyFeatures(initialFeatures, featureToDataType) + // Process internal features first (RTP, SEEN Score, etc.) - only available in meesho builds + internalFeatures, internalFeatureToDataType, err := InternalComponentBuilderInstance.ProcessFeatures(initialFeatures, featureToDataType) + if err != nil { + return nil, nil, nil, nil, nil, nil, nil, err + } + + // Remove internal features from initial features before standard classification + for f := range internalFeatures.Iter() { + initialFeatures.Remove(f) + } + + offlineFeatures, onlineFeatures, defaultFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, newFeatureToDataType, err := classifyFeatures(initialFeatures, featureToDataType) if err != nil { return nil, nil, nil, nil, nil, nil, nil, err } @@ -153,6 +148,9 @@ func GetFeatureList(request InferflowOnboardRequest, etcdConfig etcd.Manager, to for f, dtype := range newFeatureToDataType { featureToDataType[f] = dtype } + for f, dtype := range internalFeatureToDataType { + featureToDataType[f] = dtype + } features := mapset.NewSet[string]() for offlineFeature, onlineFeature := range offlineToOnlineMapping { @@ -165,34 +163,28 @@ func GetFeatureList(request InferflowOnboardRequest, etcdConfig etcd.Manager, to features.Add(f) } - // Fetch RTP registry once for classification - rtpRegistry, err := GetRTPFeatureGroupDataTypeMap() - if err != nil && inferflow.IsMeeshoEnabled { - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to get RTP registry: %w", err) - } - - // Fetch component features from RTP components - rtpComponentFSFeatures, rtpComponentRTPFeatures, rtpComponentFeatureToDataType, err := fetchRTPComponentFeaturesWithClassification(rtpFeatures, etcdConfig, rtpRegistry) + // Fetch internal component features (only available in meesho builds) + internalFSFeatures, newInternalFeatures, internalComponentFeatureToDataType, err := InternalComponentBuilderInstance.FetchInternalComponentFeatures(internalFeatures, etcdConfig) if err != nil { return nil, nil, nil, nil, nil, nil, nil, err } // Add FS features to the main features set - for _, f := range rtpComponentFSFeatures.ToSlice() { + for _, f := range internalFSFeatures.ToSlice() { features.Add(f) } - // Add newly discovered RTP features to rtpFeatures set - for _, f := range rtpComponentRTPFeatures.ToSlice() { - rtpFeatures.Add(f) + // Add newly discovered internal features + for _, f := range newInternalFeatures.ToSlice() { + internalFeatures.Add(f) } - for f, dtype := range rtpComponentFeatureToDataType { + for f, dtype := range internalComponentFeatureToDataType { featureToDataType[f] = dtype } // Fetch component features from regular FS components - componentFSFeatures, componentRTPFeatures, newfeatureToDataType, err := fetchComponentFeaturesWithClassification(features, pctrCalibrationFeatures, pcvrCalibrationFeatures, etcdConfig, request.Payload.RealEstate, token, rtpRegistry) + componentFSFeatures, newfeatureToDataType, err := fetchComponentFeatures(features, pctrCalibrationFeatures, pcvrCalibrationFeatures, etcdConfig, request.Payload.RealEstate, token) if err != nil { return nil, nil, nil, nil, nil, nil, nil, err } @@ -202,19 +194,10 @@ func GetFeatureList(request InferflowOnboardRequest, etcdConfig etcd.Manager, to features.Add(f) } - // Add newly discovered RTP features to rtpFeatures set - for _, f := range componentRTPFeatures.ToSlice() { - rtpFeatures.Add(f) - } - for f, dtype := range newfeatureToDataType { featureToDataType[f] = dtype } - // for _, f := range defaultFeatures.ToSlice() { - // features.Add(f) - // } - for _, f := range defaultFeatures.ToSlice() { if _, exists := predatorAndIrisOutputsToDataType[f]; exists { continue @@ -224,11 +207,11 @@ func GetFeatureList(request InferflowOnboardRequest, etcdConfig etcd.Manager, to } } - if err := fetchMissingDatatypes(featureToDataType, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, onlineFeatures, token); err != nil { + if err := fetchMissingDatatypes(featureToDataType, internalFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, onlineFeatures, token); err != nil { return nil, nil, nil, nil, nil, nil, nil, err } - return features, featureToDataType, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, predatorAndIrisOutputsToDataType, offlineToOnlineMapping, nil + return features, featureToDataType, internalFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, predatorAndIrisOutputsToDataType, offlineToOnlineMapping, nil } func extractEntityIDs(request InferflowOnboardRequest) map[string]bool { @@ -309,7 +292,7 @@ func extractFeatures(request InferflowOnboardRequest, entityIDs map[string]bool) func fetchMissingDatatypes( featureToDataType map[string]string, - rtpFeatures mapset.Set[string], + internalFeatures mapset.Set[string], pctrCalibrationFeatures mapset.Set[string], pcvrCalibrationFeatures mapset.Set[string], onlineFeatures mapset.Set[string], @@ -327,7 +310,6 @@ func fetchMissingDatatypes( } horizonFeatures := make(map[string]struct{ label, group string }) - rtpFeaturesToFetch := mapset.NewSet[string]() for feature, dtype := range featureToDataType { if dtype != "" { @@ -354,9 +336,8 @@ func fetchMissingDatatypes( continue } - // Check if it's an RTP feature - if rtpFeatures.Contains(feature) { - rtpFeaturesToFetch.Add(feature) + // Skip internal features - they are handled by the internal component builder + if internalFeatures.Contains(feature) { continue } @@ -374,25 +355,9 @@ func fetchMissingDatatypes( } } - // Query RTP API once for all RTP datatypes - if rtpFeaturesToFetch.Cardinality() > 0 { - rtpDataTypeMap, err := GetRTPFeatureGroupDataTypeMap() - if err == nil { - for _, feature := range rtpFeaturesToFetch.ToSlice() { - if dataType, exists := rtpDataTypeMap[feature]; exists { - featureToDataType[feature] = dataType - continue - } - - parts := strings.Split(feature, COLON_DELIMITER) - if len(parts) == 4 { - withoutPrefix := strings.Join(parts[1:], COLON_DELIMITER) - if dataType, exists := rtpDataTypeMap[withoutPrefix]; exists { - featureToDataType[feature] = dataType - } - } - } - } + // Fetch missing internal feature data types (only available in meesho builds) + if err := InternalComponentBuilderInstance.FetchMissingInternalDataTypes(featureToDataType, internalFeatures); err != nil { + return err } // Query Horizon API for remaining features grouped by label @@ -433,21 +398,21 @@ func fetchMissingDatatypes( // classifyFeatures classifies features into offline, online and default features // and returns a set of features for each class and a map of feature to data type +// Note: Internal features (RTP, SEEN Score) are handled separately by InternalComponentBuilder func classifyFeatures( featureList mapset.Set[string], featureDataTypes map[string]string, -) (mapset.Set[string], mapset.Set[string], mapset.Set[string], mapset.Set[string], mapset.Set[string], mapset.Set[string], map[string]string, error) { +) (mapset.Set[string], mapset.Set[string], mapset.Set[string], mapset.Set[string], mapset.Set[string], map[string]string, error) { defaultFeatures := mapset.NewSet[string]() modelFeatures := mapset.NewSet[string]() onlineFeatures := mapset.NewSet[string]() offlineFeatures := mapset.NewSet[string]() - rtpFeatures := mapset.NewSet[string]() pctrCalibrationFeatures := mapset.NewSet[string]() pcvrCalibrationFeatures := mapset.NewSet[string]() newFeatureToDataType := make(map[string]string) add := func(name, originalFeature string, featureType string) error { - if err := AddFeatureToSet(&defaultFeatures, &modelFeatures, &onlineFeatures, &offlineFeatures, &rtpFeatures, &pctrCalibrationFeatures, &pcvrCalibrationFeatures, name, featureType); err != nil { + if err := AddFeatureToSet(&defaultFeatures, &modelFeatures, &onlineFeatures, &offlineFeatures, &pctrCalibrationFeatures, &pcvrCalibrationFeatures, name, featureType); err != nil { return fmt.Errorf("error classifying feature: %w", err) } newFeatureToDataType[name] = featureDataTypes[originalFeature] @@ -455,26 +420,30 @@ func classifyFeatures( } for feature := range featureList.Iter() { + // Check if this is an internal feature - skip if so (handled by InternalComponentBuilder) + if _, isInternal := InternalComponentBuilderInstance.ClassifyFeature(feature); isInternal { + continue + } + transformedFeature, featureType, err := transformFeature(feature) if err != nil { - return nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, err } if err := add(transformedFeature, feature, featureType); err != nil { - return nil, nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, err } } - return offlineFeatures, onlineFeatures, defaultFeatures, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, newFeatureToDataType, nil + return offlineFeatures, onlineFeatures, defaultFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures, newFeatureToDataType, nil } -func AddFeatureToSet(defaultFeatures, modelFeatures, onlineFeatures, offlineFeatures, rtpFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures *mapset.Set[string], feature string, featureType string) error { +func AddFeatureToSet(defaultFeatures, modelFeatures, onlineFeatures, offlineFeatures, pctrCalibrationFeatures, pcvrCalibrationFeatures *mapset.Set[string], feature string, featureType string) error { allSets := map[string]*mapset.Set[string]{ featureClassDefault: defaultFeatures, featureClassModel: modelFeatures, featureClassOnline: onlineFeatures, featureClassOffline: offlineFeatures, - featureClassRtp: rtpFeatures, featureClassPCTRCalibration: pctrCalibrationFeatures, featureClassPCVRCalibration: pcvrCalibrationFeatures, } @@ -495,17 +464,8 @@ func AddFeatureToSet(defaultFeatures, modelFeatures, onlineFeatures, offlineFeat } // transformFeature transforms the feature to either online, offline or default feature -// and returns the transformed feature and the feature type. The feature can be of these given types: -// 1. PARENT_OFFLINE_FEATURE|FEATURE -// 2. DEFAULT_FEATURE|FEATURE -// 3. ONLINE_FEATURE|FEATURE -// 4. OFFLINE_FEATURE|FEATURE -// 5. PARENT_DEFAULT_FEATURE|FEATURE -// 6. PARENT_ONLINE_FEATURE|FEATURE -// 7. MODEL_FEATURE|FEATURE -// 8. CALIBRATION|FEATURE -// 9. RTP_FEATURE|FEATURE -// 10. PARENT_RTP_FEATURE|FEATURE +// and returns the transformed feature and the feature type. +// Note: Internal features (RTP, SEEN Score) are handled by InternalComponentBuilder.ClassifyFeature() func transformFeature(feature string) (string, string, error) { parts := strings.Split(feature, PIPE_DELIMITER) if len(parts) < 2 { @@ -534,8 +494,6 @@ func transformFeature(feature string) (string, string, error) { return newFeature, featureClassOnline, nil case OFFLINE_FEATURE: return newFeature, featureClassOffline, nil - case RTP_FEATURE: - return newFeature, featureClassRtp, nil case PCVR_CALIBRATION: return newFeature, featureClassPCVRCalibration, nil case PCTR_CALIBRATION: @@ -553,8 +511,6 @@ func transformFeature(feature string) (string, string, error) { return featureName, featureClassOnline, nil case OFFLINE_FEATURE: return featureName, featureClassOffline, nil - case RTP_FEATURE: - return featureName, featureClassRtp, nil case PCVR_CALIBRATION: return featureName, featureClassPCVRCalibration, nil case PCTR_CALIBRATION: @@ -570,94 +526,27 @@ func mapOfflineFeatures(offlineFeatureList mapset.Set[string], token string) (ma return GetOnlineFeatureMapping(offlineFeatureList, token) } -func fetchRTPComponentFeaturesWithClassification(rtpFeatures mapset.Set[string], etcdConfig etcd.Manager, rtpRegistry map[string]string) (mapset.Set[string], mapset.Set[string], map[string]string, error) { - componentList := getComponentList(rtpFeatures, nil, nil) - fsFeatures := mapset.NewSet[string]() - newRTPFeatures := mapset.NewSet[string]() - featureToDataType := make(map[string]string) - - for _, component := range componentList.ToSlice() { - componentData := etcdConfig.GetComponentData(component) - if componentData == nil { - return nil, nil, nil, fmt.Errorf("RTP Component: componentData for '%s' not found in registry", component) - } - - for _, pair := range componentData.FSIdSchemaToValueColumns { - if strings.Contains(pair.ValueCol, COLON_DELIMITER) { - // Check if this is an RTP feature or FS feature - isRTPFeature := false - - // Check direct match in RTP registry - if _, exists := rtpRegistry[pair.ValueCol]; exists { - isRTPFeature = true - } else { - // Check with prefix removed (for features like "parent:entity:group:feature") - parts := strings.Split(pair.ValueCol, COLON_DELIMITER) - if len(parts) == 4 { - withoutPrefix := strings.Join(parts[1:], COLON_DELIMITER) - if _, exists := rtpRegistry[withoutPrefix]; exists { - isRTPFeature = true - } - } - } - - if isRTPFeature { - newRTPFeatures.Add(pair.ValueCol) - } else { - fsFeatures.Add(pair.ValueCol) - } - featureToDataType[pair.ValueCol] = pair.DataType - } - } - } - - return fsFeatures, newRTPFeatures, featureToDataType, nil -} - -// fetchComponentFeaturesWithClassification fetches the component features from the etcd config -// and classifies them as RTP or FS features, returns both sets and a map of feature to data type -func fetchComponentFeaturesWithClassification(features mapset.Set[string], pctrCalibrationFeatures mapset.Set[string], pcvrCalibrationFeatures mapset.Set[string], etcdConfig etcd.Manager, realEstate string, token string, rtpRegistry map[string]string) (mapset.Set[string], mapset.Set[string], map[string]string, error) { +// fetchComponentFeatures fetches the component features from the etcd config +// and returns the features set and a map of feature to data type +func fetchComponentFeatures(features mapset.Set[string], pctrCalibrationFeatures mapset.Set[string], pcvrCalibrationFeatures mapset.Set[string], etcdConfig etcd.Manager, realEstate string, token string) (mapset.Set[string], map[string]string, error) { componentList := getComponentList(features, pctrCalibrationFeatures, pcvrCalibrationFeatures) fsFeatures := mapset.NewSet[string]() - newRTPFeatures := mapset.NewSet[string]() featureToDataType := make(map[string]string) for _, component := range componentList.ToSlice() { componentData := etcdConfig.GetComponentData(component) if componentData == nil { - return nil, nil, nil, fmt.Errorf("component data: ComponentData for '%s' not found in registry. Please contact MLP team to onboard the component", component) + return nil, nil, fmt.Errorf("component data: ComponentData for '%s' not found in registry. Please contact MLP team to onboard the component", component) } for _, pair := range componentData.FSIdSchemaToValueColumns { if strings.Contains(pair.ValueCol, COLON_DELIMITER) { - // Check if this is an RTP feature or FS feature - isRTPFeature := false - - // Check direct match in RTP registry - if _, exists := rtpRegistry[pair.ValueCol]; exists { - isRTPFeature = true - } else { - // Check with prefix removed (for features like "parent:entity:group:feature") - parts := strings.Split(pair.ValueCol, COLON_DELIMITER) - if len(parts) == 4 { - withoutPrefix := strings.Join(parts[1:], COLON_DELIMITER) - if _, exists := rtpRegistry[withoutPrefix]; exists { - isRTPFeature = true - } - } - } - - if isRTPFeature { - newRTPFeatures.Add(pair.ValueCol) - } else { - fsFeatures.Add(pair.ValueCol) - } + fsFeatures.Add(pair.ValueCol) featureToDataType[pair.ValueCol] = pair.DataType } } if override, hasOverride := componentData.Overridecomponent[realEstate]; hasOverride { - // Override components are always FS features fsFeatures.Add(override.ComponentId) parts := strings.Split(override.ComponentId, COLON_DELIMITER) var label, group string @@ -667,23 +556,23 @@ func fetchComponentFeaturesWithClassification(features mapset.Set[string], pctrC } else if len(parts) == 4 { label, group = parts[1], parts[2] } else { - return nil, nil, nil, fmt.Errorf("component data: invalid override component id: %s", override.ComponentId) + return nil, nil, fmt.Errorf("component data: invalid override component id: %s", override.ComponentId) } featureGroupDataTypeMap, err := GetFeatureGroupDataTypeMap(label, token) if err != nil { - return nil, nil, nil, fmt.Errorf("component data: error getting feature group data type map: %w", err) + return nil, nil, fmt.Errorf("component data: error getting feature group data type map: %w", err) } if dataType, exists := featureGroupDataTypeMap[group]; exists { featureToDataType[override.ComponentId] = dataType } else { - return nil, nil, nil, fmt.Errorf("component data: feature group data type not found for %s: %s", override.ComponentId, group) + return nil, nil, fmt.Errorf("component data: feature group data type not found for %s: %s", override.ComponentId, group) } } } - return fsFeatures, newRTPFeatures, featureToDataType, nil + return fsFeatures, featureToDataType, nil } // getComponentList gets the component list from the features @@ -995,165 +884,6 @@ func GetFeatureGroupDataTypeMap(label string, token string) (map[string]string, return featureGroupDataTypeMap, nil } -func GetRTPComponents(request InferflowOnboardRequest, rtpFeatures mapset.Set[string], etcdConfig etcd.Manager, token string) ([]RTPComponent, error) { - rtpComponents := make([]RTPComponent, 0) - - if rtpFeatures.Cardinality() == 0 { - return rtpComponents, nil - } - - featureDataTypeMap, err := GetRTPFeatureGroupDataTypeMap() - if err != nil && inferflow.IsMeeshoEnabled { - return rtpComponents, fmt.Errorf("RTP Components: failed to fetch RTP feature data-type map: %w", err) - } - rtpFeatureComponentsMap := GetRTPFeatureLabelToPrefixToFeatureGroupToFeatureMap(rtpFeatures.ToSlice()) - for label, prefixToFeatureGroupToFeatureMap := range rtpFeatureComponentsMap { - if err != nil { - return nil, err - } - - for prefix, featureGroupToFeatureMap := range prefixToFeatureGroupToFeatureMap { - componentName := label - colNamePrefix := "" - - if prefix != "" { - componentName = prefix + UNDERSCORE_DELIMITER + label - colNamePrefix = prefix + COLON_DELIMITER - } - - featureGroupsByDataType := make(map[string]map[string][]string) // [featureGroupName][dataType][]features - - for featureGroupName, featureSet := range featureGroupToFeatureMap { - featureGroupsByDataType[featureGroupName] = make(map[string][]string) - - for _, feature := range featureSet.ToSlice() { - featureDataType := featureDataTypeMap[strings.Join([]string{label, featureGroupName, feature}, COLON_DELIMITER)] - if featureDataType == "" { - return nil, fmt.Errorf("RTP Components: no data type found for feature %s", feature) - } - - if featureGroupsByDataType[featureGroupName][featureDataType] == nil { - featureGroupsByDataType[featureGroupName][featureDataType] = make([]string, 0) - } - featureGroupsByDataType[featureGroupName][featureDataType] = append(featureGroupsByDataType[featureGroupName][featureDataType], feature) - } - } - - featureGroups := make([]FSFeatureGroup, 0) - for featureGroupName, dataTypeToFeatures := range featureGroupsByDataType { - for dataType, features := range dataTypeToFeatures { - sort.Strings(features) - - featureGroupData := FSFeatureGroup{ - Label: featureGroupName, - Features: features, - DataType: dataType, - } - featureGroups = append(featureGroups, featureGroupData) - } - } - - componentData := etcdConfig.GetComponentData(componentName) - if componentData == nil { - return nil, fmt.Errorf("RTP Components: componentData for '%s' not found in registry", componentName) - } - - componentID := componentData.ComponentID - overrideComponentID := "" - if realEstate := request.Payload.RealEstate; realEstate != "" { - if override, exists := componentData.Overridecomponent[realEstate]; exists { - overrideComponentID = override.ComponentId - componentID = override.ComponentId - } - } - - idKeys := make([]string, 0, len(componentData.FSIdSchemaToValueColumns)) - for k := range componentData.FSIdSchemaToValueColumns { - idKeys = append(idKeys, k) - } - sort.Strings(idKeys) - fsKeys := make([]FSKey, 0, len(idKeys)) - for _, k := range idKeys { - pair := componentData.FSIdSchemaToValueColumns[k] - col := pair.ValueCol - if overrideComponentID != "" { - col = overrideComponentID - } - fsKeys = append(fsKeys, FSKey{ - Schema: pair.Schema, - Col: col, - }) - } - - rtpComponent := RTPComponent{ - Component: componentName, - ComponentID: componentID, - CompCacheEnabled: false, - FSKeys: fsKeys, - FSFlattenRespKeys: []string{componentData.FSFlattenResKeys[label]}, - FeatureRequest: &FSRequest{ - Label: label, - FeatureGroups: featureGroups, - }, - ColNamePrefix: colNamePrefix, - } - - rtpComponents = append(rtpComponents, rtpComponent) - } - } - - return rtpComponents, nil -} - -func GetRTPFeatureLabelToPrefixToFeatureGroupToFeatureMap(featureStrings []string) map[string]map[string]map[string]mapset.Set[string] { - featuresMap := make(map[string]map[string]map[string]mapset.Set[string]) - - if len(featureStrings) == 0 { - return featuresMap - } - - sort.Strings(featureStrings) - - for _, input := range featureStrings { - parts := strings.Split(input, COLON_DELIMITER) - if len(parts) != 3 && len(parts) != 4 { - continue - } - - var ( - prefix string - label string - group string - feature string - ) - - if len(parts) == 4 { - prefix, label, group, feature = parts[0], parts[1], parts[2], parts[3] - } else { - prefix = "" - label, group, feature = parts[0], parts[1], parts[2] - } - - if _, ok := featuresMap[label]; !ok { - featuresMap[label] = make(map[string]map[string]mapset.Set[string]) - } - if _, ok := featuresMap[label][prefix]; !ok { - featuresMap[label][prefix] = make(map[string]mapset.Set[string]) - } - if _, ok := featuresMap[label][prefix][group]; !ok { - featuresMap[label][prefix][group] = mapset.NewSet[string]() - } - - featuresMap[label][prefix][group].Add(feature) - } - - return featuresMap -} - -func GetRTPFeatureGroupDataTypeMap() (map[string]string, error) { - return externalcall.PricingClient.GetFeatureGroupDataTypeMap() -} - func GetPredatorComponents(request InferflowOnboardRequest, offlineToOnlineMapping map[string]string) ([]PredatorComponent, error) { predatorComponents := make([]PredatorComponent, 0, len(request.Payload.Rankers)) @@ -1297,9 +1027,10 @@ func getNumerixScoreMapping(eqVariables map[string]string, offlineToOnlineMappin } else { return nil, fmt.Errorf("numerix score mapping: offlineToOnlineMapping for '%s' not found", transformedFeature) } - case featureClassOnline, featureClassDefault, featureClassRtp: + case featureClassOnline, featureClassDefault: scoremap[key] = transformedFeature default: + // Includes internal feature types (handled by InternalComponentBuilder) scoremap[key] = transformedFeature } @@ -1320,21 +1051,22 @@ func GetResponseConfigs(request *InferflowOnboardRequest) (*FinalResponseConfig, return responseConfigs, nil } -func GetComponentConfig(featureComponents []FeatureComponent, rtpComponents []RTPComponent, NumerixComponents []NumerixComponent, predatorComponents []PredatorComponent) (*ComponentConfig, error) { +func GetComponentConfig(featureComponents []FeatureComponent, rtpComponents []RTPComponent, seenScoreComponents []SeenScoreComponent, NumerixComponents []NumerixComponent, predatorComponents []PredatorComponent) (*ComponentConfig, error) { componentConfig := &ComponentConfig{ - CacheEnabled: true, - CacheTTL: 300, - CacheVersion: 1, - FeatureComponents: featureComponents, - RTPComponents: rtpComponents, - NumerixComponents: NumerixComponents, - PredatorComponents: predatorComponents, + CacheEnabled: true, + CacheTTL: 300, + CacheVersion: 1, + FeatureComponents: featureComponents, + RTPComponents: rtpComponents, + SeenScoreComponents: seenScoreComponents, + NumerixComponents: NumerixComponents, + PredatorComponents: predatorComponents, } return componentConfig, nil } -func GetDagExecutionConfig(request InferflowOnboardRequest, featureComponents []FeatureComponent, rtpComponents []RTPComponent, NumerixComponents []NumerixComponent, predatorComponents []PredatorComponent, etcdConfig etcd.Manager) (*DagExecutionConfig, error) { +func GetDagExecutionConfig(request InferflowOnboardRequest, featureComponents []FeatureComponent, rtpComponents []RTPComponent, seenScoreComponents []SeenScoreComponent, NumerixComponents []NumerixComponent, predatorComponents []PredatorComponent, etcdConfig etcd.Manager) (*DagExecutionConfig, error) { dagExecutionConfig := &DagExecutionConfig{ ComponentDependency: make(map[string][]string), } @@ -1351,17 +1083,9 @@ func GetDagExecutionConfig(request InferflowOnboardRequest, featureComponents [] } } - for _, component := range rtpComponents { - componentName := component.Component - - specificDependencies := findSpecificRTPDependencies(component, rtpComponents, featureComponents) - - if len(specificDependencies) > 0 { - dagExecutionConfig.ComponentDependency[componentName] = append(dagExecutionConfig.ComponentDependency[componentName], specificDependencies...) - } else { - dagExecutionConfig.ComponentDependency[componentName] = append(dagExecutionConfig.ComponentDependency[componentName], FEATURE_INITIALIZER) - } - } + // Add internal component dependencies (RTP, SEEN Score, etc.) using internal component builder + // This is a no-op for open-source builds since internal components will be empty + InternalComponentBuilderInstance.AddInternalDependenciesToDAG(rtpComponents, seenScoreComponents, featureComponents, dagExecutionConfig) for _, component := range predatorComponents { componentName := component.Component @@ -1438,50 +1162,6 @@ func findSpecificFeatureDependencies(featureComp FeatureComponent, featureCompon return dependencies } -func findSpecificRTPDependencies(rtpComp RTPComponent, rtpComponents []RTPComponent, featureComponents []FeatureComponent) []string { - var dependencies []string - completedComponents := make(map[string]bool) - requiredInputs := make(map[string]struct{}) - for _, key := range rtpComp.FSKeys { - requiredInputs[key.Col] = struct{}{} - } - - for _, otherComp := range rtpComponents { - if done, ok := completedComponents[otherComp.Component]; ok && done { - continue - } - colNamePrefix := otherComp.ColNamePrefix - for _, featureGroup := range otherComp.FeatureRequest.FeatureGroups { - for _, feature := range featureGroup.Features { - featureKey := colNamePrefix + otherComp.FeatureRequest.Label + COLON_DELIMITER + featureGroup.Label + COLON_DELIMITER + feature - if _, required := requiredInputs[featureKey]; required && !completedComponents[otherComp.Component] { - dependencies = append(dependencies, otherComp.Component) - completedComponents[otherComp.Component] = true - break - } - } - } - } - - for _, featureComp := range featureComponents { - if done, ok := completedComponents[featureComp.Component]; ok && done { - continue - } - colNamePrefix := featureComp.ColNamePrefix - for _, featureGroup := range featureComp.FSRequest.FeatureGroups { - for _, feature := range featureGroup.Features { - featureKey := colNamePrefix + featureComp.FSRequest.Label + COLON_DELIMITER + featureGroup.Label + COLON_DELIMITER + feature - if _, required := requiredInputs[featureKey]; required && !completedComponents[featureComp.Component] { - dependencies = append(dependencies, featureComp.Component) - completedComponents[featureComp.Component] = true - break - } - } - } - } - return dependencies -} - func findSpecificPredatorDependencies(predatorComp PredatorComponent, featureComponents []FeatureComponent, rtpComponents []RTPComponent, NumerixComponents []NumerixComponent) []string { var dependencies []string diff --git a/horizon/internal/inferflow/handler/inferflow.go b/horizon/internal/inferflow/handler/inferflow.go index 8df7a05b..063b8b56 100644 --- a/horizon/internal/inferflow/handler/inferflow.go +++ b/horizon/internal/inferflow/handler/inferflow.go @@ -675,7 +675,7 @@ func (m *InferFlow) handleApprovedRequest(request ReviewRequest) (Response, erro return Response{ Error: emptyResponse, - Data: Message{Message: "Mp Config reviewed successfully."}, + Data: Message{Message: "Inferflow Config reviewed successfully."}, }, nil } diff --git a/horizon/internal/inferflow/handler/internal_components.go b/horizon/internal/inferflow/handler/internal_components.go new file mode 100644 index 00000000..c084476d --- /dev/null +++ b/horizon/internal/inferflow/handler/internal_components.go @@ -0,0 +1,93 @@ +package handler + +import ( + etcd "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" + etcdModel "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" + dbModel "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow" + mapset "github.com/deckarep/golang-set/v2" +) + +// InternalComponentBuilder defines the interface for building internal-only components. +// This interface abstracts all internal feature processing (RTP, SEEN Score, etc.) +// that are only available in internal builds (meesho build tag). +// +// For open-source builds (!meesho), a stub implementation returns empty/pass-through results. +// For internal builds (meesho), the full implementation provides actual functionality. +// +// The external code has NO knowledge of RTP, SEEN Score, or any other internal features. +// All such logic is encapsulated within the internal implementation. +type InternalComponentBuilder interface { + // IsEnabled returns true if internal components are available in this build + IsEnabled() bool + + // ProcessFeatures processes the initial feature set and returns additional internal features. + // It classifies features that are internal-only (RTP, SEEN Score, etc.) and returns: + // - internalFeatures: features that should be handled by internal components + // - featureToDataType: data type mappings for internal features + // The external code will exclude these features from standard processing. + ProcessFeatures( + initialFeatures mapset.Set[string], + featureDataTypes map[string]string, + ) (internalFeatures mapset.Set[string], featureToDataType map[string]string, err error) + + // ClassifyFeature checks if a feature is internal-only. + // Returns the transformed feature name and true if it's an internal feature type. + // Returns empty string and false if it's not an internal feature. + ClassifyFeature(feature string) (transformedFeature string, isInternal bool) + + // GetInternalComponents builds all internal components (RTP, SEEN Score, etc.) + // Returns the components to be added to the config. + GetInternalComponents( + request InferflowOnboardRequest, + internalFeatures mapset.Set[string], + etcdConfig etcd.Manager, + token string, + ) (rtpComponents []RTPComponent, seenScoreComponents []SeenScoreComponent, err error) + + // FetchInternalComponentFeatures fetches features from internal component definitions + // and classifies them. Returns features that should be added to the main feature set + // and features that should remain as internal-only. + FetchInternalComponentFeatures( + internalFeatures mapset.Set[string], + etcdConfig etcd.Manager, + ) (fsFeatures mapset.Set[string], newInternalFeatures mapset.Set[string], featureToDataType map[string]string, err error) + + // FetchMissingInternalDataTypes fetches data types for internal features that are missing them + FetchMissingInternalDataTypes( + featureToDataType map[string]string, + internalFeatures mapset.Set[string], + ) error + + // AddInternalDependenciesToDAG adds dependencies for all internal components to the DAG + AddInternalDependenciesToDAG( + rtpComponents []RTPComponent, + seenScoreComponents []SeenScoreComponent, + featureComponents []FeatureComponent, + dagConfig *DagExecutionConfig, + ) + + // ============= Adaptor Methods ============= + + // AdaptToDBRTPComponent adapts RTP components to DB model format + AdaptToDBRTPComponent(inferflowConfig InferflowConfig) []dbModel.RTPComponent + + // AdaptToDBSeenScoreComponent adapts SeenScore components to DB model format + AdaptToDBSeenScoreComponent(inferflowConfig InferflowConfig) []dbModel.SeenScoreComponent + + // AdaptFromDbToRTPComponent adapts DB model RTP components to handler format + AdaptFromDbToRTPComponent(dbRTPComponents []dbModel.RTPComponent) []RTPComponent + + // AdaptFromDbToSeenScoreComponent adapts DB model SeenScore components to handler format + AdaptFromDbToSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []SeenScoreComponent + + // AdaptToEtcdRTPComponent adapts DB model RTP components to etcd model format + AdaptToEtcdRTPComponent(dbRTPComponents []dbModel.RTPComponent) []etcdModel.RTPComponent + + // AdaptToEtcdSeenScoreComponent adapts DB model SeenScore components to etcd model format + AdaptToEtcdSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []etcdModel.SeenScoreComponent +} + +// InternalComponentBuilderInstance is the global instance of the internal component builder. +// This is set by the init() function in either the stub or internal implementation file +// depending on build tags. +var InternalComponentBuilderInstance InternalComponentBuilder From c3762fd304b42eaa931ff3fc83cc5195e5e70be9 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Mon, 2 Feb 2026 18:05:52 +0530 Subject: [PATCH 03/24] revert pre-commit config --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e1fccdbf..c721100c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,6 @@ repos: - id: trufflehog name: TruffleHog description: Detect secrets in your data. - entry: "pre-commit-scripts/runner.sh" + entry: "trufflehog/trufflehog-hook.sh" language: script stages: ["pre-commit", "pre-push"] From bb796ed343dfd40be73025627ca1927bce4e2ac1 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Mon, 2 Feb 2026 18:12:55 +0530 Subject: [PATCH 04/24] removed redundant structs --- .../component_builder_internal_stub.go | 9 +++--- .../inferflow/handler/config_builder.go | 30 ------------------- 2 files changed, 4 insertions(+), 35 deletions(-) diff --git a/horizon/internal/inferflow/handler/component_builder_internal_stub.go b/horizon/internal/inferflow/handler/component_builder_internal_stub.go index 0c656139..8a2053c9 100644 --- a/horizon/internal/inferflow/handler/component_builder_internal_stub.go +++ b/horizon/internal/inferflow/handler/component_builder_internal_stub.go @@ -4,7 +4,6 @@ package handler import ( etcd "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" - etcdModel "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" dbModel "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow" mapset "github.com/deckarep/golang-set/v2" ) @@ -95,11 +94,11 @@ func (s *internalComponentBuilderStub) AdaptFromDbToSeenScoreComponent(dbSeenSco } // AdaptToEtcdRTPComponent returns empty slice - no RTP components in open-source builds -func (s *internalComponentBuilderStub) AdaptToEtcdRTPComponent(dbRTPComponents []dbModel.RTPComponent) []etcdModel.RTPComponent { - return []etcdModel.RTPComponent{} +func (s *internalComponentBuilderStub) AdaptToEtcdRTPComponent(dbRTPComponents []dbModel.RTPComponent) []etcd.RTPComponent { + return []etcd.RTPComponent{} } // AdaptToEtcdSeenScoreComponent returns empty slice - no SeenScore components in open-source builds -func (s *internalComponentBuilderStub) AdaptToEtcdSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []etcdModel.SeenScoreComponent { - return []etcdModel.SeenScoreComponent{} +func (s *internalComponentBuilderStub) AdaptToEtcdSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []etcd.SeenScoreComponent { + return []etcd.SeenScoreComponent{} } diff --git a/horizon/internal/inferflow/handler/config_builder.go b/horizon/internal/inferflow/handler/config_builder.go index c3cf507c..83b9c875 100644 --- a/horizon/internal/inferflow/handler/config_builder.go +++ b/horizon/internal/inferflow/handler/config_builder.go @@ -36,36 +36,6 @@ const ( FEATURE_INITIALIZER = "feature_initializer" ) -type FeatureLists struct { - allFeatureList mapset.Set[string] - - pcvrCalibrationFeatures, pctrCalibrationFeatures mapset.Set[string] - - featureToDataType, predatorAndIrisOutputsToDataType, offlineToOnlineMapping map[string]string -} - -type ClassifiedFeatures struct { - OfflineFeatures mapset.Set[string] - - OnlineFeatures mapset.Set[string] - - DefaultFeatures mapset.Set[string] - - PCTRCalibrationFeatures mapset.Set[string] - - PCVRCalibrationFeatures mapset.Set[string] - - FeatureToDataType map[string]string -} - -type AllComponents struct { - FeatureComponents []FeatureComponent - - IrisComponents []NumerixComponent - - PredatorComponents []PredatorComponent -} - func (m *InferFlow) GetInferflowConfig(request InferflowOnboardRequest, token string) (InferflowConfig, error) { entityIDs := extractEntityIDs(request) From f3806230ccf879b6e4e1549d44daac17bc300203 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 09:49:46 +0530 Subject: [PATCH 05/24] error formatting fixes --- horizon/internal/inferflow/handler/inferflow.go | 8 ++++---- horizon/internal/inferflow/handler/internal_components.go | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/horizon/internal/inferflow/handler/inferflow.go b/horizon/internal/inferflow/handler/inferflow.go index 063b8b56..55275115 100644 --- a/horizon/internal/inferflow/handler/inferflow.go +++ b/horizon/internal/inferflow/handler/inferflow.go @@ -212,7 +212,7 @@ func (m *InferFlow) Promote(request PromoteConfigRequest) (Response, error) { newVersion = latestRequests[0].Version + 1 } if newVersion > maxConfigVersion { - return Response{}, errors.New("This inferflow config has reached its version limit. Please create a clone to make further updates.") + return Response{}, errors.New("this inferflow config has reached its version limit. Please create a clone to make further updates") } request.Payload.ConfigValue.ComponentConfig.CacheVersion = newVersion } else { @@ -290,7 +290,7 @@ func (m *InferFlow) Edit(request EditConfigOrCloneConfigRequest, token string) ( } if newVersion > maxConfigVersion { - return Response{}, errors.New("This inferflow config has reached its version limit. Please create a clone to make further updates.") + return Response{}, errors.New("this inferflow config has reached its version limit. Please create a clone to make further updates") } onboardRequest := InferflowOnboardRequest(request) @@ -732,7 +732,7 @@ func (m *InferFlow) rollbackPromoteRequest(tx *gorm.DB, currentRequest *inferflo func (m *InferFlow) rollbackEditRequest(tx *gorm.DB, currentRequest *inferflow_request.Table, discoveryID int) error { approvedRequests, err := m.InferFlowRequestRepo.GetApprovedRequestsByConfigID(currentRequest.ConfigID) if err != nil { - return fmt.Errorf("Failed to retrieve approved requests: %w", err) + return fmt.Errorf("failed to retrieve approved requests: %w", err) } var previousRequest *inferflow_request.Table @@ -1319,7 +1319,7 @@ func (m *InferFlow) ValidateOnboardRequest(request OnboardPayload) (Response, er return Response{ Error: "Failed to fetch deployable config for the request", Data: Message{Message: emptyResponse}, - }, errors.New("Failed to fetch deployable config for the request") + }, errors.New("failed to fetch deployable config for the request") } permissibleEndpoints := m.EtcdConfig.GetConfiguredEndpoints(deployableConfig.Name) for _, ranker := range request.Rankers { diff --git a/horizon/internal/inferflow/handler/internal_components.go b/horizon/internal/inferflow/handler/internal_components.go index c084476d..6ca64f5b 100644 --- a/horizon/internal/inferflow/handler/internal_components.go +++ b/horizon/internal/inferflow/handler/internal_components.go @@ -2,7 +2,6 @@ package handler import ( etcd "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" - etcdModel "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" dbModel "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow" mapset "github.com/deckarep/golang-set/v2" ) @@ -81,10 +80,10 @@ type InternalComponentBuilder interface { AdaptFromDbToSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []SeenScoreComponent // AdaptToEtcdRTPComponent adapts DB model RTP components to etcd model format - AdaptToEtcdRTPComponent(dbRTPComponents []dbModel.RTPComponent) []etcdModel.RTPComponent + AdaptToEtcdRTPComponent(dbRTPComponents []dbModel.RTPComponent) []etcd.RTPComponent // AdaptToEtcdSeenScoreComponent adapts DB model SeenScore components to etcd model format - AdaptToEtcdSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []etcdModel.SeenScoreComponent + AdaptToEtcdSeenScoreComponent(dbSeenScoreComponents []dbModel.SeenScoreComponent) []etcd.SeenScoreComponent } // InternalComponentBuilderInstance is the global instance of the internal component builder. From c7ee96760980fb9ca7013b0ab4afae94ff3f5149 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 11:25:54 +0530 Subject: [PATCH 06/24] predator sync --- horizon/internal/configs/app_config.go | 7 +- horizon/internal/externalcall/gcs_client.go | 438 ++++++++++----- horizon/internal/predator/handler/model.go | 2 + horizon/internal/predator/handler/predator.go | 520 ++++++++++++++++-- .../repositories/sql/predatorconfig/sql.go | 18 + .../repositories/sql/predatorconfig/table.go | 1 + 6 files changed, 779 insertions(+), 207 deletions(-) diff --git a/horizon/internal/configs/app_config.go b/horizon/internal/configs/app_config.go index c561d0fd..34f20755 100644 --- a/horizon/internal/configs/app_config.go +++ b/horizon/internal/configs/app_config.go @@ -55,9 +55,10 @@ type Configs struct { DefaultGpuThreshold string `mapstructure:"default_gpu_threshold"` DefaultModelPath string `mapstructure:"default_model_path"` - GcsModelBucket string `mapstructure:"gcs_model_bucket"` - GcsModelBasePath string `mapstructure:"gcs_model_base_path"` - GcsEnabled bool `mapstructure:"gcs_enabled"` + GcsModelBucket string `mapstructure:"gcs_model_bucket"` + GcsModelBasePath string `mapstructure:"gcs_model_base_path"` + GcsConfigBasePath string `mapstructure:"gcs_config_base_path"` + GcsConfigBucket string `mapstructure:"gcs_config_bucket"` GrafanaBaseUrl string `mapstructure:"grafana_base_url"` diff --git a/horizon/internal/externalcall/gcs_client.go b/horizon/internal/externalcall/gcs_client.go index 7fe8bc5a..47ac4948 100644 --- a/horizon/internal/externalcall/gcs_client.go +++ b/horizon/internal/externalcall/gcs_client.go @@ -3,11 +3,13 @@ package externalcall import ( "bytes" "context" + "errors" "fmt" "io" "os" "path" "path/filepath" + "regexp" "strings" "sync" "time" @@ -21,6 +23,7 @@ type GCSClientInterface interface { ReadFile(bucket, objectPath string) ([]byte, error) TransferFolder(srcBucket, srcPath, srcModelName, destBucket, destPath, destModelName string) error TransferAndDeleteFolder(srcBucket, srcPath, srcModelName, destBucket, destPath, destModelName string) error + TransferFolderWithSplitSources(modelBucket, modelPath, configBucket, configPath, srcModelName, destBuckt, destPath, destModelName string) error DeleteFolder(bucket, modelPath, modelName string) error ListFolders(bucket, prefix string) ([]string, error) UploadFile(bucket, objectPath string, data []byte) error @@ -52,14 +55,7 @@ type GCSClient struct { ctx context.Context } -func CreateGCSClient(isGcsEnabled bool) GCSClientInterface { - if !isGcsEnabled { - log.Warn().Msg("GCS client is disabled") - return &GCSClient{ - client: nil, - ctx: nil, - } - } +func CreateGCSClient() GCSClientInterface { ctx := context.Background() // Check for Application Default Credentials path @@ -139,30 +135,17 @@ func (g *GCSClient) TransferFolder(srcBucket, srcPath, srcModelName, destBucket, var regularFiles []storage.ObjectAttrs var configFiles []storage.ObjectAttrs - it := g.client.Bucket(srcBucket).Objects(g.ctx, &storage.Query{Prefix: prefix}) - for { - objAttrs, err := it.Next() - if err == iterator.Done { - break - } - if err != nil { - return fmt.Errorf("failed to list source bucket: %w", err) - } - - if strings.HasSuffix(objAttrs.Name, "/") { - continue - } + isConfigFile := func(attrs *storage.ObjectAttrs) bool { + return strings.HasSuffix(attrs.Name, "config.pbtxt") + } - if strings.HasSuffix(objAttrs.Name, "config.pbtxt") { - configFiles = append(configFiles, *objAttrs) - } else { - regularFiles = append(regularFiles, *objAttrs) - } + configFiles, regularFiles, err := g.partitionObjects(srcBucket, prefix, isConfigFile) + if err != nil { + return fmt.Errorf("failed to list source bucket: %w", err) } if len(regularFiles) == 0 && len(configFiles) == 0 { - log.Info().Msg("No objects found to transfer") - return nil + return fmt.Errorf("no files found at source location: gs://%s/%s", srcBucket, prefix) } log.Info().Msgf("Starting two-phase transfer: %d regular files, %d config files", len(regularFiles), len(configFiles)) @@ -291,6 +274,115 @@ func (g *GCSClient) transferSingleConfigFile(objAttrs storage.ObjectAttrs, srcBu return nil } +func (g *GCSClient) TransferFolderWithSplitSources(modelBucket, modelPath, configBucket, configPath, srcModelName, destBucket, destPath, destModelName string) error { + modelPrefix := path.Join(modelPath, srcModelName) + if !strings.HasSuffix(modelPrefix, "/") { + modelPrefix += "/" + } + + regularFiles, err := g.listObjects(modelBucket, modelPrefix, func(attrs *storage.ObjectAttrs) bool { + return !strings.HasSuffix(attrs.Name, "config.pbtxt") + }) + if err != nil { + return fmt.Errorf("failed to read regular files from model source: %w", err) + } + + log.Info().Msgf("TransferFolderWithSplitSources: Found %d regular files in model source gs://%s/%s", + len(regularFiles), modelBucket, modelPrefix) + + configPrefix := path.Join(configPath, srcModelName) + if !strings.HasSuffix(configPrefix, "/") { + configPrefix += "/" + } + + configFiles, err := g.listObjects(configBucket, configPrefix, func(attrs *storage.ObjectAttrs) bool { + return strings.HasSuffix(attrs.Name, "config.pbtxt") + }) + if err != nil { + return fmt.Errorf("failed to read config files from config source: %w", err) + } + + log.Info().Msgf("TransferFolderWithSplitSources: Found %d config files in config source gs://%s/%s", + len(configFiles), configBucket, configPrefix) + + if len(regularFiles) == 0 && len(configFiles) == 0 { + log.Warn().Msg("TransferFolderWithSplitSources: No objects found to transfer") + return nil + } + + regularFilesTransferred := false + if len(regularFiles) > 0 { + if err := g.transferRegularFilesFromSource(regularFiles, modelBucket, destBucket, destPath, destModelName, modelPrefix); err != nil { + return fmt.Errorf("failed to transfer regular files from model source: %w", err) + } + regularFilesTransferred = true + } + + if len(configFiles) > 0 { + if err := g.transferConfigFilesFromSource(configFiles, configBucket, destBucket, destPath, destModelName, configPrefix); err != nil { + if regularFilesTransferred { + log.Warn().Err(err).Msgf("Config file transfer failed, reverting transfer by deleting destination folder gs://%s/%s/%s", + destBucket, destPath, destModelName) + if revertErr := g.DeleteFolder(destBucket, destPath, destModelName); revertErr != nil { + log.Error().Err(revertErr).Msgf("Failed to revert transfer by deleting destination folder gs://%s/%s/%s", + destBucket, destPath, destModelName) + return fmt.Errorf("failed to transfer config files from config source: %w; revert also failed: %w", err, revertErr) + } + log.Info().Msgf("Successfully reverted transfer by deleting destination folder") + } + return fmt.Errorf("failed to transfer config files from config source: %w", err) + } + } + + log.Info().Msgf("TransferFolderWithSplitSources: Successfully completed split-source transfer for model %s -> %s", + srcModelName, destModelName) + return nil +} + +func (g *GCSClient) transferRegularFilesFromSource(files []storage.ObjectAttrs, srcBucket, destBucket, destPath, destModelName, prefix string) error { + log.Info().Msgf("Transferring %d regular files from model source", len(files)) + + semaphore := make(chan struct{}, maxConcurrentFiles) + var wg sync.WaitGroup + var mu sync.Mutex + var transferErrors []error + + for _, objAttrs := range files { + wg.Add(1) + go func(obj storage.ObjectAttrs) { + defer wg.Done() + semaphore <- struct{}{} + defer func() { <-semaphore }() + + if err := g.transferSingleRegularFile(obj, srcBucket, destBucket, destPath, destModelName, prefix); err != nil { + mu.Lock() + transferErrors = append(transferErrors, fmt.Errorf("failed to transfer %s: %w", obj.Name, err)) + mu.Unlock() + } + }(objAttrs) + } + + wg.Wait() + + if len(transferErrors) > 0 { + return fmt.Errorf("regular file transfer completed with %d errors: %v", len(transferErrors), transferErrors[0]) + } + + return nil +} + +func (g *GCSClient) transferConfigFilesFromSource(files []storage.ObjectAttrs, srcBucket, destBucket, destPath, destModelName, prefix string) error { + log.Info().Msgf("Transferring %d config files from config source", len(files)) + + for _, objAttrs := range files { + if err := g.transferSingleConfigFile(objAttrs, srcBucket, destBucket, destPath, destModelName, prefix); err != nil { + return fmt.Errorf("failed to transfer config file %s: %w", objAttrs.Name, err) + } + } + + return nil +} + func (g *GCSClient) DeleteFolder(bucket, modelPath, modelName string) error { // Ensure the prefix ends with "/" to avoid matching partial directory names prefix := path.Join(modelPath, modelName) @@ -353,19 +445,50 @@ func (g *GCSClient) TransferAndDeleteFolder(srcBucket, srcPath, srcModelName, de return nil } -// replaceModelNameInConfig modifies the `name:` field in config.pbtxt content +// replaceModelNameInConfig modifies only the top-level `name:` field in config.pbtxt content +// It replaces only the first occurrence to avoid modifying nested names in inputs/outputs/instance_groups func replaceModelNameInConfig(data []byte, destModelName string) []byte { - lines := strings.Split(string(data), "\n") - originalName := "" + content := string(data) + lines := strings.Split(content, "\n") + for i, line := range lines { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "name:") { - originalName = line - lines[i] = fmt.Sprintf(`name: "%s"`, destModelName) - log.Info().Msgf("Replacing model name in config.pbtxt: '%s' -> 'name: \"%s\"'", originalName, destModelName) - break + trimmed := strings.TrimSpace(line) + // Match top-level "name:" field - should be at the start of line (or minimal indentation) + // Skip nested names which are typically indented with 2+ spaces + if strings.HasPrefix(trimmed, "name:") { + // Check indentation: top-level fields have minimal/no indentation + leadingWhitespace := len(line) - len(strings.TrimLeft(line, " \t")) + // Skip if heavily indented (nested field) + if leadingWhitespace >= 2 { + continue + } + + // Match the first occurrence of name: "value" pattern + namePattern := regexp.MustCompile(`name\s*:\s*"([^"]+)"`) + matches := namePattern.FindStringSubmatch(line) + if len(matches) > 1 { + oldModelName := matches[1] + // Replace only the FIRST occurrence to avoid replacing nested names + loc := namePattern.FindStringIndex(line) + if loc != nil { + // Replace only the matched portion (first occurrence) + before := line[:loc[0]] + matched := line[loc[0]:loc[1]] + after := line[loc[1]:] + // Replace the value inside quotes while preserving the "name:" format + valuePattern := regexp.MustCompile(`"([^"]+)"`) + valueReplaced := valuePattern.ReplaceAllString(matched, fmt.Sprintf(`"%s"`, destModelName)) + lines[i] = before + valueReplaced + after + } else { + // Fallback: replace all (shouldn't happen with valid input) + lines[i] = namePattern.ReplaceAllString(line, fmt.Sprintf(`name: "%s"`, destModelName)) + } + log.Info().Msgf("Replacing top-level model name in config.pbtxt: '%s' -> '%s'", oldModelName, destModelName) + break + } } } + return []byte(strings.Join(lines, "\n")) } @@ -384,32 +507,23 @@ func (g *GCSClient) ListFolders(bucket, prefix string) ([]string, error) { log.Info().Msgf("Listing folders in GCS bucket %s with prefix %s", bucket, prefix) - it := g.client.Bucket(bucket).Objects(g.ctx, &storage.Query{ - Prefix: prefix, - // Do NOT set Delimiter here - }) - - for { - attrs, err := it.Next() - if err == iterator.Done { - break + err := g.forEachObject(bucket, prefix, func(attrs *storage.ObjectAttrs) error { + if attrs.Name == "" { + return nil } - if err != nil { - return nil, fmt.Errorf("failed to list objects: %w", err) - } - - // Extract folder name after the prefix - if attrs.Name != "" { - trimmed := strings.TrimPrefix(attrs.Name, prefix) - parts := strings.SplitN(trimmed, "/", 2) - if len(parts) > 1 { - folderName := parts[0] - if !seenFolders[folderName] { - folders = append(folders, folderName) - seenFolders[folderName] = true - } + trimmed := strings.TrimPrefix(attrs.Name, prefix) + parts := strings.SplitN(trimmed, "/", 2) + if len(parts) > 1 { + folderName := parts[0] + if !seenFolders[folderName] { + folders = append(folders, folderName) + seenFolders[folderName] = true } } + return nil + }) + if err != nil { + return nil, err } return folders, nil @@ -458,20 +572,15 @@ func (g *GCSClient) CheckFolderExists(bucket, folderPrefix string) (bool, error) folderPrefix += "/" } - it := g.client.Bucket(bucket).Objects(g.ctx, &storage.Query{ - Prefix: folderPrefix, + var exists bool + err := g.forEachObject(bucket, folderPrefix, func(attrs *storage.ObjectAttrs) error { + exists = true + return ErrStopIteration // Found one, stop iteration }) - - // Check if at least one object exists with this prefix - _, err := it.Next() - if err == iterator.Done { - return false, nil // No objects found with this prefix - } if err != nil { return false, fmt.Errorf("failed to check folder existence: %w", err) } - - return true, nil + return exists, nil } func (g *GCSClient) UploadFolderFromLocal(srcFolderPath, bucket, destPath string) error { @@ -526,38 +635,26 @@ func (g *GCSClient) GetFolderInfo(bucket, folderPrefix string) (*GCSFolderInfo, folderPrefix += "/" } - it := g.client.Bucket(bucket).Objects(g.ctx, &storage.Query{ - Prefix: folderPrefix, - }) - var folderInfo GCSFolderInfo folderInfo.Name = strings.TrimSuffix(path.Base(folderPrefix), "/") folderInfo.Path = fmt.Sprintf("gs://%s/%s", bucket, strings.TrimSuffix(folderPrefix, "/")) - folderInfo.Created = time.Now() // Will be updated with actual earliest file - folderInfo.Updated = time.Time{} // Will be updated with actual latest file + folderInfo.Created = time.Now() + folderInfo.Updated = time.Time{} - for { - attrs, err := it.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("failed to list objects: %w", err) - } - - // Update folder stats + err := g.forEachObject(bucket, folderPrefix, func(attrs *storage.ObjectAttrs) error { folderInfo.FileCount++ folderInfo.Size += attrs.Size - // Track earliest creation time if attrs.Created.Before(folderInfo.Created) { folderInfo.Created = attrs.Created } - - // Track latest update time if attrs.Updated.After(folderInfo.Updated) { folderInfo.Updated = attrs.Updated } + return nil + }) + if err != nil { + return nil, err } if folderInfo.FileCount == 0 { @@ -581,54 +678,39 @@ func (g *GCSClient) ListFoldersWithTimestamp(bucket, prefix string) ([]GCSFolder } log.Info().Msgf("Listing folders with timestamps in GCS bucket %s with prefix %s", bucket, prefix) - - it := g.client.Bucket(bucket).Objects(g.ctx, &storage.Query{ - Prefix: prefix, - }) - - for { - attrs, err := it.Next() - if err == iterator.Done { - break + err := g.forEachObject(bucket, prefix, func(attrs *storage.ObjectAttrs) error { + if attrs.Name == "" { + return nil } - if err != nil { - return nil, fmt.Errorf("failed to list objects: %w", err) - } - - // Extract folder name after the prefix - if attrs.Name != "" { - trimmed := strings.TrimPrefix(attrs.Name, prefix) - parts := strings.SplitN(trimmed, "/", 2) - if len(parts) > 1 { - folderName := parts[0] - - // Initialize or update folder info - if folderInfo, exists := seenFolders[folderName]; !exists { - seenFolders[folderName] = &GCSFolderInfo{ - Name: folderName, - Path: fmt.Sprintf("gs://%s/%s%s", bucket, prefix, folderName), - Created: attrs.Created, - Updated: attrs.Updated, - Size: attrs.Size, - FileCount: 1, - } - } else { - // Update existing folder info - folderInfo.FileCount++ - folderInfo.Size += attrs.Size - - // Track earliest creation time - if attrs.Created.Before(folderInfo.Created) { - folderInfo.Created = attrs.Created - } - - // Track latest update time - if attrs.Updated.After(folderInfo.Updated) { - folderInfo.Updated = attrs.Updated - } + trimmed := strings.TrimPrefix(attrs.Name, prefix) + parts := strings.SplitN(trimmed, "/", 2) + if len(parts) > 1 { + folderName := parts[0] + + if folderInfo, exists := seenFolders[folderName]; !exists { + seenFolders[folderName] = &GCSFolderInfo{ + Name: folderName, + Path: fmt.Sprintf("gs://%s/%s%s", bucket, prefix, folderName), + Created: attrs.Created, + Updated: attrs.Updated, + Size: attrs.Size, + FileCount: 1, + } + } else { + folderInfo.FileCount++ + folderInfo.Size += attrs.Size + if attrs.Created.Before(folderInfo.Created) { + folderInfo.Created = attrs.Created + } + if attrs.Updated.After(folderInfo.Updated) { + folderInfo.Updated = attrs.Updated } } } + return nil + }) + if err != nil { + return nil, err } // Convert map to slice @@ -653,29 +735,95 @@ func (g *GCSClient) FindFileWithSuffix(bucket, folderPath, suffix string) (bool, log.Info().Msgf("Searching for files with suffix '%s' in GCS bucket %s with prefix %s", suffix, bucket, folderPath) - it := g.client.Bucket(bucket).Objects(g.ctx, &storage.Query{ - Prefix: folderPath, + var foundFile string + err := g.forEachObject(bucket, folderPath, func(attrs *storage.ObjectAttrs) error { + fileName := path.Base(attrs.Name) + if strings.HasSuffix(fileName, suffix) { + log.Info().Msgf("Found file with suffix '%s': %s", suffix, fileName) + foundFile = fileName + return ErrStopIteration + } + return nil }) + if err != nil { + return false, "", fmt.Errorf("failed to list objects: %w", err) + } + return foundFile != "", foundFile, nil + log.Info().Msgf("No file found with suffix '%s' in %s/%s", suffix, bucket, folderPath) + return false, "", nil +} + +// ObjectVisitor is called for each object. Return an error to stop iteration. +// Return a special sentinel error like ErrStopIteration to stop without error. +type ObjectVisitor func(attrs *storage.ObjectAttrs) error + +var ErrStopIteration = errors.New("stop iteration") + +// forEachObject iterates over all objects with the given prefix and calls the visitor for each. +func (g *GCSClient) forEachObject(bucket, prefix string, visitor ObjectVisitor) error { + it := g.client.Bucket(bucket).Objects(g.ctx, &storage.Query{Prefix: prefix}) for { - attrs, err := it.Next() + objAttrs, err := it.Next() if err == iterator.Done { break } if err != nil { - return false, "", fmt.Errorf("failed to list objects: %w", err) + return fmt.Errorf("failed to list objects: %w", err) } - // Get the filename from the full object path - fileName := path.Base(attrs.Name) + if err := visitor(objAttrs); err != nil { + if errors.Is(err, ErrStopIteration) { + return nil + } + return err + } + } + return nil +} - // Check if the file ends with the specified suffix - if strings.HasSuffix(fileName, suffix) { - log.Info().Msgf("Found file with suffix '%s': %s", suffix, fileName) - return true, fileName, nil +// ObjectFilter returns true if the object should be included. +type ObjectFilter func(attrs *storage.ObjectAttrs) bool + +// listObjects returns all objects matching the prefix, optionally filtered. +// Pass nil for filter to include all objects (except directory markers). +func (g *GCSClient) listObjects(bucket, prefix string, filter ObjectFilter) ([]storage.ObjectAttrs, error) { + var objects []storage.ObjectAttrs + + err := g.forEachObject(bucket, prefix, func(attrs *storage.ObjectAttrs) error { + // Skip directory markers by default + if strings.HasSuffix(attrs.Name, "/") { + return nil } + + if filter == nil || filter(attrs) { + objects = append(objects, *attrs) + } + return nil + }) + + if err != nil { + return nil, err } + return objects, nil +} - log.Info().Msgf("No file found with suffix '%s' in %s/%s", suffix, bucket, folderPath) - return false, "", nil +// partitionObjects separates objects into two groups based on a predicate. +// Objects matching the predicate go into the first slice, others into the second. +func (g *GCSClient) partitionObjects(bucket, prefix string, predicate ObjectFilter) (matching, notMatching []storage.ObjectAttrs, err error) { + err = g.forEachObject(bucket, prefix, func(attrs *storage.ObjectAttrs) error { + // Skip directory markers + if strings.HasSuffix(attrs.Name, "/") { + return nil + } + + if predicate(attrs) { + matching = append(matching, *attrs) + } else { + notMatching = append(notMatching, *attrs) + } + return nil + }) + + return matching, notMatching, err } diff --git a/horizon/internal/predator/handler/model.go b/horizon/internal/predator/handler/model.go index 6e618373..c05cb329 100644 --- a/horizon/internal/predator/handler/model.go +++ b/horizon/internal/predator/handler/model.go @@ -45,6 +45,7 @@ type IOField struct { type ConfigMapping struct { ServiceDeployableID uint `json:"service_deployable_id"` + SourceModelName string `json:"source_model_name,omitempty"` } type FetchModelConfigRequest struct { @@ -99,6 +100,7 @@ type ModelResponse struct { DeployableRunningStatus string `json:"deployable_running_status"` TestResults json.RawMessage `json:"test_results"` HasNilData bool `json:"has_nil_data"` + SourceModelName string `json:"source_model_name,omitempty"` } type PredatorRequestResponse struct { diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index 97a9ca3f..a4af4a3e 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -9,6 +9,7 @@ import ( "io" "math" "os" + "regexp" "sync" "github.com/Meesho/BharatMLStack/horizon/internal/constant" @@ -163,6 +164,9 @@ const ( failedToParseServiceConfig = "Failed to parse service config" failedToCreateServiceDiscoveryAndConfig = "Failed to create service discovery and config" predatorInferMethod = "inference.GRPCInferenceService/ModelInfer" + deployableTagDelimiter = "_" + scaleupTag = "scaleup" + ) func InitV1ConfigHandler() (Config, error) { @@ -237,7 +241,7 @@ func InitV1ConfigHandler() (Config, error) { workingEnv := viper.GetString("WORKING_ENV") predator = &Predator{ - GcsClient: externalcall.CreateGCSClient(pred.IsGcsEnabled), + GcsClient: externalcall.CreateGCSClient(), ServiceDeployableRepo: serviceDeployableRepo, Repo: repo, PredatorConfigRepo: predatorConfigRepo, @@ -272,7 +276,32 @@ func (p *Predator) HandleModelRequest(req ModelRequest, requestType string) (str modelNameList = append(modelNameList, modelName) } - exist, err := p.Repo.ActiveModelRequestExistForRequestType(modelNameList, requestType) + var payloadObjects []Payload + derivedModelNames := make([]string, len(modelNameList)) + + for i, payload := range req.Payload { + payloadBytes, err := json.Marshal(payload) + if err != nil { + return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf(errMsgProcessPayload) + } + + var payloadObject Payload + if err := json.Unmarshal(payloadBytes, &payloadObject); err != nil { + return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf(errMsgProcessPayload) + } + derivedModelName, err := p.GetDerivedModelName(payloadObject, requestType) + if err != nil { + return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf("failed to fetch derived model name: %w", err) + } + if requestType == ScaleUpRequestType { + payloadObject.ConfigMapping.SourceModelName = payloadObject.ModelName + } + payloadObject.ModelName = derivedModelName + derivedModelNames[i] = derivedModelName + payloadObjects = append(payloadObjects, payloadObject) + } + + exist, err := p.Repo.ActiveModelRequestExistForRequestType(derivedModelNames, requestType) if err != nil { return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf("failed to check existing models: %w", err) } @@ -280,7 +309,7 @@ func (p *Predator) HandleModelRequest(req ModelRequest, requestType string) (str return constant.EmptyString, http.StatusConflict, fmt.Errorf("active model request already exists for one or more requested models") } - predatorConfigList, err := p.PredatorConfigRepo.GetActiveModelByModelNameList(modelNameList) + predatorConfigList, err := p.PredatorConfigRepo.GetActiveModelByModelNameList(derivedModelNames) if err != nil { log.Error().Err(err).Msg(fmt.Sprintf("failed to fetch predator configs: %v", err)) @@ -296,12 +325,22 @@ func (p *Predator) HandleModelRequest(req ModelRequest, requestType string) (str return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf("failed to get group id: %w", err) } - for _, payload := range req.Payload { - payloadBytes, err := json.Marshal(payload) + for i := range len(req.Payload) { + payloadObject := payloadObjects[i] + payloadBytes, err := json.Marshal(payloadObject) if err != nil { - return constant.EmptyString, http.StatusInternalServerError, errors.New(errMsgProcessPayload) + return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf("failed to marshal payload: %w", err) } - modelName, _ := payload[fieldModelName].(string) + + if payloadObject.ConfigMapping.ServiceDeployableID == 0 { + return constant.EmptyString, http.StatusBadRequest, fmt.Errorf("service deployable id is required") + } + + if requestType == OnboardRequestType && payloadObject.MetaData.InstanceCount > 1 { + return constant.EmptyString, http.StatusBadRequest, fmt.Errorf("instance count should be 1 for onboard environment") + } + + modelName := payloadObject.ModelName newRequests = append(newRequests, predatorrequest.PredatorRequest{ ModelName: modelName, GroupId: groupID, @@ -514,18 +553,27 @@ func (p *Predator) FetchModelConfig(req FetchModelConfigRequest) (ModelParamsRes return ModelParamsResponse{}, http.StatusBadRequest, err } - bucket, objectPath := parseModelPath(req.ModelPath) - configPath := path.Join(objectPath, configFile) - metaDataPath := path.Join(objectPath, "metadata.json") + intBucket, intObjectPath := parseModelPath(req.ModelPath) + metaDataPath := path.Join(intObjectPath, "metadata.json") + _, modelName := parseModelPath(intObjectPath) + intConfigPath := path.Join(intObjectPath, configFile) + // Read config.pbtxt - configData, err := p.GcsClient.ReadFile(bucket, configPath) + var configData []byte + var err error + if p.isNonProductionEnvironment() { + configData, err = p.GcsClient.ReadFile(intBucket, intConfigPath) + } else { + prodConfigPath := path.Join(pred.GcsConfigBasePath, modelName, configFile) + configData, err = p.GcsClient.ReadFile(pred.GcsConfigBucket, prodConfigPath) + } if err != nil { return ModelParamsResponse{}, http.StatusInternalServerError, fmt.Errorf(errReadConfigFileFormat, err) } // Read feature_meta.json - metaData, err := p.GcsClient.ReadFile(bucket, metaDataPath) + metaData, err := p.GcsClient.ReadFile(intBucket, metaDataPath) var featureMeta *FeatureMetadata if err == nil && metaData != nil { if err := json.Unmarshal(metaData, &featureMeta); err != nil { @@ -561,7 +609,7 @@ func (p *Predator) FetchModelConfig(req FetchModelConfigRequest) (ModelParamsRes outputs = []IO{} } - return createModelParamsResponse(&modelConfig, objectPath, inputs, outputs), http.StatusOK, nil + return createModelParamsResponse(&modelConfig, intObjectPath, inputs, outputs), http.StatusOK, nil } func validateModelPath(modelPath string) error { @@ -858,6 +906,8 @@ func (p *Predator) processEditGCSCopyStage(requestIdPayloadMap map[uint]*Payload return transferredGcsModelData, nil } + isNotProd := p.isNonProductionEnvironment() + for _, requestModel := range predatorRequestList { payload := requestIdPayloadMap[requestModel.RequestID] if payload == nil { @@ -896,9 +946,22 @@ func (p *Predator) processEditGCSCopyStage(requestIdPayloadMap map[uint]*Payload sourceModelName := pathSegments[len(pathSegments)-1] sourceBasePath := strings.TrimSuffix(sourcePath, "/"+sourceModelName) - if err := p.GcsClient.TransferFolder(sourceBucket, sourceBasePath, sourceModelName, targetBucket, targetPath, modelName); err != nil { - log.Error().Err(err).Msgf("Failed to copy model %s for edit approval", modelName) - return transferredGcsModelData, fmt.Errorf("failed to copy model %s: %w", modelName, err) + if isNotProd { + if err := p.GcsClient.TransferFolder( + sourceBucket, sourceBasePath, sourceModelName, + targetBucket, targetPath, modelName, + ); err != nil { + return transferredGcsModelData, err + } + } else { + configBucket := pred.GcsConfigBucket + configPath := pred.GcsConfigBasePath + if err := p.GcsClient.TransferFolderWithSplitSources( + sourceBucket, sourceBasePath, configBucket, configPath, + sourceModelName, targetBucket, targetPath, modelName, + ); err != nil { + return transferredGcsModelData, err + } } // Track transferred data for potential rollback @@ -1150,6 +1213,7 @@ func (p *Predator) processPayload(predatorRequest predatorrequest.PredatorReques func (p *Predator) processGCSCloneStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) ([]GcsModelData, error) { var transferredGcsModelData []GcsModelData if predatorRequestList[0].RequestStage == predatorStagePending || predatorRequestList[0].RequestStage == predatorStageCloneToBucket { + isNotProd := p.isNonProductionEnvironment() for _, requestModel := range predatorRequestList { serviceDeployable, err := p.ServiceDeployableRepo.GetById(int(requestIdPayloadMap[requestModel.RequestID].ConfigMapping.ServiceDeployableID)) @@ -1165,18 +1229,45 @@ func (p *Predator) processGCSCloneStage(requestIdPayloadMap map[uint]*Payload, p return transferredGcsModelData, err } - srcBucket, srcPath, srcModelName := extractGCSDetails(requestIdPayloadMap[requestModel.RequestID].ModelSource) destBucket, destPath := extractGCSPath(strings.TrimSuffix(deployableConfig.GCSBucketPath, "/*")) + destModelName := requestIdPayloadMap[requestModel.RequestID].ModelName - if deployableConfig.GCSBucketPath != "NA" { - log.Info().Msgf("srcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", srcBucket, srcPath, srcModelName, destBucket, destPath) - if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || srcModelName == constant.EmptyString || destBucket == constant.EmptyString || destPath == constant.EmptyString || requestIdPayloadMap[requestModel.RequestID].ModelName == constant.EmptyString { - log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) - return transferredGcsModelData, errors.New(errModelPathFormat) - } + var srcBucket, srcPath, srcModelName string + + srcBucket = pred.GcsModelBucket + srcPath = pred.GcsModelBasePath + if requestModel.RequestType == ScaleUpRequestType { + srcModelName = destModelName + log.Info().Msgf("Scale-up: Source from model-source gs://%s/%s/%s", + srcBucket, srcPath, srcModelName) + } else { + _, _, srcModelName = extractGCSDetails(requestIdPayloadMap[requestModel.RequestID].ModelSource) + log.Info().Msgf("Onboard/Promote: Source from payload gs://%s/%s/%s", + srcBucket, srcPath, srcModelName) + } + + log.Info().Msgf("Copying to target deployable - src: %s/%s/%s, dest: %s/%s/%s", + srcBucket, srcPath, srcModelName, destBucket, destPath, destModelName) - if err := p.GcsClient.TransferFolder(srcBucket, srcPath, srcModelName, destBucket, destPath, - requestIdPayloadMap[requestModel.RequestID].ModelName); err != nil { + + if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || + srcModelName == constant.EmptyString || destBucket == constant.EmptyString || + destPath == constant.EmptyString || destModelName == constant.EmptyString { + log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) + return transferredGcsModelData, errors.New(errModelPathFormat) + } + + if isNotProd { + if err := p.GcsClient.TransferFolder(srcBucket, srcPath, srcModelName, + destBucket, destPath, destModelName); err != nil { + log.Error().Err(err).Msg(errGCSCopyFailed) + return transferredGcsModelData, err + } + } else { + if err := p.GcsClient.TransferFolderWithSplitSources( + srcBucket, srcPath, pred.GcsConfigBucket, pred.GcsConfigBasePath, + srcModelName, destBucket, destPath, destModelName, + ); err != nil { log.Error().Err(err).Msg(errGCSCopyFailed) return transferredGcsModelData, err } @@ -1187,6 +1278,8 @@ func (p *Predator) processGCSCloneStage(requestIdPayloadMap map[uint]*Payload, p Path: destPath, Name: requestIdPayloadMap[requestModel.RequestID].ModelName, }) + + log.Info().Msgf("Successfully copied model to target deployable: %s", destModelName) } p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusInProgress, predatorStageDBPopulation) } @@ -1195,30 +1288,59 @@ func (p *Predator) processGCSCloneStage(requestIdPayloadMap map[uint]*Payload, p func (p *Predator) processGCSCloneStageIndefaultFolder(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) ([]GcsModelData, error) { var transferredGcsModelData []GcsModelData - if predatorRequestList[0].RequestStage == predatorStagePending || predatorRequestList[0].RequestStage == predatorStageCloneToBucket { - for _, requestModel := range predatorRequestList { - srcBucket, srcPath, srcModelName := extractGCSDetails(requestIdPayloadMap[requestModel.RequestID].ModelSource) - destBucket := pred.GcsModelBucket - destPath := pred.GcsModelBasePath - log.Info().Msgf("srcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", srcBucket, srcPath, srcModelName, destBucket, destPath) - if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || srcModelName == constant.EmptyString || destBucket == constant.EmptyString || destPath == constant.EmptyString || requestIdPayloadMap[requestModel.RequestID].ModelName == constant.EmptyString { - log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) - return transferredGcsModelData, errors.New(errModelPathFormat) - } + if predatorRequestList[0].RequestStage != predatorStagePending && + predatorRequestList[0].RequestStage != predatorStageCloneToBucket { + return transferredGcsModelData, nil + } - if err := p.GcsClient.TransferFolder(srcBucket, srcPath, srcModelName, pred.GcsModelBucket, pred.GcsModelBasePath, - requestIdPayloadMap[requestModel.RequestID].ModelName); err != nil { - log.Error().Err(err).Msg(errGCSCopyFailed) + isNotProd := p.isNonProductionEnvironment() + + for _, requestModel := range predatorRequestList { + payload := requestIdPayloadMap[requestModel.RequestID] + + destBucket := pred.GcsModelBucket + destPath := pred.GcsModelBasePath + destModelName := payload.ModelName + + _, _, originalModelName := extractGCSDetails(payload.ModelSource) + srcBucket := pred.GcsModelBucket + srcPath := pred.GcsModelBasePath + srcModelName := originalModelName + + log.Info().Msgf("Scale-up: Copying within model-source %s → %s", srcModelName, destModelName) + log.Info().Msgf("srcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", + srcBucket, srcPath, srcModelName, destBucket, destPath) + + if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || + srcModelName == constant.EmptyString || destBucket == constant.EmptyString || + destPath == constant.EmptyString || destModelName == constant.EmptyString { + log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) + return transferredGcsModelData, errors.New(errModelPathFormat) + } + + if err := p.GcsClient.TransferFolder(srcBucket, srcPath, srcModelName, + destBucket, destPath, destModelName); err != nil { + log.Error().Err(err).Msg(errGCSCopyFailed) + return transferredGcsModelData, err + } + + log.Info().Msgf("Successfully copied model in model-source: %s → %s", srcModelName, destModelName) + + if !isNotProd && srcModelName != destModelName { + if err := p.copyConfigToNewNameInConfigSource(srcModelName, destModelName); err != nil { + log.Error().Err(err).Msgf("Failed to copy config to config-source: %s → %s", + srcModelName, destModelName) return transferredGcsModelData, err } - - transferredGcsModelData = append(transferredGcsModelData, GcsModelData{ - Bucket: pred.GcsModelBucket, - Path: pred.GcsModelBasePath, - Name: requestIdPayloadMap[requestModel.RequestID].ModelName, - }) } + + transferredGcsModelData = append(transferredGcsModelData, GcsModelData{ + Bucket: destBucket, + Path: destPath, + Name: destModelName, + }) } + return transferredGcsModelData, nil } @@ -1405,7 +1527,9 @@ func (p *Predator) updateRequestStatusAndStage(approvedBy string, predatorReques if stage != constant.EmptyString { predatorRequestList[i].RequestStage = stage } - if predatorRequestList[i].Status == statusApproved { + if predatorRequestList[i].Status == statusApproved || + predatorRequestList[i].Status == statusFailed || + predatorRequestList[i].Status == statusRejected { predatorRequestList[i].Active = false } predatorRequestList[i].UpdatedAt = time.Now() @@ -1415,6 +1539,7 @@ func (p *Predator) updateRequestStatusAndStage(approvedBy string, predatorReques log.Printf(errFailedToUpdateRequestStatusAndStage, err) } } + func (p *Predator) createDiscoveryAndPredatorConfigTx(tx *gorm.DB, requestModel predatorrequest.PredatorRequest, payload Payload, approvedBy string) error { discoveryConfig, err := p.createDiscoveryConfigTx(tx, &requestModel, payload) if err != nil { @@ -1449,6 +1574,13 @@ func (p *Predator) createPredatorConfigTx(tx *gorm.DB, requestModel *predatorreq return err } + serviceDeployableID := int(payload.ConfigMapping.ServiceDeployableID) + serviceDeployable, err := p.ServiceDeployableRepo.GetById(serviceDeployableID) + if err != nil { + log.Error().Err(err).Msgf("Failed to get service deployable config for ID %d", serviceDeployableID) + return fmt.Errorf("failed to get service deployable config: %w", err) + } + config := predatorconfig.PredatorConfig{ DiscoveryConfigID: discoveryConfigID, ModelName: payload.ModelName, @@ -1458,6 +1590,15 @@ func (p *Predator) createPredatorConfigTx(tx *gorm.DB, requestModel *predatorreq CreatedAt: time.Now(), UpdatedAt: time.Now(), Active: true, + SourceModelName: payload.ConfigMapping.SourceModelName, + } + + if serviceDeployable.OverrideTesting { + log.Info().Msgf("OverrideTesting is enabled for deployable %s. Setting test_results for model %s", + serviceDeployable.Name, payload.ModelName) + + config.TestResults = json.RawMessage(`{"is_functionally_tested": true}`) + config.HasNilData = false } if err := tx.Create(&config).Error; err != nil { @@ -1687,6 +1828,7 @@ func (p *Predator) buildModelResponses( DeployableRunningStatus: infraConfig.RunningStatus, TestResults: config.TestResults, HasNilData: config.HasNilData, + SourceModelName: config.SourceModelName, } results = append(results, modelResponse) @@ -2202,22 +2344,49 @@ func (p *Predator) copyRequestModelsToTemporary(requests []predatorrequest.Preda return fmt.Errorf("failed to parse temporary deployable config: %w", err) } - if tempDeployableConfig.GCSBucketPath != "NA" { - tempBucket, tempPath := extractGCSPath(strings.TrimSuffix(tempDeployableConfig.GCSBucketPath, "/*")) + tempBucket, tempPath := extractGCSPath(strings.TrimSuffix(tempDeployableConfig.GCSBucketPath, "/*")) - // Copy each requested model from default GCS location to temporary deployable - for _, request := range requests { - modelName := request.ModelName - sourceBucket := pred.GcsModelBucket - sourcePath := pred.GcsModelBasePath + isNotProd := p.isNonProductionEnvironment() - log.Info().Msgf("Copying requested model %s from gs://%s/%s to temporary deployable gs://%s/%s", - modelName, sourceBucket, sourcePath, tempBucket, tempPath) + // Copy each requested model from default GCS location to temporary deployable + for _, request := range requests { + modelName := request.ModelName + payload, err := p.processPayload(request) + if err != nil { + log.Error().Err(err).Msgf("Failed to parse payload for request %d", request.RequestID) + return fmt.Errorf("failed to parse payload for request %d: %w", request.RequestID, err) + } - if err := p.GcsClient.TransferFolder(sourceBucket, sourcePath, modelName, tempBucket, tempPath, modelName); err != nil { + var sourceBucket, sourcePath, sourceModelName string + if payload.ModelSource != "" { + sourceBucket, sourcePath, sourceModelName = extractGCSDetails(payload.ModelSource) + log.Info().Msgf("Using ModelSource from payload for validation: gs://%s/%s/%s", + sourceBucket, sourcePath, sourceModelName) + } else { + sourceBucket = pred.GcsModelBucket + sourcePath = pred.GcsModelBasePath + sourceModelName = modelName + log.Info().Msgf("Using default model source for validation: gs://%s/%s/%s", + sourceBucket, sourcePath, sourceModelName) + } + log.Info().Msgf("Copying model %s from gs://%s/%s/%s to temporary deployable gs://%s/%s", + modelName, sourceBucket, sourcePath, sourceModelName, tempBucket, tempPath) + + if isNotProd { + if err := p.GcsClient.TransferFolder(sourceBucket, sourcePath, sourceModelName, + tempBucket, tempPath, modelName); err != nil { + return fmt.Errorf("failed to copy requested model %s to temporary deployable: %w", modelName, err) + } + } else { + if err := p.GcsClient.TransferFolderWithSplitSources( + sourceBucket, sourcePath, pred.GcsConfigBucket, pred.GcsConfigBasePath, + sourceModelName, tempBucket, tempPath, modelName, + ); err != nil { return fmt.Errorf("failed to copy requested model %s to temporary deployable: %w", modelName, err) } } + + log.Info().Msgf("Successfully copied model %s to temporary deployable", modelName) } return nil @@ -2849,10 +3018,54 @@ func (p *Predator) ExecuteFunctionalTestRequest(req ExecuteRequestFunctionalRequ elementsPerBatch *= dim } + normalizedOutputDT := strings.ToUpper(strings.TrimPrefix(outputMeta.DataType, "TYPE_")) + isStringType := normalizedOutputDT == "STRING" || normalizedOutputDT == "BYTES" + elementSize := getElementSize(outputMeta.DataType) bytesPerBatch := int(elementsPerBatch * int64(elementSize)) - if elementSize > 0 && len(outputBytes) >= bytesPerBatch { + if isStringType { + var allBatches [][]interface{} + offset := 0 + for offset < len(outputBytes) { + var batchSlice []interface{} + for j := int64(0); j < elementsPerBatch && offset < len(outputBytes); j++ { + if offset+4 > len(outputBytes) { + modelConfig.HasNilData = true + p.PredatorConfigRepo.Update(modelConfig) + return ExecuteRequestFunctionalResponse{}, fmt.Errorf("functional test failed: insufficient bytes for string length at offset %d", offset) + } + + length := binary.LittleEndian.Uint32(outputBytes[offset : offset+4]) + offset += 4 + + if offset+int(length) > len(outputBytes) { + modelConfig.HasNilData = true + p.PredatorConfigRepo.Update(modelConfig) + return ExecuteRequestFunctionalResponse{}, fmt.Errorf("functional test failed: insufficient bytes for string content at offset %d, expected %d bytes", offset, length) + } + + stringContent := outputBytes[offset : offset+int(length)] + offset += int(length) + batchSlice = append(batchSlice, string(stringContent)) + } + + if len(batchSlice) > 0 { + allBatches = append(allBatches, batchSlice) + } + + if offset >= len(outputBytes) { + break + } + } + + convertedOutputs = append(convertedOutputs, Output{ + Name: outputMeta.Name, + Dims: dims, + DataType: outputMeta.DataType, + Data: allBatches, + }) + } else if elementSize > 0 && len(outputBytes) >= bytesPerBatch { // Calculate number of batches from total bytes numBatches := len(outputBytes) / bytesPerBatch @@ -2961,14 +3174,15 @@ func (p *Predator) ExecuteFunctionalTestRequest(req ExecuteRequestFunctionalRequ } } } - modelConfig.HasNilData = false - p.PredatorConfigRepo.Update(modelConfig) } else { modelConfig.HasNilData = true p.PredatorConfigRepo.Update(modelConfig) - return ExecuteRequestFunctionalResponse{}, fmt.Errorf("no output contents received") + return ExecuteRequestFunctionalResponse{}, fmt.Errorf("no raw output contents received from helix") } + modelConfig.HasNilData = false + p.PredatorConfigRepo.Update(modelConfig) + // Return converted response return ExecuteRequestFunctionalResponse{ ModelName: req.ModelName, @@ -3186,6 +3400,16 @@ func (p *Predator) HandleEditModel(req ModelRequest, createdBy string) (string, if err != nil { return constant.EmptyString, http.StatusInternalServerError, errors.New(errMsgProcessPayload) } + + var payloadObject Payload + if err := json.Unmarshal(payloadBytes, &payloadObject); err != nil { + return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf(errMsgProcessPayload) + } + + if payloadObject.MetaData.InstanceCount > 1 && p.isNonProductionEnvironment() { + return constant.EmptyString, http.StatusBadRequest, fmt.Errorf("instance count should be 1 for non-production environment") + } + modelName, _ := payload[fieldModelName].(string) newRequests = append(newRequests, predatorrequest.PredatorRequest{ ModelName: modelName, @@ -3306,6 +3530,11 @@ func (p *Predator) uploadSingleModel(modelItem ModelUploadItem, bucket, basePath return p.createErrorResult(modelName, "Model file sync failed", err) } + // Step 7: Copy config.pbtxt to prod config source (only in production) + if err := p.copyConfigToProdConfigSource(modelItem.GCSPath, modelName); err != nil { + return p.createErrorResult(modelName, "Failed to copy config to prod config source", err) + } + // Upload processed metadata.json (always done regardless of partial/full) metadataPath, err := p.uploadModelMetadata(modelItem.Metadata, bucket, destPath) if err != nil { @@ -3322,6 +3551,42 @@ func (p *Predator) uploadSingleModel(modelItem ModelUploadItem, bucket, basePath } } +// copyConfigToProdConfigSource copies config.pbtxt to the prod config source path +// This is done in both int and prd environments so config is available for prod deployments +func (p *Predator) copyConfigToProdConfigSource(gcsPath, modelName string) error { + // Check if config source is configured + if pred.GcsConfigBucket == "" || pred.GcsConfigBasePath == "" { + log.Warn().Msg("Config source not configured, skipping config.pbtxt copy to config source") + return nil + } + + // Parse source GCS path + srcBucket, srcPath := extractGCSPath(gcsPath) + if srcBucket == "" || srcPath == "" { + return fmt.Errorf("invalid GCS path format: %s", gcsPath) + } + + // Read config.pbtxt from source + srcConfigPath := path.Join(srcPath, configFile) + configData, err := p.GcsClient.ReadFile(srcBucket, srcConfigPath) + if err != nil { + return fmt.Errorf("failed to read config.pbtxt from source: %w", err) + } + + // Update model name while preserving formatting + updatedConfigData := p.replaceModelNameInConfigPreservingFormat(configData, modelName) + + // Upload to prod config source path with updated model name + destConfigPath := path.Join(pred.GcsConfigBasePath, modelName, configFile) + if err := p.GcsClient.UploadFile(pred.GcsConfigBucket, destConfigPath, updatedConfigData); err != nil { + return fmt.Errorf("failed to upload config.pbtxt to config source: %w", err) + } + + log.Info().Msgf("Successfully copied config.pbtxt to config source with model name %s: gs://%s/%s", + modelName, pred.GcsConfigBucket, destConfigPath) + return nil +} + // Helper functions for simplified upload flow // createErrorResult creates a standardized error result @@ -3739,3 +4004,140 @@ func (p *Predator) cleanEnsembleScheduling(metadata MetaData) MetaData { } return metadata } + +// Returns the derived model name with deployable tag +func (p *Predator) GetDerivedModelName(payloadObject Payload, requestType string) (string, error) { + if requestType != ScaleUpRequestType { + return payloadObject.ModelName, nil + } + serviceDeployableID := payloadObject.ConfigMapping.ServiceDeployableID + serviceDeployable, err := p.ServiceDeployableRepo.GetById(int(serviceDeployableID)) + if err != nil { + return constant.EmptyString, fmt.Errorf("%s: %w", errFetchDeployableConfig, err) + } + + deployableTag := serviceDeployable.DeployableTag + if deployableTag == "" { + return payloadObject.ModelName, nil + } + + derivedModelName := payloadObject.ModelName + deployableTagDelimiter + deployableTag + derivedModelName = derivedModelName + deployableTagDelimiter + scaleupTag + return derivedModelName, nil +} + +// Returns the original model name if no tag is found (backward compatibility). +func (p *Predator) GetOriginalModelName(derivedModelName string, serviceDeployableID int) (string, error) { + serviceDeployable, err := p.ServiceDeployableRepo.GetById(serviceDeployableID) + if err != nil { + return constant.EmptyString, fmt.Errorf("%s: %w", errFetchDeployableConfig, err) + } + + deployableTag := serviceDeployable.DeployableTag + if deployableTag == "" { + return derivedModelName, nil + } + + scaleupSuffix := deployableTagDelimiter + scaleupTag + derivedModelName = strings.TrimSuffix(derivedModelName, scaleupSuffix) + + deployableTagSuffix := deployableTagDelimiter + deployableTag + if originalName, foundSuffix := strings.CutSuffix(derivedModelName, deployableTagSuffix); foundSuffix { + return originalName, nil + } + + return derivedModelName, nil +} + +func (p *Predator) isNonProductionEnvironment() bool { + env := strings.ToLower(strings.TrimSpace(pred.AppEnv)) + if env == "prd" || env == "prod" { + return false + } + return true +} + +func (p *Predator) copyConfigToNewNameInConfigSource(oldModelName, newModelName string) error { + if oldModelName == newModelName { + return nil + } + + if pred.GcsConfigBucket == "" || pred.GcsConfigBasePath == "" { + log.Warn().Msg("Config source not configured, skipping config.pbtxt copy in config source") + return nil + } + + destConfigPath := path.Join(pred.GcsConfigBasePath, newModelName, configFile) + exists, err := p.GcsClient.CheckFileExists(pred.GcsConfigBucket, destConfigPath) + if err != nil { + log.Warn().Err(err).Msgf("Failed to check if config exists for %s, will attempt copy anyway", newModelName) + } else if exists { + log.Info().Msgf("Config already exists for %s in config source, skipping copy", newModelName) + return nil + } + + srcConfigPath := path.Join(pred.GcsConfigBasePath, oldModelName, configFile) + + configData, err := p.GcsClient.ReadFile(pred.GcsConfigBucket, srcConfigPath) + if err != nil { + return fmt.Errorf("failed to read config.pbtxt from %s: %w", srcConfigPath, err) + } + + // Use formatting-preserving function instead of marshal/unmarshal + updatedConfigData := p.replaceModelNameInConfigPreservingFormat(configData, newModelName) + + if err := p.GcsClient.UploadFile(pred.GcsConfigBucket, destConfigPath, updatedConfigData); err != nil { + return fmt.Errorf("failed to upload config.pbtxt to %s: %w", destConfigPath, err) + } + + log.Info().Msgf("Successfully copied config.pbtxt from %s to %s in config source", + oldModelName, newModelName) + return nil +} + +// replaceModelNameInConfigPreservingFormat updates only the top-level model name while preserving formatting +// It replaces only the first occurrence to avoid modifying nested names in inputs/outputs/instance_groups +func (p *Predator) replaceModelNameInConfigPreservingFormat(data []byte, destModelName string) []byte { + content := string(data) + lines := strings.Split(content, "\n") + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + // Match top-level "name:" field - should be at the start of line (or minimal indentation) + // Skip nested names which are typically indented with 2+ spaces + if strings.HasPrefix(trimmed, "name:") { + // Check indentation: top-level fields have minimal/no indentation + leadingWhitespace := len(line) - len(strings.TrimLeft(line, " \t")) + // Skip if heavily indented (nested field) + if leadingWhitespace >= 2 { + continue + } + + // Match the first occurrence of name: "value" pattern + namePattern := regexp.MustCompile(`name\s*:\s*"([^"]+)"`) + matches := namePattern.FindStringSubmatch(line) + if len(matches) > 1 { + oldModelName := matches[1] + // Replace only the FIRST occurrence to avoid replacing nested names + loc := namePattern.FindStringIndex(line) + if loc != nil { + // Replace only the matched portion (first occurrence) + before := line[:loc[0]] + matched := line[loc[0]:loc[1]] + after := line[loc[1]:] + // Replace the value inside quotes while preserving the "name:" format + valuePattern := regexp.MustCompile(`"([^"]+)"`) + valueReplaced := valuePattern.ReplaceAllString(matched, fmt.Sprintf(`"%s"`, destModelName)) + lines[i] = before + valueReplaced + after + } else { + // Fallback: replace all (shouldn't happen with valid input) + lines[i] = namePattern.ReplaceAllString(line, fmt.Sprintf(`name: "%s"`, destModelName)) + } + log.Info().Msgf("Replacing top-level model name in config.pbtxt: '%s' -> '%s'", oldModelName, destModelName) + break + } + } + } + + return []byte(strings.Join(lines, "\n")) +} \ No newline at end of file diff --git a/horizon/internal/repositories/sql/predatorconfig/sql.go b/horizon/internal/repositories/sql/predatorconfig/sql.go index 6776bb67..7a421639 100644 --- a/horizon/internal/repositories/sql/predatorconfig/sql.go +++ b/horizon/internal/repositories/sql/predatorconfig/sql.go @@ -23,6 +23,7 @@ type PredatorConfigRepository interface { GetByModelName(modelName string) (*PredatorConfig, error) GetActiveModelByModelName(modelName string) (*PredatorConfig, error) GetActiveModelByModelNameList(modelNames []string) ([]PredatorConfig, error) + FindByDiscoveryIDsAndAge(discoveryConfigIds []int, daysAgo int) ([]PredatorConfig, error) } type predatorConfigRepo struct { @@ -76,6 +77,7 @@ func (r *predatorConfigRepo) Update(config *PredatorConfig) error { "updated_at": config.UpdatedAt, "test_results": config.TestResults, "has_nil_data": config.HasNilData, + "source_model_name": config.SourceModelName, }).Error } @@ -149,3 +151,19 @@ func (r *predatorConfigRepo) GetByModelName(modelName string) (*PredatorConfig, err := r.db.Where("model_name = ?", modelName).First(&config).Error return &config, err } + +// FindByDiscoveryIDsAndAge returns active predator configs for given discovery IDs created before (now - daysAgo). +func (r *predatorConfigRepo) FindByDiscoveryIDsAndAge(discoveryConfigIds []int, daysAgo int) ([]PredatorConfig, error) { + var configs []PredatorConfig + if daysAgo < 0 { + return nil, errors.New("daysAgo must be >= 0") + } + cutoffDate := time.Now().AddDate(0, 0, -daysAgo) + + err := r.db.Where("discovery_config_id IN ? AND created_at < ? AND active = ?", + discoveryConfigIds, cutoffDate, true). + Find(&configs).Error + + return configs, err +} + diff --git a/horizon/internal/repositories/sql/predatorconfig/table.go b/horizon/internal/repositories/sql/predatorconfig/table.go index 448d0370..f12c20eb 100644 --- a/horizon/internal/repositories/sql/predatorconfig/table.go +++ b/horizon/internal/repositories/sql/predatorconfig/table.go @@ -22,6 +22,7 @@ type PredatorConfig struct { UpdatedAt time.Time TestResults json.RawMessage HasNilData bool `gorm:"default:false"` // Tracks if model has nil data issues + SourceModelName string `gorm:"column:source_model_name"` } func (PredatorConfig) TableName() string { From 1aa8d3280f5113611432e3d7032d2130e9b4816a Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 13:11:46 +0530 Subject: [PATCH 07/24] predator init fixes --- horizon/internal/predator/init.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/horizon/internal/predator/init.go b/horizon/internal/predator/init.go index 07ef8bee..c90cb110 100644 --- a/horizon/internal/predator/init.go +++ b/horizon/internal/predator/init.go @@ -15,7 +15,9 @@ var ( TestGpuDeployableID int initOnce sync.Once IsMeeshoEnabled bool - IsGcsEnabled bool + AppEnv string + GcsConfigBucket string + GcsConfigBasePath string ) func Init(config configs.Configs) { @@ -27,7 +29,9 @@ func Init(config configs.Configs) { TestDeployableID = config.TestDeployableID TestGpuDeployableID = config.TestGpuDeployableID IsMeeshoEnabled = config.IsMeeshoEnabled - IsGcsEnabled = config.GcsEnabled + AppEnv = config.AppEnv + GcsConfigBasePath = config.GcsConfigBasePath + GcsConfigBucket = config.GcsConfigBucket }) } From 19c0d774008dc7d21f71dd42b9cee42ee1964c14 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 13:45:23 +0530 Subject: [PATCH 08/24] code rabbit issues --- horizon/internal/inferflow/handler/inferflow.go | 14 +++++++++++--- .../repositories/sql/discoveryconfig/sql.go | 8 ++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/horizon/internal/inferflow/handler/inferflow.go b/horizon/internal/inferflow/handler/inferflow.go index 55275115..6e9e183d 100644 --- a/horizon/internal/inferflow/handler/inferflow.go +++ b/horizon/internal/inferflow/handler/inferflow.go @@ -424,7 +424,7 @@ func (m *InferFlow) ScaleUp(request ScaleUpConfigRequest) (Response, error) { var latestSourceRequest GetLatestRequestResponse latestSourceRequest, err = m.GetLatestRequest(sourceConfigID) - if err != nil { + if err != nil || latestSourceRequest.Error != emptyResponse { return Response{}, errors.New("failed to get latest request for the source configID: " + sourceConfigID + ": " + err.Error()) } request.Payload.ConfigMapping.SourceConfigID = sourceConfigID @@ -624,7 +624,10 @@ func (m *InferFlow) handleApprovedRequest(request ReviewRequest) (Response, erro var configExistedBeforeTx bool if tempRequest.RequestType == promoteRequestType { - existingConfig, _ := m.InferFlowConfigRepo.GetByID(tempRequest.ConfigID) + existingConfig, err := m.InferFlowConfigRepo.GetByID(tempRequest.ConfigID) + if err != nil { + return Response{}, fmt.Errorf("failed to check existing config for promote: %w", err) + } configExistedBeforeTx = existingConfig != nil } @@ -1171,7 +1174,7 @@ func (m *InferFlow) batchFetchDiscoveryConfigs(discoveryIDs []int) ( return emptyDiscoveryMap, emptyServiceDeployableMap, nil } - discoveryConfigs, err := m.DiscoveryConfigRepo.GetByServiceDeployableIDs(discoveryIDs) + discoveryConfigs, err := m.DiscoveryConfigRepo.GetByDiscoveryIDs(discoveryIDs) if err != nil { return nil, nil, fmt.Errorf("failed to get discovery configs: %w", err) } @@ -1678,6 +1681,11 @@ func (m *InferFlow) GetFeatureSchema(request FeatureSchemaRequest) (FeatureSchem Data: []inferflow.SchemaComponents{}, }, err } + if len(inferflowRequests) == 0 { + return FeatureSchemaResponse{ + Data: []inferflow.SchemaComponents{}, + }, errors.New("no inferflow config found for model_config_id=" + request.ModelConfigId + " version=" + strconv.Itoa(version)) + } inferflowConfig := inferflowRequests[0].Payload componentConfig := &inferflowConfig.ConfigValue.ComponentConfig responseConfig := &inferflowConfig.ConfigValue.ResponseConfig diff --git a/horizon/internal/repositories/sql/discoveryconfig/sql.go b/horizon/internal/repositories/sql/discoveryconfig/sql.go index 39577085..f5c3b755 100644 --- a/horizon/internal/repositories/sql/discoveryconfig/sql.go +++ b/horizon/internal/repositories/sql/discoveryconfig/sql.go @@ -17,7 +17,7 @@ type DiscoveryConfigRepository interface { GetByToken(token string) ([]DiscoveryConfig, error) GetById(configId int) (*DiscoveryConfig, error) GetByServiceDeployableID(serviceDeployableID int) ([]DiscoveryConfig, error) - GetByServiceDeployableIDs(serviceDeployableIDs []int) ([]DiscoveryConfig, error) + GetByDiscoveryIDs(discoveryIDs []int) ([]DiscoveryConfig, error) DB() *gorm.DB WithTx(tx *gorm.DB) DiscoveryConfigRepository DeleteByIDTx(tx *gorm.DB, id int) error @@ -108,12 +108,12 @@ func (r *discoveryConfigRepo) WithTx(tx *gorm.DB) DiscoveryConfigRepository { } } -func (r *discoveryConfigRepo) GetByServiceDeployableIDs(serviceDeployableIDs []int) ([]DiscoveryConfig, error) { - if len(serviceDeployableIDs) == 0 { +func (r *discoveryConfigRepo) GetByDiscoveryIDs(discoveryIDs []int) ([]DiscoveryConfig, error) { + if len(discoveryIDs) == 0 { return []DiscoveryConfig{}, nil } var configs []DiscoveryConfig - err := r.db.Where("id IN ?", serviceDeployableIDs).Find(&configs).Error + err := r.db.Where("id IN ?", discoveryIDs).Find(&configs).Error return configs, err } From 0423a9e5cfa390a88db3968d1f4271eccd57c72f Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 14:14:33 +0530 Subject: [PATCH 09/24] bulk delete changes --- horizon/internal/configs/app_config.go | 10 +- .../externalcall/prometheus_client.go | 97 +++++ .../bulkdeletestrategy/inferflow_service.go | 2 +- .../bulkdeletestrategy/numerix_service.go | 2 +- .../bulkdeletestrategy/predator_service.go | 348 ++++++++++++++++-- .../bulkdeletestrategy/strategy_selector.go | 54 ++- 6 files changed, 460 insertions(+), 53 deletions(-) diff --git a/horizon/internal/configs/app_config.go b/horizon/internal/configs/app_config.go index 34f20755..431a92d5 100644 --- a/horizon/internal/configs/app_config.go +++ b/horizon/internal/configs/app_config.go @@ -67,9 +67,13 @@ type Configs struct { NumerixAppName string `mapstructure:"numerix_app_name"` NumerixMonitoringUrl string `mapstructure:"numerix_monitoring_url"` - MaxNumerixInactiveAge int `mapstructure:"max_numerix_inactive_age"` - MaxInferflowInactiveAge int `mapstructure:"max_inferflow_inactive_age"` - MaxPredatorInactiveAge int `mapstructure:"max_predator_inactive_age"` + BulkDeletePredatorEnabled bool `mapstructure:"bulk_delete_predator_enabled"` + BulkDeleteInferflowEnabled bool `mapstructure:"bulk_delete_inferflow_enabled"` + BulkDeleteNumerixEnabled bool `mapstructure:"bulk_delete_numerix_enabled"` + BulkDeletePredatorMaxInactiveDays int `mapstructure:"bulk_delete_predator_max_inactive_days"` + BulkDeleteInferflowMaxInactiveDays int `mapstructure:"bulk_delete_inferflow_max_inactive_days"` + BulkDeleteNumerixMaxInactiveDays int `mapstructure:"bulk_delete_numerix_max_inactive_days"` + BulkDeletePredatorRequestSubmissionEnabled bool `mapstructure:"bulk_delete_predator_request_submission_enabled"` InferflowAppName string `mapstructure:"inferflow_app_name"` diff --git a/horizon/internal/externalcall/prometheus_client.go b/horizon/internal/externalcall/prometheus_client.go index c1bc9b9c..a89a7c32 100644 --- a/horizon/internal/externalcall/prometheus_client.go +++ b/horizon/internal/externalcall/prometheus_client.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/url" + "strconv" "sync" "time" ) @@ -14,6 +15,7 @@ type PrometheusClient interface { GetModelNames(serviceName string) ([]string, error) GetInferflowConfigNames(serviceName string) ([]string, error) GetNumerixConfigNames() ([]string, error) + GetPredatorModelTraffic(serviceName string, daysAgo int) (map[string]PredatorModelTraffic, error) } type prometheusClientImpl struct { @@ -91,6 +93,27 @@ type prometheusNumerixConfigResponse struct { } `json:"data"` } +type PredatorModelResponse struct { + Status string `json:"status"` + IsPartial bool `json:"isPartial"` + Data struct { + ResultType string `json:"resultType"` + Result []struct { + Metric struct { + Model string `json:"model"` + } `json:"metric"` + Values [][]interface{} `json:"values"` // [[timestamp, "value"], ...] + } `json:"result"` + } `json:"data"` +} + +// PredatorModelTraffic holds model name and its traffic data +type PredatorModelTraffic struct { + ModelName string + TotalTraffic float64 // Sum of all values + DataPoints int // Number of data points +} + func (p *prometheusClientImpl) GetModelNames(serviceName string) ([]string, error) { end := time.Now().Unix() daysAgo := vmselectStartDaysAgo @@ -141,6 +164,80 @@ func (p *prometheusClientImpl) GetModelNames(serviceName string) ([]string, erro return modelNames, nil } +// GetPredatorModelTraffic returns models with their traffic data for the past N days +func (p *prometheusClientImpl) GetPredatorModelTraffic(serviceName string, daysAgo int) (map[string]PredatorModelTraffic, error) { + end := time.Now().Unix() + start := end - int64(daysAgo*24*60*60) + step := "1m" + + query := fmt.Sprintf( + "sum by (model)(increase(nv_inference_count{service=\"%s\"}[1m]))", + serviceName, + ) + + url := fmt.Sprintf("%s/prometheus/api/v1/query_range?query=%s&start=%d&end=%d&step=%s", + p.BaseURL, + escapePrometheusQuery(query), + start, + end, + step, + ) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", p.APIKey) + + resp, err := p.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call Prometheus: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("Prometheus call failed, status: %d, body: %s", resp.StatusCode, string(bodyBytes)) + } + + var pr PredatorModelResponse + if err := json.NewDecoder(resp.Body).Decode(&pr); err != nil { + return nil, fmt.Errorf("failed to decode Prometheus response: %w", err) + } + + // Parse results into map + result := make(map[string]PredatorModelTraffic) + for _, item := range pr.Data.Result { + modelName := item.Metric.Model + if modelName == "" { + continue + } + + var totalTraffic float64 + dataPoints := 0 + + for _, valueArr := range item.Values { + if len(valueArr) >= 2 { + if valueStr, ok := valueArr[1].(string); ok { + if val, err := strconv.ParseFloat(valueStr, 64); err == nil { + totalTraffic += val + dataPoints++ + } + } + } + } + + result[modelName] = PredatorModelTraffic{ + ModelName: modelName, + TotalTraffic: totalTraffic, + DataPoints: dataPoints, + } + } + + return result, nil +} + func (p *prometheusClientImpl) GetInferflowConfigNames(serviceName string) ([]string, error) { end := time.Now().Unix() daysAgo := vmselectStartDaysAgo diff --git a/horizon/internal/jobs/bulkdeletestrategy/inferflow_service.go b/horizon/internal/jobs/bulkdeletestrategy/inferflow_service.go index f2c8efc7..f330c86f 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/inferflow_service.go +++ b/horizon/internal/jobs/bulkdeletestrategy/inferflow_service.go @@ -87,7 +87,7 @@ func (m *InferflowService) fetchNonActiveInferflowConfigList(serviceDeployable s discoveryConfigId = append(discoveryConfigId, discoveryConfigEntity.ID) } - inferflowConfigList, err := inferflowConfigRepo.FindByDiscoveryIDsAndCreatedBefore(discoveryConfigId, maxInferflowInactiveAge) + inferflowConfigList, err := inferflowConfigRepo.FindByDiscoveryIDsAndCreatedBefore(discoveryConfigId, bulkDeleteInferflowMaxInactiveDays) if err != nil { return nil, err } diff --git a/horizon/internal/jobs/bulkdeletestrategy/numerix_service.go b/horizon/internal/jobs/bulkdeletestrategy/numerix_service.go index e8c2a55c..84d48791 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/numerix_service.go +++ b/horizon/internal/jobs/bulkdeletestrategy/numerix_service.go @@ -66,7 +66,7 @@ func (i *NumerixService) fetchNonActiveNumerixConfigList(numerixConfigRepo numer var allNumerixConfigList []string - numerixConfigList, err := numerixConfigRepo.FindByCreatedBefore(maxNumerixInactiveAge) + numerixConfigList, err := numerixConfigRepo.FindByCreatedBefore(bulkDeleteNumerixMaxInactiveDays) if err != nil { return nil, err } diff --git a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go index 89f1390a..7a9e582b 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go +++ b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go @@ -4,13 +4,16 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/Meesho/BharatMLStack/horizon/internal/constant" "github.com/Meesho/BharatMLStack/horizon/internal/externalcall" infrastructurehandler "github.com/Meesho/BharatMLStack/horizon/internal/infrastructure/handler" "github.com/Meesho/BharatMLStack/horizon/internal/predator/handler" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/counter" "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/discoveryconfig" "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorconfig" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorrequest" "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" "github.com/Meesho/BharatMLStack/horizon/pkg/infra" "github.com/rs/zerolog/log" @@ -28,22 +31,40 @@ type PredatorService struct { const ( slashConstant = "/" gcsPrefix = "gs://" + bulkDeleteCreatedBy = "horizon-bulk-delete" ) +type ModelInfo struct { + ModelName string + DiscoveryConfigID int +} + +type PredatorBulkDeleteRepos struct { + discoveryConfigRepo discoveryconfig.DiscoveryConfigRepository + predatorConfigRepo predatorconfig.PredatorConfigRepository + predatorRequestRepo predatorrequest.PredatorRequestRepository + groupCounterRepo counter.GroupIdCounterRepository +} + + func (p *PredatorService) ProcessBulkDelete(serviceDeployable servicedeployableconfig.ServiceDeployableConfig) error { - discoveryConfigRepo, predatorConfigRepo, err := p.initializeRepositories() + predatorBulkDeleteRepos, err := p.initializeRepositories() if err != nil { log.Error().Err(err).Msg("Error initializing repositories") return err } - discoveryConfigList, err := discoveryConfigRepo.GetByServiceDeployableID(serviceDeployable.ID) + discoveryConfigList, err := predatorBulkDeleteRepos.discoveryConfigRepo.GetByServiceDeployableID(serviceDeployable.ID) if err != nil { log.Error().Err(err).Msg("Error fetching discovery config list") return err } - activeModelNameList, parentModelNameList, parentToChildMapping, discoveryConfigId, err := p.fetchModelNames(serviceDeployable, discoveryConfigList, predatorConfigRepo) + _, zeroTrafficModelList, parentToChildMapping, trafficData, err := p.fetchModelNames( + serviceDeployable, + discoveryConfigList, + &predatorBulkDeleteRepos, + ) if err != nil { log.Error().Err(err).Msg("Error fetching model names") return err @@ -55,24 +76,57 @@ func (p *PredatorService) ProcessBulkDelete(serviceDeployable servicedeployablec return err } - inactiveModelNameList := difference(parentModelNameList, activeModelNameList) - childModelNameList := make([]string, 0) - for _, inactiveModel := range inactiveModelNameList { - if _, found := parentToChildMapping[inactiveModel]; !found { - continue + // Get child models for zero traffic parents + var modelsToDelete []ModelInfo + for _, parentModel := range zeroTrafficModelList { + modelsToDelete = append(modelsToDelete, parentModel) + + // Add child models if any + if children, found := parentToChildMapping[parentModel.ModelName]; found { + for _, childName := range children { + // Find child's discovery config ID from predator config + childTraffic, existsInPrometheus := trafficData[childName] + if existsInPrometheus && childTraffic.TotalTraffic > 0 { + log.Info().Msgf("[SKIP CHILD] Model: %s (child of %s) has traffic (%.2f), skipping deletion", + childName, parentModel.ModelName, childTraffic.TotalTraffic) + continue + } + + childConfig, _ := predatorBulkDeleteRepos.predatorConfigRepo.GetActiveModelByModelName(childName) + if childConfig != nil { + modelsToDelete = append(modelsToDelete, ModelInfo{ + ModelName: childName, + DiscoveryConfigID: childConfig.DiscoveryConfigID, + }) + log.Info().Msgf("[DELETE CHILD] Model: %s (child of %s) has zero traffic, adding to delete list", + childName, parentModel.ModelName) + } + } } - childModelNameList = append(childModelNameList, parentToChildMapping[inactiveModel]...) } - deleteModelNameList := p.processGCSAndDeleteModels(strings.TrimSuffix(deployableConfig.GCSBucketPath, "/*"), append(inactiveModelNameList, childModelNameList...)) + if len(modelsToDelete) == 0 { + log.Info().Msg("No models to delete") + return nil + } - err = p.deactivateModelsAndRestartDeployable(deleteModelNameList, serviceDeployable, predatorConfigRepo, discoveryConfigId) - if err != nil { - log.Error().Err(err).Msg("Error deactivating models and restarting deployable") - return err + // Process deletion: GCS delete, deactivate predator_config, deactivate discovery_config, create request + deletedModels := p.processDeleteModels( + strings.TrimSuffix(deployableConfig.GCSBucketPath, "/*"), + modelsToDelete, + serviceDeployable, + &predatorBulkDeleteRepos, + ) + + if len(deletedModels) > 0 { + // Restart deployable after deletion + err = p.infrastructureHandler.RestartDeployment(serviceDeployable.Name, p.workingEnv, false) + if err != nil { + log.Error().Err(err).Msg("Error restarting deployable") + } } - err = p.sendSlackNotification(serviceDeployable.Name, deleteModelNameList) + err = p.sendSlackNotification(serviceDeployable.Name, deletedModels) if err != nil { log.Error().Err(err).Msg("Error sending Slack notification") return err @@ -81,62 +135,284 @@ func (p *PredatorService) ProcessBulkDelete(serviceDeployable servicedeployablec return nil } -func (p *PredatorService) initializeRepositories() (discoveryconfig.DiscoveryConfigRepository, predatorconfig.PredatorConfigRepository, error) { +func (p *PredatorService) initializeRepositories() (PredatorBulkDeleteRepos, error) { discoveryConfigRepo, err := discoveryconfig.NewRepository(p.sqlConn) if err != nil { log.Err(err).Msg("Error initializing discovery config repository") - return nil, nil, err + return PredatorBulkDeleteRepos{}, err } predatorConfigRepo, err := predatorconfig.NewRepository(p.sqlConn) if err != nil { log.Err(err).Msg("Error initializing predator config repository") - return nil, nil, err + return PredatorBulkDeleteRepos{}, err + } + + predatorRequestRepo, err := predatorrequest.NewRepository(p.sqlConn) + if err != nil { + log.Err(err).Msg("Error initializing predator request repository") + return PredatorBulkDeleteRepos{}, err } - return discoveryConfigRepo, predatorConfigRepo, nil + groupCounterRepo, err := counter.NewCounterRepository(p.sqlConn) + if err != nil { + log.Err(err).Msg("Error initializing group counter repository") + return PredatorBulkDeleteRepos{}, err + } + + return PredatorBulkDeleteRepos{ + discoveryConfigRepo: discoveryConfigRepo, + predatorConfigRepo: predatorConfigRepo, + predatorRequestRepo: predatorRequestRepo, + groupCounterRepo: groupCounterRepo, + }, nil } -func (p *PredatorService) fetchModelNames(serviceDeployable servicedeployableconfig.ServiceDeployableConfig, discoveryConfigList []discoveryconfig.DiscoveryConfig, predatorConfigRepo predatorconfig.PredatorConfigRepository) ([]string, []string, map[string][]string, []int, error) { - activeModelNameList, err := p.prometheusClient.GetModelNames(serviceDeployable.Name) +func (p *PredatorService) fetchModelNames( + serviceDeployable servicedeployableconfig.ServiceDeployableConfig, + discoveryConfigList []discoveryconfig.DiscoveryConfig, + predatorBulkDeleteRepos *PredatorBulkDeleteRepos, +) ([]ModelInfo, []ModelInfo, map[string][]string, map[string]externalcall.PredatorModelTraffic, error) { + + zeroTrafficDays := bulkDeletePredatorMaxInactiveDays + + trafficData, err := p.prometheusClient.GetPredatorModelTraffic(serviceDeployable.Name, zeroTrafficDays) if err != nil { - log.Err(err).Msg("Error fetching active model names from Prometheus") + log.Err(err).Msg("Error fetching predator model traffic from Prometheus") return nil, nil, nil, nil, err } - var allModelNameList []string - parentToChildMapping := make(map[string][]string) - childModelNameList := make([]string, 0) - var discoveryConfigId []int + // Get ALL models from DB + var discoveryConfigIds []int for _, discoveryConfigEntity := range discoveryConfigList { - discoveryConfigId = append(discoveryConfigId, discoveryConfigEntity.ID) + if discoveryConfigEntity.Active { + discoveryConfigIds = append(discoveryConfigIds, discoveryConfigEntity.ID) + } } - predatorConfigList, err := predatorConfigRepo.FindByDiscoveryIDsAndCreatedBefore(discoveryConfigId, maxPredatorInactiveAge) + predatorConfigList, err := predatorBulkDeleteRepos.predatorConfigRepo.FindByDiscoveryIDsAndAge(discoveryConfigIds, zeroTrafficDays) if err != nil { + log.Err(err).Msg("Error fetching predator configs from DB") return nil, nil, nil, nil, err } - for _, predatorConfigEntity := range predatorConfigList { + modelInfoMap := make(map[string]ModelInfo) // deduplicate + parentToChildMapping := make(map[string][]string) + childModelNames := make(map[string]bool) + + for _, pc := range predatorConfigList { + if !pc.Active { + continue + } + + if _, exists := modelInfoMap[pc.ModelName]; exists { + continue + } + + modelInfoMap[pc.ModelName] = ModelInfo{ + ModelName: pc.ModelName, + DiscoveryConfigID: pc.DiscoveryConfigID, + } + var metaData handler.MetaData - if err := json.Unmarshal(predatorConfigEntity.MetaData, &metaData); err != nil { + if err := json.Unmarshal(pc.MetaData, &metaData); err != nil { + log.Err(err).Msg("Could not unmarshall model metadata for: " + pc.ModelName + " for scheduled deletion") continue } - if metaData.Ensembling.Step != nil { + if (&metaData.Ensembling) != nil && metaData.Ensembling.Step != nil { for _, step := range metaData.Ensembling.Step { if step.ModelName != "" { - parentToChildMapping[predatorConfigEntity.ModelName] = append(parentToChildMapping[predatorConfigEntity.ModelName], step.ModelName) - childModelNameList = append(childModelNameList, step.ModelName) + parentToChildMapping[pc.ModelName] = append(parentToChildMapping[pc.ModelName], step.ModelName) + childModelNames[step.ModelName] = true } } } + } - allModelNameList = append(allModelNameList, predatorConfigEntity.ModelName) + var allParentModels []ModelInfo + for modelName, info := range modelInfoMap { + if !childModelNames[modelName] { + allParentModels = append(allParentModels, info) + } } - parentModelNameList := difference(allModelNameList, childModelNameList) - return activeModelNameList, parentModelNameList, parentToChildMapping, discoveryConfigId, nil + // Separate active vs zero-traffic + var activeModels []ModelInfo + var zeroTrafficModels []ModelInfo + + log.Info().Msgf("=== Traffic check for %s (past %d days) ===", serviceDeployable.Name, zeroTrafficDays) + + for _, modelInfo := range allParentModels { + traffic, existsInPrometheus := trafficData[modelInfo.ModelName] + + if existsInPrometheus && traffic.TotalTraffic > 0 { + activeModels = append(activeModels, modelInfo) + log.Info().Msgf("[ACTIVE] Model: %s | Traffic: %.2f", modelInfo.ModelName, traffic.TotalTraffic) + } else { + zeroTrafficModels = append(zeroTrafficModels, modelInfo) + log.Warn().Msgf("[ZERO TRAFFIC - DELETE CANDIDATE] Model: %s | 0 traffic for %d days", modelInfo.ModelName, zeroTrafficDays) + } + } + + log.Info().Msgf("Summary: Total: %d | Active: %d | Zero traffic (to delete): %d", + len(allParentModels), len(activeModels), len(zeroTrafficModels)) + + return activeModels, zeroTrafficModels, parentToChildMapping, trafficData, nil +} + +// processDeleteModels - NEW: Delete from GCS, DB, and create approved delete request +func (p *PredatorService) processDeleteModels( + basePath string, + modelInfoList []ModelInfo, + serviceDeployableConfig servicedeployableconfig.ServiceDeployableConfig, + predatorBulkDeleteRepos *PredatorBulkDeleteRepos, +) []string { + srcBucket, srcPath := extractGCSPath(basePath) + + var deletedModels []string + for _, modelInfo := range modelInfoList { + modelName := modelInfo.ModelName + discoveryConfigID := modelInfo.DiscoveryConfigID + modelGCSPath := srcPath + "/" + modelName + existsInGCS, err := p.gcsClient.CheckFolderExists(srcBucket, modelGCSPath) + if err != nil { + log.Warn().Err(err).Msgf("Failed to check GCS existence for model %s, assuming it exists", modelName) + existsInGCS = true + } + + err = p.processModelDeletion( + srcBucket, srcPath, modelName, discoveryConfigID, serviceDeployableConfig.ID, + predatorBulkDeleteRepos, existsInGCS, + ) + + if err != nil { + log.Error().Err(err).Msgf("Failed to process scheduled deletion for model: %s, skipping", modelName) + continue + } + + deletedModels = append(deletedModels, modelInfo.ModelName) + } + + return deletedModels +} + +func (p *PredatorService) processModelDeletion( + srcBucket, srcPath, modelName string, + discoveryConfigID int, serviceDeployableID int, + predatorBulkDeleteRepos *PredatorBulkDeleteRepos, + existsInGCS bool, +) (err error) { + db := predatorBulkDeleteRepos.predatorConfigRepo.DB() + + tx := db.Begin() + if tx.Error != nil { + return fmt.Errorf("failed to start transaction: %w", tx.Error) + } + + defer func() { + if r := recover(); r != nil { + tx.Rollback() + log.Error().Msgf("Panic recovered, transaction rolled back for model: %s", modelName) + err = fmt.Errorf("panic during deletion of model %s: %v", modelName, r) + } + }() + + predatorConfig, err := predatorBulkDeleteRepos.predatorConfigRepo.WithTx(tx).GetActiveModelByModelName(modelName) + if err != nil { + tx.Rollback() + return fmt.Errorf("failed to fetch predator config for model %s: %w", modelName, err) + } + if predatorConfig == nil { + tx.Rollback() + return fmt.Errorf("no active predator config found for model %s", modelName) + } + + // 1. Deactivate predator_config (in transaction) + predatorConfig.Active = false + predatorConfig.UpdatedAt = time.Now() + predatorConfig.UpdatedBy = bulkDeleteCreatedBy + err = predatorBulkDeleteRepos.predatorConfigRepo.WithTx(tx).Update(predatorConfig) + if err != nil { + tx.Rollback() + return fmt.Errorf("failed to deactivate predator_config: %w", err) + } + log.Info().Msgf("Deactivated predator_config: %s", modelName) + + // 2. Deactivate discovery_config (in transaction) + err = predatorBulkDeleteRepos.discoveryConfigRepo.WithTx(tx).DeactivateByID(discoveryConfigID, bulkDeleteCreatedBy) + if err != nil { + tx.Rollback() + return fmt.Errorf("failed to deactivate discovery_config ID %d: %w", discoveryConfigID, err) + } + log.Info().Msgf("Deactivated discovery_config ID: %d", discoveryConfigID) + + // 3. Create APPROVED delete request (in transaction) if flag enabled + if enablePredatorRequestSubmission { + // create predator payload for creating deletion request + payload := map[string]interface{}{ + "model_name": modelName, + "model_source_path": fmt.Sprintf("gs://%s/%s/%s", srcBucket, srcPath, modelName), + "meta_data": predatorConfig.MetaData, + "discovery_config_id": discoveryConfigID, + "config_mapping": map[string]interface{}{ + "service_deployable_id": serviceDeployableID, + }, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + tx.Rollback() + return fmt.Errorf("failed to marshall payload for model %s: %w", modelName, err) + } + + groupID, err := predatorBulkDeleteRepos.groupCounterRepo.GetAndIncrementCounter(1) + if err != nil { + tx.Rollback() + return fmt.Errorf("failed to get new groupID for model deletion: %w", err) + } + + deleteRequest := predatorrequest.PredatorRequest{ + ModelName: modelName, + GroupId: groupID, + Payload: string(payloadBytes), + CreatedBy: bulkDeleteCreatedBy, + UpdatedBy: bulkDeleteCreatedBy, + Reviewer: bulkDeleteCreatedBy, + RequestType: "Delete", + Status: "Approved", + RequestStage: "DB Population", + Active: false, + IsValid: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err = predatorBulkDeleteRepos.predatorRequestRepo.WithTx(tx).Create(&deleteRequest) + if err != nil { + tx.Rollback() + return fmt.Errorf("failed to create delete request: %w", err) + } + log.Info().Msgf("Created APPROVED delete request: %s", modelName) + } + + // 4. Commit DB transaction + if err := tx.Commit().Error; err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + log.Info().Msgf("DB transaction committed for model: %s", modelName) + + // Step 5: Delete from GCS (AFTER successful commit) + if existsInGCS { + if err := p.gcsClient.DeleteFolder(srcBucket, srcPath, modelName); err != nil { + log.Error().Err(err).Msgf("GCS deletion failed for model %s after DB commit - orphaned data may need manual cleanup", modelName) + } else { + log.Info().Msgf("Deleted model from GCS: %s", modelName) + } + } else { + log.Info().Msgf("Model %s not found in GCS, skipping GCS deletion (DB cleanup only)", modelName) + } + + return nil } func (p *PredatorService) deserializeDeployableConfig(serviceDeployable servicedeployableconfig.ServiceDeployableConfig) (handler.PredatorDeployableConfig, error) { diff --git a/horizon/internal/jobs/bulkdeletestrategy/strategy_selector.go b/horizon/internal/jobs/bulkdeletestrategy/strategy_selector.go index 63315f73..1a672991 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/strategy_selector.go +++ b/horizon/internal/jobs/bulkdeletestrategy/strategy_selector.go @@ -29,20 +29,39 @@ type StrategySelectorImpl struct { } var ( - strategySelectorOnce sync.Once - maxPredatorInactiveAge int - defaultModelPath string - maxInferflowInactiveAge int - maxNumerixInactiveAge int + strategySelectorOnce sync.Once + defaultModelPath string + bulkDeletePredatorEnabled bool + bulkDeletePredatorMaxInactiveDays int + bulkDeleteInferflowEnabled bool + bulkDeleteInferflowMaxInactiveDays int + bulkDeleteNumerixEnabled bool + bulkDeleteNumerixMaxInactiveDays int + enablePredatorRequestSubmission bool +) + +const ( + inferflowService = "inferflow" + predatorService = "predator" + numerixService = "numerix" ) func Init(config configs.Configs) StrategySelectorImpl { var strategySelectorImpl StrategySelectorImpl strategySelectorOnce.Do(func() { - maxPredatorInactiveAge = config.MaxPredatorInactiveAge defaultModelPath = config.DefaultModelPath - maxInferflowInactiveAge = config.MaxInferflowInactiveAge - maxNumerixInactiveAge = config.MaxNumerixInactiveAge + + bulkDeletePredatorEnabled = config.BulkDeletePredatorEnabled + bulkDeletePredatorMaxInactiveDays = config.BulkDeletePredatorMaxInactiveDays + + bulkDeleteInferflowEnabled = config.BulkDeleteInferflowEnabled + bulkDeleteInferflowMaxInactiveDays = config.BulkDeleteInferflowMaxInactiveDays + + bulkDeleteNumerixEnabled = config.BulkDeleteNumerixEnabled + bulkDeleteNumerixMaxInactiveDays = config.BulkDeleteNumerixMaxInactiveDays + + enablePredatorRequestSubmission = config.BulkDeletePredatorRequestSubmissionEnabled + connection, err := infra.SQL.GetConnection() if err != nil { @@ -62,7 +81,7 @@ func Init(config configs.Configs) StrategySelectorImpl { sqlConn: sqlConn, prometheusClient: externalcall.GetPrometheusClient(), slackClient: externalcall.GetSlackClient(), - gcsClient: externalcall.CreateGCSClient(config.GcsEnabled), + gcsClient: externalcall.CreateGCSClient(), InferflowEtcdClient: inferflowEtcdClient, NumerixEtcdClient: numerixEtcdClient, infrastructureHandler: infrastructureHandler, @@ -71,16 +90,27 @@ func Init(config configs.Configs) StrategySelectorImpl { }) return strategySelectorImpl } + func (ss *StrategySelectorImpl) GetBulkDeleteStrategy(service string) (BulkDeleteStrategy, error) { switch service { - case "INFERFLOW": + case inferflowService: + if !bulkDeleteInferflowEnabled { + return nil, errors.New("inferflow bulk delete is disabled for this environment") + } return &InferflowService{ss.sqlConn, ss.prometheusClient, ss.slackClient, ss.InferflowEtcdClient}, nil - case "PREDATOR": + case predatorService: + if !bulkDeletePredatorEnabled { + return nil, errors.New("predator bulk delete is disabled for this environment") + } return &PredatorService{ss.sqlConn, ss.prometheusClient, ss.infrastructureHandler, ss.workingEnv, ss.slackClient, ss.gcsClient}, nil - case "NUMERIX": + case numerixService: + if !bulkDeleteNumerixEnabled { + return nil, errors.New("numerix bulk delete is disabled for this environment") + } return &NumerixService{ss.sqlConn, ss.prometheusClient, ss.slackClient, ss.NumerixEtcdClient}, nil default: log.Warn().Msg("Unknown service type: " + service) return nil, errors.New("unknown service type: " + service) } } + From 9754fcb1d771bf7304259df69a59460445a88943 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 15:15:14 +0530 Subject: [PATCH 10/24] schema client separation and fixes --- horizon/internal/externalcall/gcs_client.go | 8 ++-- .../sql/inferflow/config/repository.go | 8 ++-- horizon/pkg/configschemaclient/client.go | 40 +++++-------------- .../configschemaclient/internal_processor.go | 19 +++++++++ .../internal_processor_stub.go | 21 ++++++++++ horizon/pkg/configschemaclient/types.go | 24 +++++++---- 6 files changed, 76 insertions(+), 44 deletions(-) create mode 100644 horizon/pkg/configschemaclient/internal_processor.go create mode 100644 horizon/pkg/configschemaclient/internal_processor_stub.go diff --git a/horizon/internal/externalcall/gcs_client.go b/horizon/internal/externalcall/gcs_client.go index 47ac4948..99ec2a71 100644 --- a/horizon/internal/externalcall/gcs_client.go +++ b/horizon/internal/externalcall/gcs_client.go @@ -748,10 +748,10 @@ func (g *GCSClient) FindFileWithSuffix(bucket, folderPath, suffix string) (bool, if err != nil { return false, "", fmt.Errorf("failed to list objects: %w", err) } - return foundFile != "", foundFile, nil - - log.Info().Msgf("No file found with suffix '%s' in %s/%s", suffix, bucket, folderPath) - return false, "", nil + if foundFile == "" { + return false, "", fmt.Errorf("no file found with suffix '%s' in %s/%s", suffix, bucket, folderPath) + } + return true, foundFile, nil } // ObjectVisitor is called for each object. Return an error to stop iteration. diff --git a/horizon/internal/repositories/sql/inferflow/config/repository.go b/horizon/internal/repositories/sql/inferflow/config/repository.go index de615420..d5f5588c 100644 --- a/horizon/internal/repositories/sql/inferflow/config/repository.go +++ b/horizon/internal/repositories/sql/inferflow/config/repository.go @@ -86,10 +86,11 @@ func (g *InferflowConfig) GetAll() ([]Table, error) { return configs, result.Error } -func (g *InferflowConfig) GetByID(configID string) (table *Table, err error) { +func (g *InferflowConfig) GetByID(configID string) (*Table, error) { + var table Table result := g.db.Where("config_id = ? and active = ?", configID, true). Order("updated_at DESC"). - First(table) + First(&table) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -97,7 +98,8 @@ func (g *InferflowConfig) GetByID(configID string) (table *Table, err error) { } return nil, result.Error } - return table, nil + + return &table, nil } func (g *InferflowConfig) DoesConfigIDExist(configID string) (bool, error) { diff --git a/horizon/pkg/configschemaclient/client.go b/horizon/pkg/configschemaclient/client.go index d372f616..cc2f054c 100644 --- a/horizon/pkg/configschemaclient/client.go +++ b/horizon/pkg/configschemaclient/client.go @@ -5,7 +5,7 @@ import ( ) // BuildFeatureSchema builds a feature schema from the component and response configs. -// It processes components in order: FS → RTP → Numerix Output → Predator Output → Numerix Input → Predator Input +// It processes components in order: FS → RTP → SeenScore → Numerix Output → Predator Output → Numerix Input → Predator Input func BuildFeatureSchema(componentConfig *ComponentConfig, responseConfig *ResponseConfig) []SchemaComponents { if componentConfig == nil { return nil @@ -36,19 +36,22 @@ func BuildFeatureSchema(componentConfig *ComponentConfig, responseConfig *Respon // 1. FS (Feature Store) addUniqueComponents(processFS(componentConfig.FeatureComponents)) - // 2. RTP (Real Time Pricing) - addUniqueComponents(processRTP(componentConfig.RTPComponents)) + // 2. RTP (Real Time Pricing) - Internal component, uses interface + addUniqueComponents(InternalSchemaProcessorInstance.ProcessRTP(componentConfig.RTPComponents)) - // 3. Numerix Output + // 3. SeenScore - Internal component, uses interface + addUniqueComponents(InternalSchemaProcessorInstance.ProcessSeenScore(componentConfig.SeenScoreComponents)) + + // 4. Numerix Output addUniqueComponents(processNumerixOutput(componentConfig.NumerixComponents)) - // 4. Predator Output + // 5. Predator Output addUniqueComponents(processPredatorOutput(componentConfig.PredatorComponents)) - // 5. Numerix Input (only add if not already present) + // 6. Numerix Input (only add if not already present) addOrUpdateComponents(processNumerixInput(componentConfig.NumerixComponents)) - // 6. Predator Input (only add if not already present) + // 7. Predator Input (only add if not already present) addOrUpdateComponents(processPredatorInput(componentConfig.PredatorComponents)) return response @@ -126,29 +129,6 @@ func processFS(featureComponents []FeatureComponent) []SchemaComponents { return response } -func processRTP(rtpComponents []RTPComponent) []SchemaComponents { - if len(rtpComponents) == 0 { - return nil - } - - var response []SchemaComponents - for _, rtpComponent := range rtpComponents { - if rtpComponent.FSRequest == nil { - continue - } - for _, featureGroup := range rtpComponent.FSRequest.FeatureGroups { - for _, feature := range featureGroup.Features { - response = append(response, SchemaComponents{ - FeatureName: getFeatureName(rtpComponent.ColNamePrefix, rtpComponent.FSRequest.Label, featureGroup.Label, feature), - FeatureType: featureGroup.DataType, - FeatureSize: 1, - }) - } - } - } - return response -} - func processPredatorOutput(predatorComponents []PredatorComponent) []SchemaComponents { if len(predatorComponents) == 0 { return nil diff --git a/horizon/pkg/configschemaclient/internal_processor.go b/horizon/pkg/configschemaclient/internal_processor.go new file mode 100644 index 00000000..8249b479 --- /dev/null +++ b/horizon/pkg/configschemaclient/internal_processor.go @@ -0,0 +1,19 @@ +package configschemaclient + +// InternalSchemaProcessor defines the interface for processing internal-only components +// (RTP, SeenScore) in the schema builder. +// +// For open-source builds (!meesho), a stub implementation returns empty results. +// For internal builds (meesho), the full implementation provides actual processing. +type InternalSchemaProcessor interface { + // ProcessRTP processes RTP components and returns schema components + ProcessRTP(rtpComponents []RTPComponent) []SchemaComponents + + // ProcessSeenScore processes SeenScore components and returns schema components + ProcessSeenScore(seenScoreComponents []SeenScoreComponent) []SchemaComponents +} + +// InternalSchemaProcessorInstance is the global instance of the internal schema processor. +// This is set by the init() function in either the stub or internal implementation file +// depending on build tags. +var InternalSchemaProcessorInstance InternalSchemaProcessor diff --git a/horizon/pkg/configschemaclient/internal_processor_stub.go b/horizon/pkg/configschemaclient/internal_processor_stub.go new file mode 100644 index 00000000..7675a697 --- /dev/null +++ b/horizon/pkg/configschemaclient/internal_processor_stub.go @@ -0,0 +1,21 @@ +//go:build !meesho + +package configschemaclient + +// internalSchemaProcessorStub is the stub implementation for open-source builds. +// It returns empty results for all internal component processing. +type internalSchemaProcessorStub struct{} + +func init() { + InternalSchemaProcessorInstance = &internalSchemaProcessorStub{} +} + +// ProcessRTP returns empty slice - no RTP processing in open-source builds +func (s *internalSchemaProcessorStub) ProcessRTP(rtpComponents []RTPComponent) []SchemaComponents { + return nil +} + +// ProcessSeenScore returns empty slice - no SeenScore processing in open-source builds +func (s *internalSchemaProcessorStub) ProcessSeenScore(seenScoreComponents []SeenScoreComponent) []SchemaComponents { + return nil +} diff --git a/horizon/pkg/configschemaclient/types.go b/horizon/pkg/configschemaclient/types.go index 6f63b7be..6be4fd43 100644 --- a/horizon/pkg/configschemaclient/types.go +++ b/horizon/pkg/configschemaclient/types.go @@ -9,13 +9,14 @@ type SchemaComponents struct { // ComponentConfig contains all component configurations type ComponentConfig struct { - CacheEnabled bool `json:"cache_enabled"` - CacheTTL int `json:"cache_ttl"` - CacheVersion int `json:"cache_version"` - FeatureComponents []FeatureComponent `json:"feature_components"` - RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` - PredatorComponents []PredatorComponent `json:"predator_components"` - NumerixComponents []NumerixComponent `json:"numerix_components"` + CacheEnabled bool `json:"cache_enabled"` + CacheTTL int `json:"cache_ttl"` + CacheVersion int `json:"cache_version"` + FeatureComponents []FeatureComponent `json:"feature_components"` + RTPComponents []RTPComponent `json:"real_time_pricing_feature_components,omitempty"` + SeenScoreComponents []SeenScoreComponent `json:"seen_score_components,omitempty"` + PredatorComponents []PredatorComponent `json:"predator_components"` + NumerixComponents []NumerixComponent `json:"numerix_components"` } // ResponseConfig contains response configuration @@ -62,6 +63,15 @@ type RTPComponent struct { CompCacheEnabled bool `json:"comp_cache_enabled"` } +// SeenScoreComponent represents a seen score component +type SeenScoreComponent struct { + Component string `json:"component"` + ComponentID string `json:"component_id"` + ColNamePrefix string `json:"col_name_prefix,omitempty"` + FSKeys []FSKey `json:"fs_keys"` + FSRequest *FSRequest `json:"fs_request"` +} + // PredatorComponent represents a Predator model component type PredatorComponent struct { Component string `json:"component"` From b845141eb076dc46cee2bd404116ef6c31fdb22d Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 15:26:44 +0530 Subject: [PATCH 11/24] removing redundant functions and error formatting --- .../externalcall/prometheus_client.go | 2 +- .../bulkdeletestrategy/predator_service.go | 42 +------------------ .../bulkdeletestrategy/strategy_selector.go | 3 -- horizon/internal/predator/handler/predator.go | 6 +-- 4 files changed, 5 insertions(+), 48 deletions(-) diff --git a/horizon/internal/externalcall/prometheus_client.go b/horizon/internal/externalcall/prometheus_client.go index a89a7c32..890ed3b3 100644 --- a/horizon/internal/externalcall/prometheus_client.go +++ b/horizon/internal/externalcall/prometheus_client.go @@ -198,7 +198,7 @@ func (p *prometheusClientImpl) GetPredatorModelTraffic(serviceName string, daysA if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("Prometheus call failed, status: %d, body: %s", resp.StatusCode, string(bodyBytes)) + return nil, fmt.Errorf("prometheus call failed, status: %d, body: %s", resp.StatusCode, string(bodyBytes)) } var pr PredatorModelResponse diff --git a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go index 7a9e582b..25dd921c 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go +++ b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go @@ -216,7 +216,7 @@ func (p *PredatorService) fetchModelNames( var metaData handler.MetaData if err := json.Unmarshal(pc.MetaData, &metaData); err != nil { - log.Err(err).Msg("Could not unmarshall model metadata for: " + pc.ModelName + " for scheduled deletion") + log.Err(err).Msg("could not unmarshall model metadata for: " + pc.ModelName + " for scheduled deletion") continue } @@ -425,46 +425,6 @@ func (p *PredatorService) deserializeDeployableConfig(serviceDeployable serviced return deployableConfig, nil } -func (p *PredatorService) processGCSAndDeleteModels(basePath string, inactiveModelNameList []string) []string { - srcBucket, srcPath := extractGCSPath(basePath) - destBucket, destPath := extractGCSPath(defaultModelPath) - - var deleteModelNameList []string - for _, inactiveModelName := range inactiveModelNameList { - err := p.gcsClient.TransferAndDeleteFolder(srcBucket, srcPath, inactiveModelName, destBucket, destPath, inactiveModelName) - if err != nil { - log.Error().Err(err).Msg("Error transferring and deleting folder in GCS") - continue - } - deleteModelNameList = append(deleteModelNameList, inactiveModelName) - } - - return deleteModelNameList -} - -func (p *PredatorService) deactivateModelsAndRestartDeployable(deleteModelNameList []string, serviceDeployable servicedeployableconfig.ServiceDeployableConfig, predatorConfig predatorconfig.PredatorConfigRepository, discoveryConfigId []int) error { - err := predatorConfig.BulkDeactivateByModelNames(deleteModelNameList, serviceDeployable.UpdatedBy, discoveryConfigId) - if err != nil { - log.Error().Err(err).Msg("Error deactivating models in predator config") - return err - } - - // Extract isCanary from deployable config - var deployableConfig map[string]interface{} - isCanary := false - if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err == nil { - if strategy, ok := deployableConfig["deploymentStrategy"].(string); ok && strategy == "canary" { - isCanary = true - } - } - if err := p.infrastructureHandler.RestartDeployment(serviceDeployable.Name, p.workingEnv, isCanary); err != nil { - log.Error().Err(err).Msg("Error restarting deployable") - return fmt.Errorf("failed to restart deployable: %w", err) - } - - return nil -} - func (p *PredatorService) sendSlackNotification(serviceDeployableName string, deleteModelNameList []string) error { err := p.slackClient.SendCleanupNotification(serviceDeployableName, deleteModelNameList) if err != nil { diff --git a/horizon/internal/jobs/bulkdeletestrategy/strategy_selector.go b/horizon/internal/jobs/bulkdeletestrategy/strategy_selector.go index 1a672991..8e0bcc5d 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/strategy_selector.go +++ b/horizon/internal/jobs/bulkdeletestrategy/strategy_selector.go @@ -30,7 +30,6 @@ type StrategySelectorImpl struct { var ( strategySelectorOnce sync.Once - defaultModelPath string bulkDeletePredatorEnabled bool bulkDeletePredatorMaxInactiveDays int bulkDeleteInferflowEnabled bool @@ -49,8 +48,6 @@ const ( func Init(config configs.Configs) StrategySelectorImpl { var strategySelectorImpl StrategySelectorImpl strategySelectorOnce.Do(func() { - defaultModelPath = config.DefaultModelPath - bulkDeletePredatorEnabled = config.BulkDeletePredatorEnabled bulkDeletePredatorMaxInactiveDays = config.BulkDeletePredatorMaxInactiveDays diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index a4af4a3e..93e3bcf2 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -282,12 +282,12 @@ func (p *Predator) HandleModelRequest(req ModelRequest, requestType string) (str for i, payload := range req.Payload { payloadBytes, err := json.Marshal(payload) if err != nil { - return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf(errMsgProcessPayload) + return constant.EmptyString, http.StatusInternalServerError, errors.New(errMsgProcessPayload) } var payloadObject Payload if err := json.Unmarshal(payloadBytes, &payloadObject); err != nil { - return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf(errMsgProcessPayload) + return constant.EmptyString, http.StatusInternalServerError, errors.New(errMsgProcessPayload) } derivedModelName, err := p.GetDerivedModelName(payloadObject, requestType) if err != nil { @@ -3403,7 +3403,7 @@ func (p *Predator) HandleEditModel(req ModelRequest, createdBy string) (string, var payloadObject Payload if err := json.Unmarshal(payloadBytes, &payloadObject); err != nil { - return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf(errMsgProcessPayload) + return constant.EmptyString, http.StatusInternalServerError, errors.New(errMsgProcessPayload) } if payloadObject.MetaData.InstanceCount > 1 && p.isNonProductionEnvironment() { From 2199ed25621e0167ff3afc6b9efd2410d2eb4a26 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 15:30:08 +0530 Subject: [PATCH 12/24] minor bug fix --- horizon/internal/jobs/bulkdeletestrategy/predator_service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go index 25dd921c..c47ee920 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go +++ b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go @@ -220,7 +220,7 @@ func (p *PredatorService) fetchModelNames( continue } - if (&metaData.Ensembling) != nil && metaData.Ensembling.Step != nil { + if len(metaData.Ensembling.Step) > 0 { for _, step := range metaData.Ensembling.Step { if step.ModelName != "" { parentToChildMapping[pc.ModelName] = append(parentToChildMapping[pc.ModelName], step.ModelName) From 8a8157a4bc49576cba6d31e5dc933fb9cbef9895 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 3 Feb 2026 15:50:48 +0530 Subject: [PATCH 13/24] further refractoring for coderabbit changes --- .../internal/inferflow/handler/inferflow.go | 6 +++- .../inferflow/handler/schema_adapter.go | 33 +++++++++++++++---- .../bulkdeletestrategy/predator_service.go | 15 +++++++-- quick-start/db-init/scripts/init-mysql.sh | 4 +++ 4 files changed, 48 insertions(+), 10 deletions(-) diff --git a/horizon/internal/inferflow/handler/inferflow.go b/horizon/internal/inferflow/handler/inferflow.go index 6e9e183d..3f7ba534 100644 --- a/horizon/internal/inferflow/handler/inferflow.go +++ b/horizon/internal/inferflow/handler/inferflow.go @@ -724,7 +724,11 @@ func (m *InferFlow) rollbackApprovedRequest(request ReviewRequest, fullTable *in } func (m *InferFlow) rollbackPromoteRequest(tx *gorm.DB, currentRequest *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { - if !configExistedBeforeTx { + if configExistedBeforeTx { + if err := m.rollbackEditRequest(tx, currentRequest, discoveryID); err != nil { + return err + } + } else { if err := m.rollbackCreatedConfigs(tx, currentRequest.ConfigID, discoveryID); err != nil { return err } diff --git a/horizon/internal/inferflow/handler/schema_adapter.go b/horizon/internal/inferflow/handler/schema_adapter.go index c0fbfa57..512de169 100644 --- a/horizon/internal/inferflow/handler/schema_adapter.go +++ b/horizon/internal/inferflow/handler/schema_adapter.go @@ -49,13 +49,14 @@ func toClientComponentConfig(config *inferflow.ComponentConfig) *configschemacli } return &configschemaclient.ComponentConfig{ - CacheEnabled: config.CacheEnabled, - CacheTTL: config.CacheTTL, - CacheVersion: config.CacheVersion, - FeatureComponents: toClientFeatureComponents(config.FeatureComponents), - RTPComponents: toClientRTPComponents(config.RTPComponents), - PredatorComponents: toClientPredatorComponents(config.PredatorComponents), - NumerixComponents: toClientNumerixComponents(config.NumerixComponents), + CacheEnabled: config.CacheEnabled, + CacheTTL: config.CacheTTL, + CacheVersion: config.CacheVersion, + FeatureComponents: toClientFeatureComponents(config.FeatureComponents), + RTPComponents: toClientRTPComponents(config.RTPComponents), + SeenScoreComponents: toClientSeenScoreComponents(config.SeenScoreComponents), + PredatorComponents: toClientPredatorComponents(config.PredatorComponents), + NumerixComponents: toClientNumerixComponents(config.NumerixComponents), } } @@ -135,6 +136,24 @@ func toClientRTPComponents(components []inferflow.RTPComponent) []configschemacl return result } +func toClientSeenScoreComponents(components []inferflow.SeenScoreComponent) []configschemaclient.SeenScoreComponent { + if len(components) == 0 { + return nil + } + + result := make([]configschemaclient.SeenScoreComponent, len(components)) + for i, c := range components { + result[i] = configschemaclient.SeenScoreComponent{ + Component: c.Component, + ComponentID: c.ComponentID, + ColNamePrefix: c.ColNamePrefix, + FSKeys: toClientFSKeys(c.FSKeys), + FSRequest: toClientFSRequest(c.FSRequest), + } + } + return result +} + func toClientPredatorComponents(components []inferflow.PredatorComponent) []configschemaclient.PredatorComponent { if len(components) == 0 { return nil diff --git a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go index c47ee920..47c236f6 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go +++ b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go @@ -92,7 +92,11 @@ func (p *PredatorService) ProcessBulkDelete(serviceDeployable servicedeployablec continue } - childConfig, _ := predatorBulkDeleteRepos.predatorConfigRepo.GetActiveModelByModelName(childName) + childConfig, err := predatorBulkDeleteRepos.predatorConfigRepo.GetActiveModelByModelName(childName) + if err != nil { + log.Warn().Err(err).Msgf("Failed to fetch child model config for %s", childName) + continue + } if childConfig != nil { modelsToDelete = append(modelsToDelete, ModelInfo{ ModelName: childName, @@ -120,7 +124,14 @@ func (p *PredatorService) ProcessBulkDelete(serviceDeployable servicedeployablec if len(deletedModels) > 0 { // Restart deployable after deletion - err = p.infrastructureHandler.RestartDeployment(serviceDeployable.Name, p.workingEnv, false) + isCanary := false + var deployableConfigMap map[string]interface{} + if err := json.Unmarshal(serviceDeployable.Config, &deployableConfigMap); err == nil { + if strategy, ok := deployableConfigMap["deploymentStrategy"].(string); ok && strategy == "canary" { + isCanary = true + } + } + err = p.infrastructureHandler.RestartDeployment(serviceDeployable.Name, p.workingEnv, isCanary) if err != nil { log.Error().Err(err).Msg("Error restarting deployable") } diff --git a/quick-start/db-init/scripts/init-mysql.sh b/quick-start/db-init/scripts/init-mysql.sh index bacd08d6..e20e335a 100644 --- a/quick-start/db-init/scripts/init-mysql.sh +++ b/quick-start/db-init/scripts/init-mysql.sh @@ -228,6 +228,7 @@ mysql -hmysql -uroot -proot --skip-ssl -e " created_at datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, test_results json, + source_config_id varchar(255) NULL, PRIMARY KEY (id), UNIQUE KEY config_id (config_id) ); @@ -308,6 +309,7 @@ mysql -hmysql -uroot -proot --skip-ssl -e " updated_at datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, test_results json, has_nil_data boolean DEFAULT false, + source_model_name varchar(255) NULL, PRIMARY KEY (id) ); @@ -387,6 +389,8 @@ mysql -hmysql -uroot -proot --skip-ssl -e " deployment_run_id varchar(255), deployable_health enum('DEPLOYMENT_REASON_ARGO_APP_HEALTH_DEGRADED', 'DEPLOYMENT_REASON_ARGO_APP_HEALTHY'), work_flow_status enum('WORKFLOW_COMPLETED','WORKFLOW_NOT_FOUND','WORKFLOW_RUNNING','WORKFLOW_FAILED','WORKFLOW_NOT_STARTED'), + override_testing TINYINT(1) DEFAULT 0, + deployable_tag varchar(255) NULL, PRIMARY KEY (id), UNIQUE KEY host (host) ); From 340983e1e3b7e134dfec75065bbeb9f27a34180c Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Wed, 4 Feb 2026 15:25:24 +0530 Subject: [PATCH 14/24] model name extraction fix --- horizon/internal/predator/handler/predator.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index 93e3bcf2..707c810f 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -164,9 +164,8 @@ const ( failedToParseServiceConfig = "Failed to parse service config" failedToCreateServiceDiscoveryAndConfig = "Failed to create service discovery and config" predatorInferMethod = "inference.GRPCInferenceService/ModelInfer" - deployableTagDelimiter = "_" - scaleupTag = "scaleup" - + deployableTagDelimiter = "_" + scaleupTag = "scaleup" ) func InitV1ConfigHandler() (Config, error) { @@ -555,10 +554,9 @@ func (p *Predator) FetchModelConfig(req FetchModelConfigRequest) (ModelParamsRes intBucket, intObjectPath := parseModelPath(req.ModelPath) metaDataPath := path.Join(intObjectPath, "metadata.json") - _, modelName := parseModelPath(intObjectPath) + modelName := path.Base(intObjectPath) intConfigPath := path.Join(intObjectPath, configFile) - // Read config.pbtxt var configData []byte var err error @@ -1249,8 +1247,7 @@ func (p *Predator) processGCSCloneStage(requestIdPayloadMap map[uint]*Payload, p log.Info().Msgf("Copying to target deployable - src: %s/%s/%s, dest: %s/%s/%s", srcBucket, srcPath, srcModelName, destBucket, destPath, destModelName) - - if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || + if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || srcModelName == constant.EmptyString || destBucket == constant.EmptyString || destPath == constant.EmptyString || destModelName == constant.EmptyString { log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) @@ -4140,4 +4137,4 @@ func (p *Predator) replaceModelNameInConfigPreservingFormat(data []byte, destMod } return []byte(strings.Join(lines, "\n")) -} \ No newline at end of file +} From 8ecb0eaac72d50c29fd42a593c4ea7aa606d74f7 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Wed, 4 Feb 2026 15:44:50 +0530 Subject: [PATCH 15/24] return error on no model files found --- horizon/internal/externalcall/gcs_client.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/horizon/internal/externalcall/gcs_client.go b/horizon/internal/externalcall/gcs_client.go index 99ec2a71..f697a57a 100644 --- a/horizon/internal/externalcall/gcs_client.go +++ b/horizon/internal/externalcall/gcs_client.go @@ -306,8 +306,7 @@ func (g *GCSClient) TransferFolderWithSplitSources(modelBucket, modelPath, confi len(configFiles), configBucket, configPrefix) if len(regularFiles) == 0 && len(configFiles) == 0 { - log.Warn().Msg("TransferFolderWithSplitSources: No objects found to transfer") - return nil + return fmt.Errorf("TransferFolderWithSplitSources: No objects found to transfer") } regularFilesTransferred := false From 649e89d61772ae4c7bcd15dedfcc803960cc409d Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Wed, 4 Feb 2026 15:46:05 +0530 Subject: [PATCH 16/24] capitilization --- horizon/internal/externalcall/gcs_client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horizon/internal/externalcall/gcs_client.go b/horizon/internal/externalcall/gcs_client.go index f697a57a..677c4f9b 100644 --- a/horizon/internal/externalcall/gcs_client.go +++ b/horizon/internal/externalcall/gcs_client.go @@ -306,7 +306,7 @@ func (g *GCSClient) TransferFolderWithSplitSources(modelBucket, modelPath, confi len(configFiles), configBucket, configPrefix) if len(regularFiles) == 0 && len(configFiles) == 0 { - return fmt.Errorf("TransferFolderWithSplitSources: No objects found to transfer") + return fmt.Errorf("transferFolderWithSplitSources: No objects found to transfer") } regularFilesTransferred := false From 4ec6b12fc21772d78ee692ff7da59248fc5f6244 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Thu, 5 Feb 2026 13:39:41 +0530 Subject: [PATCH 17/24] schema client refractor --- horizon/pkg/configschemaclient/client.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/horizon/pkg/configschemaclient/client.go b/horizon/pkg/configschemaclient/client.go index cc2f054c..53a11a53 100644 --- a/horizon/pkg/configschemaclient/client.go +++ b/horizon/pkg/configschemaclient/client.go @@ -4,6 +4,11 @@ import ( "strings" ) +const ( + DataTypeString = "String" + DefaultFeatureSize = 1 +) + // BuildFeatureSchema builds a feature schema from the component and response configs. // It processes components in order: FS → RTP → SeenScore → Numerix Output → Predator Output → Numerix Input → Predator Input func BuildFeatureSchema(componentConfig *ComponentConfig, responseConfig *ResponseConfig) []SchemaComponents { @@ -26,7 +31,7 @@ func BuildFeatureSchema(componentConfig *ComponentConfig, responseConfig *Respon addOrUpdateComponents := func(components []SchemaComponents) { for _, component := range components { if !existingFeatures[component.FeatureName] { - component.FeatureType = "String" + component.FeatureType = DataTypeString response = append(response, component) existingFeatures[component.FeatureName] = true } @@ -175,15 +180,15 @@ func processPredatorInput(predatorComponents []PredatorComponent) []SchemaCompon return response } -func getPredatorFeatureTypeAndSize(dataType string, shape []int) (int, string) { - if len(shape) == 1 && shape[0] == 1 { +func getPredatorFeatureTypeAndSize(dataType string, dims []int) (int, string) { + if len(dims) == 1 && dims[0] == 1 { return 1, dataType } - if len(shape) == 2 && shape[0] == -1 { - return shape[1], dataType + "Vector" + if len(dims) == 2 && dims[0] == -1 { + return dims[1], dataType + "Vector" } - if len(shape) > 0 { - return shape[0], dataType + "Vector" + if len(dims) > 0 { + return dims[0], dataType + "Vector" } return 1, dataType } @@ -212,8 +217,8 @@ func ProcessResponseConfig(responseConfig *ResponseConfig, schemaComponents []Sc } else { response = append(response, SchemaComponents{ FeatureName: feature, - FeatureType: "String", - FeatureSize: 1, + FeatureType: DataTypeString, + FeatureSize: DefaultFeatureSize, }) } } From 1303975dc36c1aaed991501b29de9390c3885c61 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Thu, 5 Feb 2026 17:00:52 +0530 Subject: [PATCH 18/24] int to preprod name convention --- horizon/internal/predator/handler/predator.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index 707c810f..824983c4 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -552,16 +552,16 @@ func (p *Predator) FetchModelConfig(req FetchModelConfigRequest) (ModelParamsRes return ModelParamsResponse{}, http.StatusBadRequest, err } - intBucket, intObjectPath := parseModelPath(req.ModelPath) - metaDataPath := path.Join(intObjectPath, "metadata.json") - modelName := path.Base(intObjectPath) - intConfigPath := path.Join(intObjectPath, configFile) + preProdBucket, preProdObjectPath := parseModelPath(req.ModelPath) + metaDataPath := path.Join(preProdObjectPath, "metadata.json") + modelName := path.Base(preProdObjectPath) + preProdConfigPath := path.Join(preProdObjectPath, configFile) // Read config.pbtxt var configData []byte var err error if p.isNonProductionEnvironment() { - configData, err = p.GcsClient.ReadFile(intBucket, intConfigPath) + configData, err = p.GcsClient.ReadFile(preProdBucket, preProdConfigPath) } else { prodConfigPath := path.Join(pred.GcsConfigBasePath, modelName, configFile) configData, err = p.GcsClient.ReadFile(pred.GcsConfigBucket, prodConfigPath) @@ -571,7 +571,7 @@ func (p *Predator) FetchModelConfig(req FetchModelConfigRequest) (ModelParamsRes } // Read feature_meta.json - metaData, err := p.GcsClient.ReadFile(intBucket, metaDataPath) + metaData, err := p.GcsClient.ReadFile(preProdBucket, metaDataPath) var featureMeta *FeatureMetadata if err == nil && metaData != nil { if err := json.Unmarshal(metaData, &featureMeta); err != nil { @@ -607,7 +607,7 @@ func (p *Predator) FetchModelConfig(req FetchModelConfigRequest) (ModelParamsRes outputs = []IO{} } - return createModelParamsResponse(&modelConfig, intObjectPath, inputs, outputs), http.StatusOK, nil + return createModelParamsResponse(&modelConfig, preProdObjectPath, inputs, outputs), http.StatusOK, nil } func validateModelPath(modelPath string) error { From 7135c9249d9259f29f5a57c56b5bb5e529b1eae3 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Mon, 9 Feb 2026 10:40:23 +0530 Subject: [PATCH 19/24] etcd name fixes and refractors --- horizon/internal/inferflow/etcd/etcd.go | 11 ++++------- horizon/internal/inferflow/etcd/models.go | 2 +- .../bulkdeletestrategy/predator_service.go | 11 +++++------ .../repositories/sql/predatorconfig/sql.go | 19 +++++++++---------- 4 files changed, 19 insertions(+), 24 deletions(-) diff --git a/horizon/internal/inferflow/etcd/etcd.go b/horizon/internal/inferflow/etcd/etcd.go index 026f7c6c..e3768fc4 100644 --- a/horizon/internal/inferflow/etcd/etcd.go +++ b/horizon/internal/inferflow/etcd/etcd.go @@ -94,7 +94,7 @@ func (e *Etcd) GetConfiguredEndpoints(serviceDeployableName string) mapset.Set[s return validEndpoints } - inferflowConfig, exists := instance.InferflowConfig[serviceDeployableName] + inferflowConfig, exists := instance.Services[serviceDeployableName] if !exists { log.Warn().Msgf("service '%s' not found in etcd registry", serviceDeployableName) return validEndpoints @@ -105,13 +105,10 @@ func (e *Etcd) GetConfiguredEndpoints(serviceDeployableName string) mapset.Set[s return validEndpoints } - endpoints := strings.Split(predatorHosts, commaDelimiter) - for i := range len(endpoints) { - cleanedEndpoint := strings.TrimSpace(endpoints[i]) - if cleanedEndpoint == "" { - continue + for _, endpoint := range strings.Split(predatorHosts, commaDelimiter) { + if cleanedEndpoint := strings.TrimSpace(endpoint); cleanedEndpoint != "" { + validEndpoints.Add(cleanedEndpoint) } - validEndpoints.Add(cleanedEndpoint) } return validEndpoints } diff --git a/horizon/internal/inferflow/etcd/models.go b/horizon/internal/inferflow/etcd/models.go index 0c2c5e50..46ab2606 100644 --- a/horizon/internal/inferflow/etcd/models.go +++ b/horizon/internal/inferflow/etcd/models.go @@ -1,7 +1,7 @@ package etcd type ModelConfigRegistery struct { - InferflowConfig map[string]InferflowConfigs `json:"services"` + Services map[string]InferflowConfigs `json:"services"` } type HorizonRegistry struct { diff --git a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go index 47c236f6..8176bbf6 100644 --- a/horizon/internal/jobs/bulkdeletestrategy/predator_service.go +++ b/horizon/internal/jobs/bulkdeletestrategy/predator_service.go @@ -29,9 +29,9 @@ type PredatorService struct { } const ( - slashConstant = "/" - gcsPrefix = "gs://" - bulkDeleteCreatedBy = "horizon-bulk-delete" + slashConstant = "/" + gcsPrefix = "gs://" + bulkDeleteCreatedBy = "horizon-bulk-delete" ) type ModelInfo struct { @@ -46,7 +46,6 @@ type PredatorBulkDeleteRepos struct { groupCounterRepo counter.GroupIdCounterRepository } - func (p *PredatorService) ProcessBulkDelete(serviceDeployable servicedeployableconfig.ServiceDeployableConfig) error { predatorBulkDeleteRepos, err := p.initializeRepositories() if err != nil { @@ -96,7 +95,7 @@ func (p *PredatorService) ProcessBulkDelete(serviceDeployable servicedeployablec if err != nil { log.Warn().Err(err).Msgf("Failed to fetch child model config for %s", childName) continue - } + } if childConfig != nil { modelsToDelete = append(modelsToDelete, ModelInfo{ ModelName: childName, @@ -201,7 +200,7 @@ func (p *PredatorService) fetchModelNames( } } - predatorConfigList, err := predatorBulkDeleteRepos.predatorConfigRepo.FindByDiscoveryIDsAndAge(discoveryConfigIds, zeroTrafficDays) + predatorConfigList, err := predatorBulkDeleteRepos.predatorConfigRepo.FindByDiscoveryConfigIdsAndAge(discoveryConfigIds, zeroTrafficDays) if err != nil { log.Err(err).Msg("Error fetching predator configs from DB") return nil, nil, nil, nil, err diff --git a/horizon/internal/repositories/sql/predatorconfig/sql.go b/horizon/internal/repositories/sql/predatorconfig/sql.go index 7a421639..5832759c 100644 --- a/horizon/internal/repositories/sql/predatorconfig/sql.go +++ b/horizon/internal/repositories/sql/predatorconfig/sql.go @@ -23,7 +23,7 @@ type PredatorConfigRepository interface { GetByModelName(modelName string) (*PredatorConfig, error) GetActiveModelByModelName(modelName string) (*PredatorConfig, error) GetActiveModelByModelNameList(modelNames []string) ([]PredatorConfig, error) - FindByDiscoveryIDsAndAge(discoveryConfigIds []int, daysAgo int) ([]PredatorConfig, error) + FindByDiscoveryConfigIdsAndAge(discoveryConfigIds []int, daysAgo int) ([]PredatorConfig, error) } type predatorConfigRepo struct { @@ -71,12 +71,12 @@ func (r *predatorConfigRepo) Update(config *PredatorConfig) error { return r.db.Model(config). Where("id = ?", config.ID). UpdateColumns(map[string]interface{}{ - "active": config.Active, - "meta_data": config.MetaData, - "updated_by": config.UpdatedBy, - "updated_at": config.UpdatedAt, - "test_results": config.TestResults, - "has_nil_data": config.HasNilData, + "active": config.Active, + "meta_data": config.MetaData, + "updated_by": config.UpdatedBy, + "updated_at": config.UpdatedAt, + "test_results": config.TestResults, + "has_nil_data": config.HasNilData, "source_model_name": config.SourceModelName, }).Error } @@ -152,8 +152,8 @@ func (r *predatorConfigRepo) GetByModelName(modelName string) (*PredatorConfig, return &config, err } -// FindByDiscoveryIDsAndAge returns active predator configs for given discovery IDs created before (now - daysAgo). -func (r *predatorConfigRepo) FindByDiscoveryIDsAndAge(discoveryConfigIds []int, daysAgo int) ([]PredatorConfig, error) { +// FindByDiscoveryConfigIdsAndAge returns active predator configs for given discovery IDs created before (now - daysAgo). +func (r *predatorConfigRepo) FindByDiscoveryConfigIdsAndAge(discoveryConfigIds []int, daysAgo int) ([]PredatorConfig, error) { var configs []PredatorConfig if daysAgo < 0 { return nil, errors.New("daysAgo must be >= 0") @@ -166,4 +166,3 @@ func (r *predatorConfigRepo) FindByDiscoveryIDsAndAge(discoveryConfigIds []int, return configs, err } - From a07f6f20938d190334f9fbd3336cfa6ae7f6f3a1 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Mon, 9 Feb 2026 19:20:20 +0530 Subject: [PATCH 20/24] refractors and dev toggle script fix --- dev-toggle-go.sh | 19 +- horizon/internal/externalcall/gcs_client.go | 41 ++- .../internal/externalcall/gcs_client_test.go | 127 +++++++ .../inferflow/handler/config_builder_test.go | 318 ++++++++++++++++++ .../inferflow/handler/inferflow_test.go | 153 +++++++++ horizon/internal/predator/handler/predator.go | 5 +- .../predator/handler/predator_test.go | 190 +++++++++++ 7 files changed, 829 insertions(+), 24 deletions(-) create mode 100644 horizon/internal/externalcall/gcs_client_test.go create mode 100644 horizon/internal/inferflow/handler/config_builder_test.go create mode 100644 horizon/internal/inferflow/handler/inferflow_test.go create mode 100644 horizon/internal/predator/handler/predator_test.go diff --git a/dev-toggle-go.sh b/dev-toggle-go.sh index 621120c5..2ba9ecf9 100755 --- a/dev-toggle-go.sh +++ b/dev-toggle-go.sh @@ -459,13 +459,16 @@ enable_dev_mode() { echo "# Folder: $FOLDER_NAME" >> "$STATE_FILE" log_debug "Created state file: $STATE_FILE" - # Find and copy all .go files from internal repo - log_info "Step 4: Searching for .go files to copy..." + # Find and copy all .go files and config files from internal repo + # Copy: .go sources and common config types (.yaml, .yml, .json, .env, .pbtxt) + log_info "Step 4: Searching for .go and config files to copy..." log_debug "Scanning directory: $INTERNAL_FOLDER_DIR" local copied_count=0 - local file_count=$(find "$INTERNAL_FOLDER_DIR" -name "*.go" -type f | wc -l) - log_info "Found $file_count .go file(s) to copy" + local find_patterns=(-name "*.go" -o -name "*.yaml" -o -name "*.yml" -o -name "*.json" -o -name "*.env" -o -name "*.pbtxt") + local file_count + file_count=$(find "$INTERNAL_FOLDER_DIR" -type f \( "${find_patterns[@]}" \) | wc -l) + log_info "Found $file_count file(s) to copy (.go and configs)" while IFS= read -r -d '' src_file; do # Get relative path from INTERNAL_FOLDER_DIR @@ -492,7 +495,7 @@ enable_dev_mode() { # Record in state file echo "FILE:$rel_path" >> "$STATE_FILE" ((copied_count++)) - done < <(find "$INTERNAL_FOLDER_DIR" -name "*.go" -type f -print0) + done < <(find "$INTERNAL_FOLDER_DIR" -type f \( "${find_patterns[@]}" \) -print0) log_info "Successfully copied $copied_count file(s)" @@ -582,6 +585,12 @@ enable_dev_mode() { echo " Files copied: $copied_count" echo " go.mod updated: $([ -f "$GO_MOD_APPEND_FILE" ] && echo "YES" || echo "NO")" echo " State file: $STATE_FILE" + if [ "$FOLDER_NAME" = "horizon" ]; then + echo "" + echo "To run tests including internal (meesho) config tests, use:" + echo " cd $TARGET_DIR && go test -tags=meesho ./..." + echo "Without -tags=meesho, only the standard tests run (internal test files are skipped)." + fi } disable_dev_mode() { diff --git a/horizon/internal/externalcall/gcs_client.go b/horizon/internal/externalcall/gcs_client.go index 677c4f9b..cec8a89b 100644 --- a/horizon/internal/externalcall/gcs_client.go +++ b/horizon/internal/externalcall/gcs_client.go @@ -111,6 +111,15 @@ func CreateGCSClient() GCSClientInterface { } } +// ObjectVisitor is called for each object. Return an error to stop iteration. +// Return a special sentinel error like ErrStopIteration to stop without error. +type ObjectVisitor func(attrs *storage.ObjectAttrs) error + +var ErrStopIteration = errors.New("stop iteration") + +// ObjectFilter returns true if the object should be included. +type ObjectFilter func(attrs *storage.ObjectAttrs) bool + func (g *GCSClient) ReadFile(bucket, objectPath string) ([]byte, error) { rc, err := g.client.Bucket(bucket).Object(objectPath).NewReader(g.ctx) if err != nil { @@ -343,8 +352,7 @@ func (g *GCSClient) transferRegularFilesFromSource(files []storage.ObjectAttrs, semaphore := make(chan struct{}, maxConcurrentFiles) var wg sync.WaitGroup - var mu sync.Mutex - var transferErrors []error + errChan := make(chan error, len(files)) for _, objAttrs := range files { wg.Add(1) @@ -354,17 +362,27 @@ func (g *GCSClient) transferRegularFilesFromSource(files []storage.ObjectAttrs, defer func() { <-semaphore }() if err := g.transferSingleRegularFile(obj, srcBucket, destBucket, destPath, destModelName, prefix); err != nil { - mu.Lock() - transferErrors = append(transferErrors, fmt.Errorf("failed to transfer %s: %w", obj.Name, err)) - mu.Unlock() + errChan <- fmt.Errorf("failed to transfer %s: %w", obj.Name, err) } }(objAttrs) } wg.Wait() - if len(transferErrors) > 0 { - return fmt.Errorf("regular file transfer completed with %d errors: %v", len(transferErrors), transferErrors[0]) + if errCount := len(errChan); errCount > 0 { + errs := make([]error, 0, errCount) + for i := 0; i < errCount; i++ { + errs = append(errs, <-errChan) + } + var b strings.Builder + b.WriteString(fmt.Sprintf("regular file transfer completed with %d errors:\n", len(errs))) + for i, e := range errs { + if i > 0 { + b.WriteString("\n") + } + b.WriteString(e.Error()) + } + return fmt.Errorf("%s", b.String()) } return nil @@ -753,12 +771,6 @@ func (g *GCSClient) FindFileWithSuffix(bucket, folderPath, suffix string) (bool, return true, foundFile, nil } -// ObjectVisitor is called for each object. Return an error to stop iteration. -// Return a special sentinel error like ErrStopIteration to stop without error. -type ObjectVisitor func(attrs *storage.ObjectAttrs) error - -var ErrStopIteration = errors.New("stop iteration") - // forEachObject iterates over all objects with the given prefix and calls the visitor for each. func (g *GCSClient) forEachObject(bucket, prefix string, visitor ObjectVisitor) error { it := g.client.Bucket(bucket).Objects(g.ctx, &storage.Query{Prefix: prefix}) @@ -781,9 +793,6 @@ func (g *GCSClient) forEachObject(bucket, prefix string, visitor ObjectVisitor) return nil } -// ObjectFilter returns true if the object should be included. -type ObjectFilter func(attrs *storage.ObjectAttrs) bool - // listObjects returns all objects matching the prefix, optionally filtered. // Pass nil for filter to include all objects (except directory markers). func (g *GCSClient) listObjects(bucket, prefix string, filter ObjectFilter) ([]storage.ObjectAttrs, error) { diff --git a/horizon/internal/externalcall/gcs_client_test.go b/horizon/internal/externalcall/gcs_client_test.go new file mode 100644 index 00000000..f92ed913 --- /dev/null +++ b/horizon/internal/externalcall/gcs_client_test.go @@ -0,0 +1,127 @@ +package externalcall + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReplaceModelNameInConfig(t *testing.T) { + tests := []struct { + name string + data []byte + destModelName string + expectContains string + }{ + { + name: "replaces top-level name only", + data: []byte(`name: "old_model" +instance_group { + name: "old_model" +} +`), + destModelName: "new_model", + expectContains: `name: "new_model"`, + }, + { + name: "preserves nested name with indentation", + data: []byte(`name: "top_level" + instance_group { + name: "nested_name" + } +`), + destModelName: "replaced", + expectContains: `name: "replaced"`, + }, + { + name: "single line config", + data: []byte(`name: "single_model"` + "\n"), + destModelName: "replaced_model", + expectContains: `name: "replaced_model"`, + }, + { + name: "no name field returns unchanged", + data: []byte(`platform: "tensorflow" +version: 1 +`), + destModelName: "any", + expectContains: `platform: "tensorflow"`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := replaceModelNameInConfig(tt.data, tt.destModelName) + assert.Contains(t, string(got), tt.expectContains) + }) + } +} + +func TestErrStopIteration(t *testing.T) { + assert.Error(t, ErrStopIteration) +} + +func TestGCSClient_NilClient_ListFolders(t *testing.T) { + g := &GCSClient{client: nil, ctx: context.Background()} + folders, err := g.ListFolders("bucket", "prefix/") + require.Error(t, err) + assert.Nil(t, folders) + assert.Contains(t, err.Error(), "not initialized") +} + +func TestGCSClient_NilClient_UploadFile(t *testing.T) { + g := &GCSClient{client: nil, ctx: context.Background()} + err := g.UploadFile("bucket", "path/obj", []byte("data")) + require.Error(t, err) + assert.Contains(t, err.Error(), "not initialized") +} + +func TestGCSClient_NilClient_CheckFileExists(t *testing.T) { + g := &GCSClient{client: nil, ctx: context.Background()} + exists, err := g.CheckFileExists("bucket", "path/obj") + require.Error(t, err) + assert.False(t, exists) + assert.Contains(t, err.Error(), "not initialized") +} + +func TestGCSClient_NilClient_CheckFolderExists(t *testing.T) { + g := &GCSClient{client: nil, ctx: context.Background()} + exists, err := g.CheckFolderExists("bucket", "folder/") + require.Error(t, err) + assert.False(t, exists) + assert.Contains(t, err.Error(), "not initialized") +} + +func TestGCSClient_NilClient_GetFolderInfo(t *testing.T) { + g := &GCSClient{client: nil, ctx: context.Background()} + info, err := g.GetFolderInfo("bucket", "folder/") + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "not initialized") +} + +func TestGCSClient_NilClient_ListFoldersWithTimestamp(t *testing.T) { + g := &GCSClient{client: nil, ctx: context.Background()} + folders, err := g.ListFoldersWithTimestamp("bucket", "prefix/") + require.Error(t, err) + assert.Nil(t, folders) + assert.Contains(t, err.Error(), "not initialized") +} + +func TestGCSClient_NilClient_FindFileWithSuffix(t *testing.T) { + g := &GCSClient{client: nil, ctx: context.Background()} + exists, name, err := g.FindFileWithSuffix("bucket", "folder/", ".pbtxt") + require.Error(t, err) + assert.False(t, exists) + assert.Empty(t, name) + assert.Contains(t, err.Error(), "not initialized") +} + +func TestGCSFolderInfo_ZeroValue(t *testing.T) { + var info GCSFolderInfo + assert.Empty(t, info.Name) + assert.Empty(t, info.Path) + assert.Zero(t, info.FileCount) + assert.Zero(t, info.Size) +} diff --git a/horizon/internal/inferflow/handler/config_builder_test.go b/horizon/internal/inferflow/handler/config_builder_test.go new file mode 100644 index 00000000..9ddaa8d1 --- /dev/null +++ b/horizon/internal/inferflow/handler/config_builder_test.go @@ -0,0 +1,318 @@ +package handler + +import ( + "testing" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractEntityIDs(t *testing.T) { + request := InferflowOnboardRequest{ + Payload: OnboardPayload{ + Rankers: []Ranker{ + {EntityID: []string{"user", "item"}}, + {EntityID: []string{"user"}}, + }, + ReRankers: []ReRanker{ + {EntityID: []string{"session"}}, + }, + }, + } + got := extractEntityIDs(request) + // extractEntityIDs sets each entityID to false; keys indicate which entity IDs were seen + assert.Contains(t, got, "user:item") + assert.Contains(t, got, "user") + assert.Contains(t, got, "session") + assert.Len(t, got, 3) + assert.False(t, got["user:item"]) + assert.False(t, got["user"]) + assert.False(t, got["session"]) +} + +func TestExtractEntityIDs_Empty(t *testing.T) { + request := InferflowOnboardRequest{Payload: OnboardPayload{}} + got := extractEntityIDs(request) + assert.Empty(t, got) +} + +func TestTransformFeature(t *testing.T) { + tests := []struct { + name string + feature string + wantFeature string + wantType string + wantErr bool + }{ + {"invalid - single part", "onlyone", "", featureClassInvalid, true}, + {"default feature", "DEFAULT|foo", "foo", featureClassDefault, false}, + {"model feature", "MODEL|bar", "bar", featureClassModel, false}, + {"online feature", "ONLINE|baz", "baz", featureClassOnline, false}, + {"offline feature", "OFFLINE|qux", "qux", featureClassOffline, false}, + {"pctr_calibration", "PCTR_CALIBRATION|pctr", "pctr_calibration:pctr", featureClassPCTRCalibration, false}, + {"pcvr_calibration", "PCVR_CALIBRATION|pcvr", "pcvr_calibration:pcvr", featureClassPCVRCalibration, false}, + {"parent default", "PARENT_DEFAULT_FEATURE|pf", "parent:pf", featureClassDefault, false}, + {"parent online", "PARENT_ONLINE_FEATURE|po", "parent:po", featureClassOnline, false}, + {"fallback default", "UNKNOWN|x", "x", featureClassDefault, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotFeature, gotType, err := transformFeature(tt.feature) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantFeature, gotFeature) + assert.Equal(t, tt.wantType, gotType) + }) + } +} + +func TestAddFeatureToSet(t *testing.T) { + defaultFeatures := mapset.NewSet[string]() + modelFeatures := mapset.NewSet[string]() + onlineFeatures := mapset.NewSet[string]() + offlineFeatures := mapset.NewSet[string]() + pctrCalibrationFeatures := mapset.NewSet[string]() + pcvrCalibrationFeatures := mapset.NewSet[string]() + + err := AddFeatureToSet(&defaultFeatures, &modelFeatures, &onlineFeatures, &offlineFeatures, &pctrCalibrationFeatures, &pcvrCalibrationFeatures, "f1", featureClassDefault) + require.NoError(t, err) + assert.True(t, defaultFeatures.Contains("f1")) + + err = AddFeatureToSet(&defaultFeatures, &modelFeatures, &onlineFeatures, &offlineFeatures, &pctrCalibrationFeatures, &pcvrCalibrationFeatures, "f2", featureClassOnline) + require.NoError(t, err) + assert.True(t, onlineFeatures.Contains("f2")) + + // duplicate in same set is allowed (Add is idempotent); adding to different set with same feature name should error + err = AddFeatureToSet(&defaultFeatures, &modelFeatures, &onlineFeatures, &offlineFeatures, &pctrCalibrationFeatures, &pcvrCalibrationFeatures, "f1", featureClassModel) + require.Error(t, err) + assert.Contains(t, err.Error(), "already exists") + + // invalid feature type + err = AddFeatureToSet(&defaultFeatures, &modelFeatures, &onlineFeatures, &offlineFeatures, &pctrCalibrationFeatures, &pcvrCalibrationFeatures, "x", "invalid_type") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid feature type") +} + +func TestGetComponentList(t *testing.T) { + features := mapset.NewSet[string]() + features.Add("parent_label:group1:feat1") + features.Add("online:group2:feat2") + pctr := mapset.NewSet[string]() + pctr.Add("pctr_calibration:label1:g1") + pcvr := mapset.NewSet[string]() + pcvr.Add("pcvr_calibration:label2:g2") + + got := getComponentList(features, pctr, pcvr) + assert.True(t, got.Contains("parent_label")) + assert.True(t, got.Contains("online")) + assert.True(t, got.Contains("pctr_calibration_label1")) + assert.True(t, got.Contains("pcvr_calibration_label2")) +} + +func TestGetComponentList_Empty(t *testing.T) { + got := getComponentList(mapset.NewSet[string](), nil, nil) + assert.True(t, got.IsEmpty()) +} + +func TestGetResponseConfigs(t *testing.T) { + request := &InferflowOnboardRequest{ + Payload: OnboardPayload{ + Response: ResponseConfig{ + PrismLoggingPerc: 10, + RankerSchemaFeaturesInResponsePerc: 20, + ResponseFeatures: []string{"f1"}, + LogSelectiveFeatures: true, + LogBatchSize: 100, + LoggingTTL: 30, + }, + }, + } + got, err := GetResponseConfigs(request) + require.NoError(t, err) + assert.Equal(t, 10, got.LoggingPerc) + assert.Equal(t, 20, got.ModelSchemaPerc) + assert.Equal(t, []string{"f1"}, got.Features) + assert.True(t, got.LogSelectiveFeatures) + assert.Equal(t, 100, got.LogBatchSize) + assert.Equal(t, 30, got.LoggingTTL) +} + +func TestGetComponentConfig(t *testing.T) { + featureComponents := []FeatureComponent{{Component: "fc1"}} + rtpComponents := []RTPComponent{} + seenScoreComponents := []SeenScoreComponent{} + numerixComponents := []NumerixComponent{{Component: "i1"}} + predatorComponents := []PredatorComponent{{Component: "p1"}} + + got, err := GetComponentConfig(featureComponents, rtpComponents, seenScoreComponents, numerixComponents, predatorComponents) + require.NoError(t, err) + require.NotNil(t, got) + assert.True(t, got.CacheEnabled) + assert.Equal(t, 300, got.CacheTTL) + assert.Equal(t, 1, got.CacheVersion) + assert.Len(t, got.FeatureComponents, 1) + assert.Len(t, got.PredatorComponents, 1) + assert.Len(t, got.NumerixComponents, 1) +} + +func TestGetPredatorComponents_Simple(t *testing.T) { + request := InferflowOnboardRequest{ + Payload: OnboardPayload{ + Rankers: []Ranker{ + { + ModelName: "m1", + EndPoint: "ep1", + EntityID: []string{"user"}, + Inputs: []Input{{Name: "in1", Features: []string{"ONLINE|f1"}}}, + Outputs: []Output{{Name: "out1", ModelScores: []string{"score1"}, ModelScoresDims: [][]int{{1}}, DataType: "Float"}}, + }, + }, + }, + } + offlineToOnlineMapping := map[string]string{} + + got, err := GetPredatorComponents(request, offlineToOnlineMapping) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, "p1", got[0].Component) + assert.Equal(t, "user", got[0].ComponentID) + assert.Equal(t, "m1", got[0].ModelName) + assert.Equal(t, "ep1", got[0].ModelEndPoint) + assert.Len(t, got[0].Inputs, 1) + assert.Len(t, got[0].Outputs, 1) +} + +func TestGetPredatorComponents_RoutingConfig_Invalid(t *testing.T) { + request := InferflowOnboardRequest{ + Payload: OnboardPayload{ + Rankers: []Ranker{ + { + ModelName: "m1", + EndPoint: "ep1", + EntityID: []string{"user"}, + RoutingConfig: []RoutingConfig{{ModelName: "", ModelEndpoint: "ep"}}, + }, + }, + }, + } + _, err := GetPredatorComponents(request, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "routing config") +} + +func TestGetFeatureLabelToPrefixToFeatureGroupToFeatureMap(t *testing.T) { + features := []string{ + "label1:group1:feat1", + "label1:group1:feat2", + "parent_label:group2:feat3", + } + got := GetFeatureLabelToPrefixToFeatureGroupToFeatureMap(features) + require.Contains(t, got, "label1") + require.Contains(t, got["label1"], "") + assert.True(t, got["label1"][""]["group1"].Contains("feat1")) + assert.True(t, got["label1"][""]["group1"].Contains("feat2")) + require.Contains(t, got, "parent_label") + assert.True(t, got["parent_label"][""]["group2"].Contains("feat3")) +} + +func TestGetFeatureLabelToPrefixToFeatureGroupToFeatureMap_Empty(t *testing.T) { + got := GetFeatureLabelToPrefixToFeatureGroupToFeatureMap(nil) + assert.Empty(t, got) + got = GetFeatureLabelToPrefixToFeatureGroupToFeatureMap([]string{}) + assert.Empty(t, got) +} + +func TestGetFeatureLabelToPrefixToFeatureGroupToFeatureMap_SkipsInvalidParts(t *testing.T) { + features := []string{ + "onlytwo", + "a:b:c:d", // 4 parts -> prefix, label, group, feature + } + got := GetFeatureLabelToPrefixToFeatureGroupToFeatureMap(features) + // "onlytwo" has 1 part after split by :, skipped (len != 3 && != 4) + assert.Contains(t, got, "b") // 4 parts: prefix=a, label=b, group=c, feature=d + assert.True(t, got["b"]["a"]["c"].Contains("d")) +} + +func TestGetNumerixComponents_Simple(t *testing.T) { + request := InferflowOnboardRequest{ + Payload: OnboardPayload{ + ReRankers: []ReRanker{ + { + EqID: 1, + Score: "score1", + DataType: "Float", + EntityID: []string{"user"}, + EqVariables: map[string]string{ + "x": "ONLINE|feat1", + }, + }, + }, + }, + } + offlineMapping := map[string]string{} + predatorOutputsToDataType := map[string]string{} + featureToDataType := map[string]string{"feat1": "Float"} + + got, err := GetNumerixComponents(request, offlineMapping, predatorOutputsToDataType, featureToDataType) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, "i1", got[0].Component) + assert.Equal(t, "score1", got[0].ScoreCol) + assert.Equal(t, "1", got[0].ComputeID) + assert.Contains(t, got[0].ScoreMapping, "x@DataTypeFloat") +} + +func TestGetNumerixScoreMapping_OfflineNotFound(t *testing.T) { + eqVariables := map[string]string{"k": "OFFLINE|off_feat"} + offlineMapping := map[string]string{} // no mapping for off_feat + featureToDataType := map[string]string{"off_feat": "Float"} // set so we reach offline-mapping check + _, err := getNumerixScoreMapping(eqVariables, offlineMapping, nil, featureToDataType) + require.Error(t, err) + assert.Contains(t, err.Error(), "offlineToOnlineMapping") +} + +func TestGetNumerixScoreMapping_DataTypeNotFound(t *testing.T) { + eqVariables := map[string]string{"k": "ONLINE|feat1"} + featureToDataType := map[string]string{} // no dtype for feat1 + predatorOutputsToDataType := map[string]string{} + _, err := getNumerixScoreMapping(eqVariables, nil, predatorOutputsToDataType, featureToDataType) + require.Error(t, err) + assert.Contains(t, err.Error(), "data type") +} + +func TestGetPredatorInputFeaturesList_InvalidFeature(t *testing.T) { + _, err := getPredatorInputFeaturesList([]string{"invalid_single_part"}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "transforming feature") +} + +func TestGetPredatorInputFeaturesList_OfflineMapping(t *testing.T) { + features := []string{"OFFLINE|off_feat"} + mapping := map[string]string{"off_feat": "online_feat"} + got, err := getPredatorInputFeaturesList(features, mapping) + require.NoError(t, err) + assert.Equal(t, []string{"online_feat"}, got) +} + +func TestExtractFeatures(t *testing.T) { + entityIDs := map[string]bool{"user:item": true} + request := InferflowOnboardRequest{ + Payload: OnboardPayload{ + Rankers: []Ranker{ + { + Inputs: []Input{{Features: []string{"ONLINE|f1", "DEFAULT|f2"}}}, + Outputs: []Output{{Name: "o1", ModelScores: []string{"ms1"}, DataType: "Float"}}, + }, + }, + }, + } + features, featureToDataType, predatorOutputsToDataType := extractFeatures(request, entityIDs) + assert.True(t, features.Contains("ONLINE|f1")) + assert.True(t, features.Contains("DEFAULT|f2")) + assert.Contains(t, featureToDataType, "ONLINE|f1") + assert.Contains(t, predatorOutputsToDataType, "ms1") +} diff --git a/horizon/internal/inferflow/handler/inferflow_test.go b/horizon/internal/inferflow/handler/inferflow_test.go new file mode 100644 index 00000000..8e5fc9f3 --- /dev/null +++ b/horizon/internal/inferflow/handler/inferflow_test.go @@ -0,0 +1,153 @@ +package handler + +import ( + "errors" + "testing" + + service_deployable_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInferFlow_GetLoggingTTL(t *testing.T) { + m := &InferFlow{} + got, err := m.GetLoggingTTL() + require.NoError(t, err) + require.NotNil(t, got.Data) + assert.Equal(t, []int{30, 60, 90}, got.Data) +} + +func TestInferFlow_GenerateFunctionalTestRequest_InvalidBatchSize(t *testing.T) { + m := &InferFlow{} + req := GenerateRequestFunctionalTestingRequest{ + Entity: "user", + BatchSize: "not_a_number", + ModelConfigID: "cfg-1", + DefaultFeatures: map[string]string{}, + } + _, err := m.GenerateFunctionalTestRequest(req) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid batch size") +} + +func TestInferFlow_GenerateFunctionalTestRequest_Valid(t *testing.T) { + m := &InferFlow{} + req := GenerateRequestFunctionalTestingRequest{ + Entity: "user", + BatchSize: "5", + ModelConfigID: "cfg-1", + DefaultFeatures: map[string]string{"f1": "v1"}, + MetaData: map[string]string{"k": "v"}, + } + got, err := m.GenerateFunctionalTestRequest(req) + require.NoError(t, err) + assert.Equal(t, "cfg-1", got.RequestBody.ModelConfigID) + assert.Len(t, got.RequestBody.Entities, 1) + assert.Equal(t, "user_id", got.RequestBody.Entities[0].Entity) + assert.Len(t, got.RequestBody.Entities[0].Ids, 5) + assert.Equal(t, "v", got.MetaData["k"]) + // Default feature f1=v1 should produce one FeatureValue with 5 elements + var found bool + for _, fv := range got.RequestBody.Entities[0].Features { + if fv.Name == "f1" { + found = true + assert.Len(t, fv.IdsFeatureValue, 5) + for _, v := range fv.IdsFeatureValue { + assert.Equal(t, "v1", v) + } + break + } + } + assert.True(t, found) +} + +func TestInferFlow_batchFetchDiscoveryConfigs_Empty(t *testing.T) { + m := &InferFlow{} + discoveryMap, serviceDeployableMap, err := m.batchFetchDiscoveryConfigs(nil) + require.NoError(t, err) + assert.Empty(t, discoveryMap) + assert.Empty(t, serviceDeployableMap) +} + +func TestInferFlow_GetDerivedConfigID_EmptyDeployableTag(t *testing.T) { + mockRepo := &mockServiceDeployableRepo{ + getById: func(id int) (*service_deployable_config.ServiceDeployableConfig, error) { + return &service_deployable_config.ServiceDeployableConfig{ + ID: id, + Name: "svc", + DeployableTag: "", + }, nil + }, + } + m := &InferFlow{ServiceDeployableConfigRepo: mockRepo} + got, err := m.GetDerivedConfigID("base-config", 1) + require.NoError(t, err) + assert.Equal(t, "base-config", got) +} + +func TestInferFlow_GetDerivedConfigID_WithDeployableTag(t *testing.T) { + mockRepo := &mockServiceDeployableRepo{ + getById: func(id int) (*service_deployable_config.ServiceDeployableConfig, error) { + return &service_deployable_config.ServiceDeployableConfig{ + ID: id, + DeployableTag: "tag1", + }, nil + }, + } + m := &InferFlow{ServiceDeployableConfigRepo: mockRepo} + got, err := m.GetDerivedConfigID("base-config", 1) + require.NoError(t, err) + assert.Equal(t, "base-config_tag1_scaleup", got) +} + +func TestInferFlow_GetDerivedConfigID_RepoError(t *testing.T) { + mockRepo := &mockServiceDeployableRepo{ + getById: func(id int) (*service_deployable_config.ServiceDeployableConfig, error) { + return nil, errors.New("db error") + }, + } + m := &InferFlow{ServiceDeployableConfigRepo: mockRepo} + _, err := m.GetDerivedConfigID("base-config", 1) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to fetch") +} + +// mockServiceDeployableRepo implements service_deployable_config.ServiceDeployableRepository for tests. +type mockServiceDeployableRepo struct { + getById func(id int) (*service_deployable_config.ServiceDeployableConfig, error) +} + +func (m *mockServiceDeployableRepo) GetById(id int) (*service_deployable_config.ServiceDeployableConfig, error) { + if m.getById != nil { + return m.getById(id) + } + return nil, errors.New("not implemented") +} + +func (m *mockServiceDeployableRepo) Create(_ *service_deployable_config.ServiceDeployableConfig) error { + return nil +} +func (m *mockServiceDeployableRepo) Update(_ *service_deployable_config.ServiceDeployableConfig) error { + return nil +} +func (m *mockServiceDeployableRepo) DeactivateServiceDeployable(_ int, _ string) error { + return nil +} +func (m *mockServiceDeployableRepo) GetByService(_ string) ([]service_deployable_config.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *mockServiceDeployableRepo) GetAllActive() ([]service_deployable_config.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *mockServiceDeployableRepo) GetByWorkflowStatus(_ string) ([]service_deployable_config.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *mockServiceDeployableRepo) GetByDeployableHealth(_ string) ([]service_deployable_config.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *mockServiceDeployableRepo) GetByNameAndService(_, _ string) (*service_deployable_config.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *mockServiceDeployableRepo) GetByIds(_ []int) ([]service_deployable_config.ServiceDeployableConfig, error) { + return nil, nil +} diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index 824983c4..0356b6f9 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -1304,9 +1304,8 @@ func (p *Predator) processGCSCloneStageIndefaultFolder(requestIdPayloadMap map[u srcPath := pred.GcsModelBasePath srcModelName := originalModelName - log.Info().Msgf("Scale-up: Copying within model-source %s → %s", srcModelName, destModelName) - log.Info().Msgf("srcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", - srcBucket, srcPath, srcModelName, destBucket, destPath) + log.Info().Msgf("Scale-up: Copying within model-source %s → %s:\nsrcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", + srcModelName, destModelName, srcBucket, srcPath, srcModelName, destBucket, destPath) if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || srcModelName == constant.EmptyString || destBucket == constant.EmptyString || diff --git a/horizon/internal/predator/handler/predator_test.go b/horizon/internal/predator/handler/predator_test.go new file mode 100644 index 00000000..9ae81749 --- /dev/null +++ b/horizon/internal/predator/handler/predator_test.go @@ -0,0 +1,190 @@ +package handler + +import ( + "errors" + "net/http" + "testing" + + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPredator_ValidateRequest_InvalidGroupIDFormat(t *testing.T) { + p := &Predator{} + msg, code := p.ValidateRequest("not_a_number") + assert.Equal(t, "Invalid request ID format", msg) + assert.Equal(t, http.StatusBadRequest, code) +} + +func TestPredator_ReplaceModelNameInConfigPreservingFormat(t *testing.T) { + p := &Predator{} + tests := []struct { + name string + data []byte + destModelName string + wantContains string + }{ + { + name: "replaces top-level name", + data: []byte("name: \"old_model\"\n"), + destModelName: "new_model", + wantContains: "name: \"new_model\"", + }, + { + name: "preserves nested indented name", + data: []byte(`name: "top" + name: "nested" +`), + destModelName: "replaced", + wantContains: "name: \"replaced\"", + }, + { + name: "no name field unchanged", + data: []byte("platform: \"tensorflow\"\n"), + destModelName: "any", + wantContains: "platform: \"tensorflow\"", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.replaceModelNameInConfigPreservingFormat(tt.data, tt.destModelName) + assert.Contains(t, string(got), tt.wantContains) + }) + } +} + +func TestPredator_GetDerivedModelName_NonScaleUp(t *testing.T) { + p := &Predator{} + payload := Payload{ModelName: "my_model", ConfigMapping: ConfigMapping{ServiceDeployableID: 1}} + got, err := p.GetDerivedModelName(payload, OnboardRequestType) + require.NoError(t, err) + assert.Equal(t, "my_model", got) +} + +func TestPredator_GetDerivedModelName_ScaleUp_EmptyTag(t *testing.T) { + mockRepo := &predatorMockServiceDeployableRepo{ + getById: func(id int) (*servicedeployableconfig.ServiceDeployableConfig, error) { + return &servicedeployableconfig.ServiceDeployableConfig{ + ID: id, + DeployableTag: "", + }, nil + }, + } + p := &Predator{ServiceDeployableRepo: mockRepo} + payload := Payload{ModelName: "base_model", ConfigMapping: ConfigMapping{ServiceDeployableID: 1}} + got, err := p.GetDerivedModelName(payload, ScaleUpRequestType) + require.NoError(t, err) + assert.Equal(t, "base_model", got) +} + +func TestPredator_GetDerivedModelName_ScaleUp_WithTag(t *testing.T) { + mockRepo := &predatorMockServiceDeployableRepo{ + getById: func(id int) (*servicedeployableconfig.ServiceDeployableConfig, error) { + return &servicedeployableconfig.ServiceDeployableConfig{ + ID: id, + DeployableTag: "tag1", + }, nil + }, + } + p := &Predator{ServiceDeployableRepo: mockRepo} + payload := Payload{ModelName: "base_model", ConfigMapping: ConfigMapping{ServiceDeployableID: 1}} + got, err := p.GetDerivedModelName(payload, ScaleUpRequestType) + require.NoError(t, err) + assert.Equal(t, "base_model_tag1_scaleup", got) +} + +func TestPredator_GetDerivedModelName_ScaleUp_RepoError(t *testing.T) { + mockRepo := &predatorMockServiceDeployableRepo{ + getById: func(id int) (*servicedeployableconfig.ServiceDeployableConfig, error) { + return nil, errors.New("db error") + }, + } + p := &Predator{ServiceDeployableRepo: mockRepo} + payload := Payload{ModelName: "base_model", ConfigMapping: ConfigMapping{ServiceDeployableID: 1}} + _, err := p.GetDerivedModelName(payload, ScaleUpRequestType) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to fetch") +} + +func TestPredator_GetOriginalModelName_EmptyTag(t *testing.T) { + mockRepo := &predatorMockServiceDeployableRepo{ + getById: func(id int) (*servicedeployableconfig.ServiceDeployableConfig, error) { + return &servicedeployableconfig.ServiceDeployableConfig{ + ID: id, + DeployableTag: "", + }, nil + }, + } + p := &Predator{ServiceDeployableRepo: mockRepo} + got, err := p.GetOriginalModelName("derived_model", 1) + require.NoError(t, err) + assert.Equal(t, "derived_model", got) +} + +func TestPredator_GetOriginalModelName_WithTag(t *testing.T) { + mockRepo := &predatorMockServiceDeployableRepo{ + getById: func(id int) (*servicedeployableconfig.ServiceDeployableConfig, error) { + return &servicedeployableconfig.ServiceDeployableConfig{ + ID: id, + DeployableTag: "tag1", + }, nil + }, + } + p := &Predator{ServiceDeployableRepo: mockRepo} + got, err := p.GetOriginalModelName("base_model_tag1_scaleup", 1) + require.NoError(t, err) + assert.Equal(t, "base_model", got) +} + +func TestPredator_GetOriginalModelName_RepoError(t *testing.T) { + mockRepo := &predatorMockServiceDeployableRepo{ + getById: func(id int) (*servicedeployableconfig.ServiceDeployableConfig, error) { + return nil, errors.New("db error") + }, + } + p := &Predator{ServiceDeployableRepo: mockRepo} + _, err := p.GetOriginalModelName("any", 1) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to fetch") +} + +// predatorMockServiceDeployableRepo implements servicedeployableconfig.ServiceDeployableRepository for tests. +type predatorMockServiceDeployableRepo struct { + getById func(id int) (*servicedeployableconfig.ServiceDeployableConfig, error) +} + +func (m *predatorMockServiceDeployableRepo) GetById(id int) (*servicedeployableconfig.ServiceDeployableConfig, error) { + if m.getById != nil { + return m.getById(id) + } + return nil, errors.New("not implemented") +} + +func (m *predatorMockServiceDeployableRepo) Create(_ *servicedeployableconfig.ServiceDeployableConfig) error { + return nil +} +func (m *predatorMockServiceDeployableRepo) Update(_ *servicedeployableconfig.ServiceDeployableConfig) error { + return nil +} +func (m *predatorMockServiceDeployableRepo) DeactivateServiceDeployable(_ int, _ string) error { + return nil +} +func (m *predatorMockServiceDeployableRepo) GetByService(_ string) ([]servicedeployableconfig.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *predatorMockServiceDeployableRepo) GetAllActive() ([]servicedeployableconfig.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *predatorMockServiceDeployableRepo) GetByWorkflowStatus(_ string) ([]servicedeployableconfig.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *predatorMockServiceDeployableRepo) GetByDeployableHealth(_ string) ([]servicedeployableconfig.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *predatorMockServiceDeployableRepo) GetByNameAndService(_, _ string) (*servicedeployableconfig.ServiceDeployableConfig, error) { + return nil, nil +} +func (m *predatorMockServiceDeployableRepo) GetByIds(_ []int) ([]servicedeployableconfig.ServiceDeployableConfig, error) { + return nil, nil +} From 16aeb6caa75b79c6dbbb103ec44b346649c481b1 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 10 Feb 2026 11:34:00 +0530 Subject: [PATCH 21/24] predator handler refractor and gcs client minor fixes --- horizon/internal/externalcall/gcs_client.go | 12 + horizon/internal/predator/README.md | 36 + horizon/internal/predator/handler/predator.go | 2855 ++--------------- .../predator/handler/predator_approval.go | 842 +++++ .../predator/handler/predator_constants.go | 115 + .../predator/handler/predator_fetch.go | 166 + .../handler/predator_functional_testing.go | 220 ++ .../predator/handler/predator_helpers.go | 108 + .../predator/handler/predator_upload.go | 532 +++ .../predator/handler/predator_validation.go | 502 +++ 10 files changed, 2713 insertions(+), 2675 deletions(-) create mode 100644 horizon/internal/predator/README.md create mode 100644 horizon/internal/predator/handler/predator_approval.go create mode 100644 horizon/internal/predator/handler/predator_constants.go create mode 100644 horizon/internal/predator/handler/predator_fetch.go create mode 100644 horizon/internal/predator/handler/predator_functional_testing.go create mode 100644 horizon/internal/predator/handler/predator_helpers.go create mode 100644 horizon/internal/predator/handler/predator_upload.go create mode 100644 horizon/internal/predator/handler/predator_validation.go diff --git a/horizon/internal/externalcall/gcs_client.go b/horizon/internal/externalcall/gcs_client.go index cec8a89b..08a38ef5 100644 --- a/horizon/internal/externalcall/gcs_client.go +++ b/horizon/internal/externalcall/gcs_client.go @@ -773,6 +773,10 @@ func (g *GCSClient) FindFileWithSuffix(bucket, folderPath, suffix string) (bool, // forEachObject iterates over all objects with the given prefix and calls the visitor for each. func (g *GCSClient) forEachObject(bucket, prefix string, visitor ObjectVisitor) error { + if g.client == nil { + return fmt.Errorf("GCS client not initialized properly") + } + it := g.client.Bucket(bucket).Objects(g.ctx, &storage.Query{Prefix: prefix}) for { objAttrs, err := it.Next() @@ -796,6 +800,10 @@ func (g *GCSClient) forEachObject(bucket, prefix string, visitor ObjectVisitor) // listObjects returns all objects matching the prefix, optionally filtered. // Pass nil for filter to include all objects (except directory markers). func (g *GCSClient) listObjects(bucket, prefix string, filter ObjectFilter) ([]storage.ObjectAttrs, error) { + if g.client == nil { + return nil, fmt.Errorf("GCS client not initialized properly") + } + var objects []storage.ObjectAttrs err := g.forEachObject(bucket, prefix, func(attrs *storage.ObjectAttrs) error { @@ -819,6 +827,10 @@ func (g *GCSClient) listObjects(bucket, prefix string, filter ObjectFilter) ([]s // partitionObjects separates objects into two groups based on a predicate. // Objects matching the predicate go into the first slice, others into the second. func (g *GCSClient) partitionObjects(bucket, prefix string, predicate ObjectFilter) (matching, notMatching []storage.ObjectAttrs, err error) { + if g.client == nil { + return nil, nil, fmt.Errorf("GCS client not initialized properly") + } + err = g.forEachObject(bucket, prefix, func(attrs *storage.ObjectAttrs) error { // Skip directory markers if strings.HasSuffix(attrs.Name, "/") { diff --git a/horizon/internal/predator/README.md b/horizon/internal/predator/README.md new file mode 100644 index 00000000..3399f950 --- /dev/null +++ b/horizon/internal/predator/README.md @@ -0,0 +1,36 @@ +# Predator + +Predator handles model lifecycle operations: onboarding, approval workflows, validation, fetch/list, upload from local or GCS, and functional testing. + +## Package layout + +``` +internal/predator/ +├── controller/ (wires routes to handler) +├── handler/ (Config implementation + helpers) +├── route/ +├── proto/ and generated Go +├── init.go +└── README.md +``` + +--- + +## Handler package structure + +| File | Purpose | +|------|--------| +| **config.go** | Defines the **Config** interface (public API). Implemented by `*Predator`. | +| **init.go** | Singleton init: `InitV1ConfigHandler()` returns Config. | +| **model.go** / **models.go** | Request/response types, payloads, and shared structs. | +| **model_config.pb.go** | Generated from `proto/` (e.g. Triton model config). | +| **predator.go** | **Predator** struct, `InitV1ConfigHandler()`, and **public entrypoints** that implement Config (e.g. `HandleModelRequest`, `HandleDeleteModel`, `ProcessRequest`, `FetchModelConfig`, `FetchModels`, `ValidateRequest`, `GenerateFunctionalTestRequest`, `ExecuteFunctionalTestRequest`, `SendLoadTestRequest`, `UploadModelFolderFromLocal`, etc.). Remaining helpers used only here (e.g. `convertFields`, `convertInputWithFeatures`, `createModelParamsResponse`) also live in this file. | +| **predator_constants.go** | All `const` values: error messages, request types, stage names, config keys, etc. | +| **predator_approval.go** | Approval workflow: `processRequest`, onboard/scale-up/promote/delete/edit flows, GCS clone stages, DB population, restart deployable, revert, `createDiscoveryAndPredatorConfigTx`, `createPredatorConfigTx`, and related helpers. | +| **predator_validation.go** | Delete validation and async validation job: `ValidateDeleteRequest`, `validateEnsembleChildGroupDeletion`, lock release, `performAsyncValidation`, health checking, `clearTemporaryDeployable`, `copyExistingModelsToTemporary`, `copyRequestModelsToTemporary`, `restartTemporaryDeployable`, and related helpers. | +| **predator_fetch.go** | Fetch/list support: `batchFetchRelatedData`, `batchFetchDeployableConfigs`, `buildModelResponses` (used by `FetchModels`). | +| **predator_upload.go** | Upload-from-local and GCS sync: `uploadSingleModel`, `copyConfigToProdConfigSource`, `copyConfigToNewNameInConfigSource`, validation (metadata, online/offline/pricing features), `syncModelFiles`, `syncFullModel`, `syncPartialFiles`, `validateModelConfiguration`, `cleanEnsembleScheduling`, `replaceModelNameInConfigPreservingFormat`, and related helpers. | +| **predator_functional_testing.go** | Functional and load-test helpers: `flattenInputTo3DByteSlice`, `getElementSize`, `reshapeDataForBatch`, `convertDimsToIntSlice` (used by `GenerateFunctionalTestRequest`, `ExecuteFunctionalTestRequest`, `SendLoadTestRequest`). | +| **predator_helpers.go** | Shared helpers: GCS path parsing (`parseGCSURL`, `extractGCSPath`, `extractGCSDetails`), `GetDerivedModelName`, `GetOriginalModelName`, `isNonProductionEnvironment`. | +| **predator_test.go** | Unit tests for the handler (same package; no build tags). | + diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index 0356b6f9..be5f0b2d 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -9,7 +9,6 @@ import ( "io" "math" "os" - "regexp" "sync" "github.com/Meesho/BharatMLStack/horizon/internal/constant" @@ -37,7 +36,6 @@ import ( "github.com/Meesho/BharatMLStack/horizon/pkg/random" "github.com/Meesho/BharatMLStack/horizon/pkg/serializer" "github.com/rs/zerolog/log" - "gorm.io/gorm" ) type Predator struct { @@ -54,120 +52,6 @@ type Predator struct { validationJobRepo validationjob.Repository // Repository for tracking validation jobs } -const ( - OnboardRequestType = "Onboard" - ScaleUpRequestType = "ScaleUp" - PromoteRequestType = "Promote" - EditRequestType = "Edit" - DeleteRequestType = "Delete" - configFile = "config.pbtxt" - pendingApproval = "Pending Approval" - slashConstant = "/" - gcsPrefix = "gs://" - adminRole = "admin" - typeString = "TYPE_STRING" - bytesKeys = "BYTES" - typePrefix = "TYPE_" - errMsgFetchConfigs = "failed to fetch predator configs: %w" - errMsgParsePayload = "failed to parse payload for request ID %d: %w" - cpuRequestKey = "cpu_request" - cpuLimitKey = "cpu_limit" - memRequestKey = "mem_request" - memLimitKey = "mem_limit" - gpuRequestKey = "gpu_request" - gpuLimitKey = "gpu_limit" - minReplicaKey = "min_replica" - maxReplicaKey = "max_replica" - nodeSelectorKey = "node_selector" - statusFailed = "Failed" - statusInProgress = "In Progress" - errMsgMarshalMeta = "Failed to marshal metadata" - errMsgInsertConfig = "Failed to insert predator_config" - errMsgInsertDiscovery = "Failed to insert service discovery" - errMsgCreateConnection = "Error in creating connection" - errMsgTypeAssertion = "failed to cast connection to *infra.SQLConnection" - errMsgTypeAssertionLog = "Type assertion error" - errMsgCreateRequestRepo = "Error in creating predator request repository" - errMsgCreateDeployableRepo = "Error in creating service deployable repository" - errMsgCreateConfigRepo = "Error in creating predator config repository" - errMsgCreateDiscoveryRepo = "Error in creating service discovery repository" - errMsgCreateGroupIdCounterRepo = "Error in creating group id counter repository" - errMsgProcessPayload = "failed to process payload" - errMsgCreateRequestFormat = "could not create %s request" - successMsgFormat = "Model %s Request Raised Successfully." - fieldModelName = "model_name" - statusPendingApproval = "Pending Approval" - errModelNotFound = "model not found" - errFetchDiscoveryConfig = "failed to fetch service discovery config" - errFetchDeployableConfig = "failed to fetch service deployable config" - errUnmarshalDeployableConfig = "failed to unmarshal service deployable config" - errMarshalPayload = "failed to marshal payload" - errCreateDeleteRequest = "could not create delete request" - successDeleteRequestMsg = "Model deletion request raised successfully" - fieldModelSourcePath = "model_source_path" - fieldMetaData = "meta_data" - fieldDiscoveryConfigID = "discovery_config_id" - fieldConfigMapping = "config_mapping" - errReadConfigFileFormat = "failed to read config.pbtxt: %v" - errUnmarshalProtoFormat = "failed to unmarshal proto text: %v" - errNoInstanceGroup = "no instance group defined in model config" - errModelPathPrefix = "model_path must be provided and start with /" - errModelPathFormat = "invalid model_path format. Expected: /bucket/path/to/model" - errModelNameMissing = "model name is missing in config" - errMaxBatchSizeMissing = "max_batch_size is missing or zero in config" - errBackendMissing = "backend is missing in config" - errNoInputDefinitions = "no input definitions found in config" - errNoOutputDefinitions = "no output definitions found in config" - errInstanceGroupMissing = "instance group is missing in config" - errInvalidRequestIDFormat = "invalid group ID format" - errFailedToFetchRequest = "failed to fetch request for group id %s" - errInvalidRequestType = "invalid request type" - statusApproved = "Approved" - statusRejected = "Rejected" - errInvalidGcsBucketPath = "invalid gcs bucket path format for source or destination" - errFailedToUpdateRequest = "Failed to update request status" - successRejectMessage = "Request %d rejected successfully.\n" - errFailedToFindServiceDiscovery = "Failed to find service discovery entry" - errFailedToUpdateServiceDiscovery = "Failed to update service discovery to inactive" - errFailedToFindPredatorConfig = "Failed to find predator config entry" - errFailedToUpdatePredatorConfig = "Failed to update predator config to inactive" - - errFailedToParsePayload = "Failed to parse payload" - errChildModelNotInDeleteRequest = "ensemble model %s has child model %s which is not included in the delete request" - errChildModelDifferentDeployable = "ensemble model %s and its child model %s belong to different deployables (ensemble: %d, child: %d)" - errFailedToFetchDiscoveryConfigForModel = "failed to fetch discovery config for model %s: %w" - errFailedToFetchDiscoveryConfigForEnsemble = "failed to fetch discovery config for ensemble model %s: %w" - errFailedToFetchDiscoveryConfigForChild = "failed to fetch discovery config for child model %s: %w" - errDuplicateModelNameInDeployable = "duplicate model name %s found within deployable %d" - errNormalModelIsChildOfEnsemble = "model %s is a child of ensemble model %s in the same deployable %d, but ensemble is not included in delete request" - errEnsembleMissingChild = "ensemble model %s has child model %s which is not included in the delete request" - errChildMissingEnsemble = "child model %s is included in delete request but its parent ensemble model %s is not included" - errFailedToFindServiceDeployableEntry = "Failed to find service deployable entry" - errFailedToOperateGcsCloneStage = "Failed to operate gcs clone stage" - errFailedToRestartDeployable = "Failed to restart deployable" - errGCSCopyFailed = "GCS copy failed" - errFailedToUpdateRequestStatusAndStage = "Failed to update request status and stage %s" - onboardRequestFlow = "Onboard request" - cloneRequestFlow = "Clone request" - promoteRequestFlow = "Promote request" - predatorStageRestartDeployable = "Restart Deployable" - predatorStagePending = "Pending" - machineTypeKey = "machine_type" - cpuThresholdKey = "cpu_threshold" - gpuThresholdKey = "gpu_threshold" - tritonImageTagKey = "triton_image_tag" - basePathKey = "base_path" - predatorStageCloneToBucket = "Clone To Bucket" - predatorStageDBPopulation = "DB Population" - predatorStageRequestPayloadError = "Request Payload Error" - serviceDeployableNotFound = "ServiceDeployable not found" - failedToParseServiceConfig = "Failed to parse service config" - failedToCreateServiceDiscoveryAndConfig = "Failed to create service discovery and config" - predatorInferMethod = "inference.GRPCInferenceService/ModelInfer" - deployableTagDelimiter = "_" - scaleupTag = "scaleup" -) - func InitV1ConfigHandler() (Config, error) { var initErr error @@ -324,7 +208,7 @@ func (p *Predator) HandleModelRequest(req ModelRequest, requestType string) (str return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf("failed to get group id: %w", err) } - for i := range len(req.Payload) { + for i := range req.Payload { payloadObject := payloadObjects[i] payloadBytes, err := json.Marshal(payloadObject) if err != nil { @@ -428,125 +312,6 @@ func (p *Predator) HandleDeleteModel(deleteRequest DeleteRequest, createdBy stri return successDeleteRequestMsg, groupID, http.StatusOK, nil } -func (p *Predator) ValidateDeleteRequest(predatorConfigList []predatorconfig.PredatorConfig, ids []int) (bool, error) { - if len(predatorConfigList) != len(ids) { - log.Error().Err(errors.New(errModelNotFound)).Msgf("model not found for ids %v", ids) - return false, errors.New(errModelNotFound) - } - - // Create maps for quick lookup - requestedModelMap := make(map[string]predatorconfig.PredatorConfig) // modelName -> config - requestedDeployableMap := make(map[int]bool) // serviceDeployableID -> exists - deployableModelMap := make(map[int]map[string]predatorconfig.PredatorConfig) // deployableID -> modelName -> config - - // Build maps from requested models - for _, predatorConfig := range predatorConfigList { - // Get service deployable ID for this model - discoveryConfig, err := p.ServiceDiscoveryRepo.GetById(predatorConfig.DiscoveryConfigID) - if err != nil { - log.Error().Err(err).Msgf("failed to fetch discovery config for model %s", predatorConfig.ModelName) - return false, fmt.Errorf(errFailedToFetchDiscoveryConfigForModel, predatorConfig.ModelName, err) - } - - requestedModelMap[predatorConfig.ModelName] = predatorConfig - requestedDeployableMap[discoveryConfig.ServiceDeployableID] = true - - // Group models by deployable - if deployableModelMap[discoveryConfig.ServiceDeployableID] == nil { - deployableModelMap[discoveryConfig.ServiceDeployableID] = make(map[string]predatorconfig.PredatorConfig) - } - deployableModelMap[discoveryConfig.ServiceDeployableID][predatorConfig.ModelName] = predatorConfig - } - - // Check for duplicate model names within same deployable - for deployableID, models := range deployableModelMap { - if len(models) > 1 { - // Check if any model names are duplicated within this deployable - modelNameCount := make(map[string]int) - for modelName := range models { - modelNameCount[modelName]++ - } - for modelName, count := range modelNameCount { - if count > 1 { - return false, fmt.Errorf(errDuplicateModelNameInDeployable, modelName, deployableID) - } - } - } - } - - // Validate ensemble-child group deletion requirements - if err := p.validateEnsembleChildGroupDeletion(requestedModelMap, deployableModelMap); err != nil { - return false, err - } - - return true, nil -} - -func (p *Predator) validateEnsembleChildGroupDeletion(requestedModelMap map[string]predatorconfig.PredatorConfig, deployableModelMap map[int]map[string]predatorconfig.PredatorConfig) error { - // Get all active models to check for ensemble relationships - allModels, err := p.PredatorConfigRepo.FindAllActiveConfig() - if err != nil { - log.Error().Err(err).Msgf("failed to fetch all active models") - return fmt.Errorf("failed to fetch all active models: %w", err) - } - - // Group all models by deployable for easier lookup - allModelsByDeployable := make(map[int]map[string]predatorconfig.PredatorConfig) - for _, model := range allModels { - discoveryConfig, err := p.ServiceDiscoveryRepo.GetById(model.DiscoveryConfigID) - if err != nil { - log.Error().Err(err).Msgf("failed to fetch discovery config for model %s", model.ModelName) - continue - } - - if allModelsByDeployable[discoveryConfig.ServiceDeployableID] == nil { - allModelsByDeployable[discoveryConfig.ServiceDeployableID] = make(map[string]predatorconfig.PredatorConfig) - } - allModelsByDeployable[discoveryConfig.ServiceDeployableID][model.ModelName] = model - } - - // Check each deployable for ensemble-child relationships - for deployableID, modelsInDeployable := range allModelsByDeployable { - requestedModelsInDeployable := deployableModelMap[deployableID] - if requestedModelsInDeployable == nil { - continue // No models from this deployable in the delete request - } - - // Check each model in this deployable - for modelName, model := range modelsInDeployable { - var metadata MetaData - if err := json.Unmarshal(model.MetaData, &metadata); err != nil { - log.Error().Err(err).Msgf("failed to unmarshal metadata for model %s", modelName) - continue - } - - // Check if this is an ensemble model - if len(metadata.Ensembling.Step) > 0 { - // This is an ensemble model - isEnsembleInRequest := requestedModelsInDeployable[modelName].ID != 0 - - // Check each child of this ensemble - for _, step := range metadata.Ensembling.Step { - childModelName := step.ModelName - isChildInRequest := requestedModelsInDeployable[childModelName].ID != 0 - - // If ensemble is in request, all children must be in request - if isEnsembleInRequest && !isChildInRequest { - return fmt.Errorf(errEnsembleMissingChild, modelName, childModelName) - } - - // If child is in request, ensemble must be in request - if isChildInRequest && !isEnsembleInRequest { - return fmt.Errorf(errChildMissingEnsemble, childModelName, modelName) - } - } - } - } - } - - return nil -} - func (p *Predator) FetchModelConfig(req FetchModelConfigRequest) (ModelParamsResponse, int, error) { if err := validateModelPath(req.ModelPath); err != nil { return ModelParamsResponse{}, http.StatusBadRequest, err @@ -779,1695 +544,243 @@ func (p *Predator) ProcessRequest(req ApproveRequest) error { return nil } -func (p *Predator) processRequest(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { - if req.Status == statusApproved { - switch predatorRequestList[0].RequestType { - case OnboardRequestType: - p.processOnboardFlow(requestIdPayloadMap, predatorRequestList, req) - case ScaleUpRequestType: - p.processScaleUpFlow(requestIdPayloadMap, predatorRequestList, req) - case PromoteRequestType: - p.processPromoteFlow(requestIdPayloadMap, predatorRequestList, req) - case DeleteRequestType: - p.processDeleteRequest(requestIdPayloadMap, predatorRequestList, req) - case EditRequestType: - p.processEditRequest(requestIdPayloadMap, predatorRequestList, req) - default: - log.Error().Err(errors.New(errInvalidRequestType)).Msg(errInvalidRequestType) - } - } else { - p.processRejectRequest(predatorRequestList, req) +func (p *Predator) FetchModels() ([]ModelResponse, error) { + predatorConfigs, err := p.PredatorConfigRepo.FindAllActiveConfig() + if err != nil { + return nil, fmt.Errorf(errMsgFetchConfigs, err) } -} -func (p *Predator) processRejectRequest(predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { - for i := range predatorRequestList { - predatorRequestList[i].Status = statusRejected - predatorRequestList[i].RejectReason = req.RejectReason - predatorRequestList[i].Reviewer = req.ApprovedBy - predatorRequestList[i].UpdatedBy = req.ApprovedBy - predatorRequestList[i].UpdatedAt = time.Now() - predatorRequestList[i].Active = false + if len(predatorConfigs) == 0 { + return []ModelResponse{}, nil } - if err := p.Repo.UpdateMany(predatorRequestList); err != nil { - log.Printf(errFailedToUpdateRequestStatusAndStage, err) + // Phase 1: Batch fetch all required data to avoid N+1 queries + discoveryConfigs, serviceDeployables, err := p.batchFetchRelatedData(predatorConfigs) + if err != nil { + return nil, fmt.Errorf("failed to batch fetch related data: %w", err) } - log.Printf("Request %s rejected successfully.\n", req.GroupID) -} - -func (p *Predator) processDeleteRequest(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { - transferredGcsModelData, err := p.processGCSCloneToDeleteBucket(req.ApprovedBy, predatorRequestList, requestIdPayloadMap) + // Phase 2: Concurrently fetch deployable configs + deployableConfigs, err := p.batchFetchDeployableConfigs(serviceDeployables) if err != nil { - log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) - p.revertForDelete(transferredGcsModelData) - return + return nil, fmt.Errorf("failed to batch fetch deployable configs: %w", err) } - p.processDBPopulationStageForDelete(predatorRequestList, requestIdPayloadMap, req) - - if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { - log.Error().Err(err).Msg(errFailedToRestartDeployable) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) - return - } + // Phase 3: Build response objects + results := p.buildModelResponses(predatorConfigs, discoveryConfigs, serviceDeployables, deployableConfigs) + return results, nil } -func (p *Predator) processEditRequest(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { - log.Info().Msgf("Starting edit request flow for group ID: %s", req.GroupID) +func (p *Predator) FetchAllPredatorRequests(role, email string) ([]map[string]interface{}, error) { + var requests []predatorrequest.PredatorRequest + var err error - // Step 1: Get target deployable configuration from the request - targetDeployableID := int(requestIdPayloadMap[predatorRequestList[0].RequestID].ConfigMapping.ServiceDeployableID) - targetServiceDeployable, err := p.ServiceDeployableRepo.GetById(targetDeployableID) - if err != nil { - log.Error().Err(err).Msg("Failed to fetch target service deployable for edit request") - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) - return + if role == adminRole { + requests, err = p.Repo.GetAll() + } else { + requests, err = p.Repo.GetAllByEmail(email) } - var targetDeployableConfig PredatorDeployableConfig - if err := json.Unmarshal(targetServiceDeployable.Config, &targetDeployableConfig); err != nil { - log.Error().Err(err).Msg("Failed to parse target service deployable config") - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) - return + if err != nil { + return nil, fmt.Errorf("error fetching predator requests: %v", err) } - targetBucket, targetPath := extractGCSPath(strings.TrimSuffix(targetDeployableConfig.GCSBucketPath, "/*")) - log.Info().Msgf("Target deployable path: gs://%s/%s", targetBucket, targetPath) + groupedRequests := make(map[uint][]PredatorRequestResponse) - // Step 2: GCS Copy Stage - Copy models from source to target deployable path - transferredGcsModelData, err := p.processEditGCSCopyStage(requestIdPayloadMap, predatorRequestList, targetBucket, targetPath) - if err != nil { - log.Error().Err(err).Msg("Failed to copy models for edit request") - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) - p.revert(transferredGcsModelData) - return - } + for _, req := range requests { + var parsedPayload map[string]interface{} + if err := json.Unmarshal([]byte(req.Payload), &parsedPayload); err != nil { + return nil, fmt.Errorf("error parsing payload for request ID %d: %v", req.RequestID, err) + } + + // Initialize response with default values + requestResponse := PredatorRequestResponse{ + RequestID: req.RequestID, + GroupID: req.GroupId, + Payload: parsedPayload, + CreatedBy: req.CreatedBy, + UpdatedBy: req.UpdatedBy, + Reviewer: req.Reviewer, + RequestStage: req.RequestStage, + RequestType: req.RequestType, + Status: req.Status, + RejectReason: req.RejectReason, + CreatedAt: req.CreatedAt, + UpdatedAt: req.UpdatedAt, + IsValid: req.IsValid, + HasNilData: false, + TestResults: json.RawMessage("{}"), + } - // Update stage to DB Population after successful GCS copy - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusInProgress, predatorStageDBPopulation) + // Extract model name from payload and fetch predator config + // Skip predator config lookup for edit requests as models might not exist in DB yet + if modelName, ok := parsedPayload["model_name"].(string); ok && modelName != "" { + if predatorConfig, err := p.PredatorConfigRepo.GetActiveModelByModelName(modelName); err == nil { + requestResponse.HasNilData = predatorConfig.HasNilData + if predatorConfig.TestResults != nil { + requestResponse.TestResults = predatorConfig.TestResults + } + } + } - // Step 3: DB Update Stage - Update existing predator config with new metadata from request - err = p.processEditDBUpdateStage(requestIdPayloadMap, predatorRequestList, req.ApprovedBy) - if err != nil { - log.Error().Err(err).Msg("Failed to update database for edit request") - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) - p.revert(transferredGcsModelData) - return + groupedRequests[req.GroupId] = append(groupedRequests[req.GroupId], requestResponse) } - // Update stage to Restart Deployable after successful DB update - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusInProgress, predatorStageRestartDeployable) + var response []map[string]interface{} - // Step 4: Restart Deployable Stage - Restart target deployable - if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { - log.Error().Err(err).Msg("Failed to restart deployable for edit request") - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) - return + for groupID, groupRequests := range groupedRequests { + groupData := map[string]interface{}{ + "group_id": groupID, + "groups": groupRequests, + } + response = append(response, groupData) } - // Mark request as approved and completed - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusApproved, constant.EmptyString) - log.Info().Msgf("Edit request completed successfully for group ID: %s", req.GroupID) + return response, nil } -// processEditGCSCopyStage copies models from source to target deployable path for edit approval -func (p *Predator) processEditGCSCopyStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, targetBucket, targetPath string) ([]GcsModelData, error) { - var transferredGcsModelData []GcsModelData +func (p *Predator) ValidateRequest(groupId string) (string, int) { + // Validate input and basic checks first (before acquiring lock) + id, err := strconv.ParseUint(groupId, 10, 32) + if err != nil { + return "Invalid request ID format", http.StatusBadRequest + } - // Check if we're in the correct stage for GCS copy - if predatorRequestList[0].RequestStage != predatorStagePending && predatorRequestList[0].RequestStage != predatorStageCloneToBucket && predatorRequestList[0].RequestStage != constant.EmptyString { - log.Info().Msgf("Skipping GCS copy stage - current stage: %s", predatorRequestList[0].RequestStage) - return transferredGcsModelData, nil + request, err := p.Repo.GetAllByGroupID(uint(id)) + if err != nil { + return "Request not found", http.StatusNotFound } - isNotProd := p.isNonProductionEnvironment() + if len(request) == 0 { + return "Request Validation Failed. No requests found", http.StatusNotFound + } - for _, requestModel := range predatorRequestList { - payload := requestIdPayloadMap[requestModel.RequestID] - if payload == nil { - log.Error().Msgf("Payload not found for request ID %d", requestModel.RequestID) - continue - } + payload, err := p.processPayload(request[0]) + if err != nil { + log.Error().Err(err).Msg("Failed to parse payload for validation") + return "Request Validation Failed. Failed to parse request payload", http.StatusBadRequest + } - modelName := requestModel.ModelName + // Determine test deployable ID based on machine type + testDeployableID, err := p.getTestDeployableID(payload) + if err != nil { + log.Error().Err(err).Msg("Failed to determine test deployable ID") + return "Request Validation Failed. Failed to determine test deployable ID", http.StatusInternalServerError + } - // Use the source path from the payload, not the default GCS bucket - if payload.ModelSource == "" { - log.Error().Msgf("ModelSource is empty for request ID %d", requestModel.RequestID) - return transferredGcsModelData, fmt.Errorf("model source path is empty for model %s", modelName) - } + // Create deployable-specific lock key (allows parallel processing for different deployables) + lockKey := fmt.Sprintf("validation-deployable-%d", testDeployableID) - // Normalize GCS URL (handle gcs:// prefix) - normalizedModelSource := payload.ModelSource - if strings.HasPrefix(normalizedModelSource, "gcs://") { - normalizedModelSource = strings.Replace(normalizedModelSource, "gcs://", "gs://", 1) - log.Info().Msgf("Normalized GCS URL from %s to %s", payload.ModelSource, normalizedModelSource) - } + // Try to acquire deployable-specific distributed lock + lock, err := p.validationLockRepo.AcquireLock(lockKey, 30*time.Minute) + if err != nil { + log.Warn().Err(err).Msgf("Validation request for group ID %s rejected - failed to acquire lock for deployable %d", groupId, testDeployableID) + return fmt.Sprintf("Request Validation Failed. Another validation is already in progress for %s deployable. Please try again later.", + map[int]string{pred.TestDeployableID: "CPU", pred.TestGpuDeployableID: "GPU"}[testDeployableID]), http.StatusConflict + } - // Parse the source GCS path - sourceBucket, sourcePath := extractGCSPath(normalizedModelSource) - if sourceBucket == "" || sourcePath == "" { - log.Error().Msgf("Invalid source GCS path format: %s (normalized: %s)", payload.ModelSource, normalizedModelSource) - return transferredGcsModelData, fmt.Errorf("invalid source GCS path format: %s", normalizedModelSource) - } + log.Info().Msgf("Starting validation for group ID: %s on deployable %d (lock acquired by %s)", groupId, testDeployableID, lock.LockedBy) - log.Info().Msgf("Copying model %s from source gs://%s/%s to target gs://%s/%s for edit approval", - modelName, sourceBucket, sourcePath, targetBucket, targetPath) - - // Copy model from source to target deployable path - // Extract model folder name from source path and copy to target with the same model name - pathSegments := strings.Split(strings.TrimSuffix(sourcePath, "/"), "/") - sourceModelName := pathSegments[len(pathSegments)-1] - sourceBasePath := strings.TrimSuffix(sourcePath, "/"+sourceModelName) - - if isNotProd { - if err := p.GcsClient.TransferFolder( - sourceBucket, sourceBasePath, sourceModelName, - targetBucket, targetPath, modelName, - ); err != nil { - return transferredGcsModelData, err - } - } else { - configBucket := pred.GcsConfigBucket - configPath := pred.GcsConfigBasePath - if err := p.GcsClient.TransferFolderWithSplitSources( - sourceBucket, sourceBasePath, configBucket, configPath, - sourceModelName, targetBucket, targetPath, modelName, - ); err != nil { - return transferredGcsModelData, err - } + // Validate request status + for _, req := range request { + if req.Status == statusApproved { + p.releaseLockWithError(lock.ID, groupId, "Request already approved") + return "Request Validation Failed. Request is already approved", http.StatusBadRequest + } + if req.Status == statusRejected { + p.releaseLockWithError(lock.ID, groupId, "Request already rejected") + return "Request Validation Failed. Request is already rejected", http.StatusBadRequest } - - // Track transferred data for potential rollback - transferredGcsModelData = append(transferredGcsModelData, GcsModelData{ - Bucket: targetBucket, - Path: targetPath, - Name: modelName, - }) - - log.Info().Msgf("Successfully copied model %s for edit approval", modelName) } - return transferredGcsModelData, nil -} - -// processEditDBUpdateStage updates predator config for edit approval -// This updates the existing predator config with new config.pbtxt and metadata.json changes -func (p *Predator) processEditDBUpdateStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, approvedBy string) error { - // Check if we're in the correct stage for DB update - if predatorRequestList[0].RequestStage != predatorStageDBPopulation { - log.Info().Msgf("Skipping DB update stage - current stage: %s", predatorRequestList[0].RequestStage) - return nil + // Get service name from deployable config + serviceName, err := p.getServiceNameFromDeployable(testDeployableID) + if err != nil { + log.Error().Err(err).Msg("Failed to get service name from deployable") + p.releaseLockWithError(lock.ID, groupId, "Failed to get service name") + return "Request Validation Failed. Failed to get service name", http.StatusInternalServerError } - log.Info().Msg("Starting DB update stage for edit approval") - - for _, requestModel := range predatorRequestList { - payload := requestIdPayloadMap[requestModel.RequestID] - if payload == nil { - log.Error().Msgf("Payload not found for request ID %d", requestModel.RequestID) - continue - } - - modelName := requestModel.ModelName - log.Info().Msgf("Updating predator config for model %s", modelName) + // Create validation job + validationJob := &validationjob.Table{ + GroupID: groupId, + LockID: lock.ID, + TestDeployableID: testDeployableID, + ServiceName: serviceName, + Status: validationjob.StatusPending, + MaxHealthChecks: 15, + HealthCheckInterval: 60, + } - // Find existing predator config for this model - existingPredatorConfig, err := p.PredatorConfigRepo.GetActiveModelByModelName(modelName) - if err != nil { - log.Error().Err(err).Msgf("Failed to fetch existing predator config for model %s", modelName) - return fmt.Errorf("failed to fetch existing predator config for model %s: %w", modelName, err) - } + if err := p.validationJobRepo.Create(validationJob); err != nil { + log.Error().Err(err).Msg("Failed to create validation job") + p.releaseLockWithError(lock.ID, groupId, "Failed to create validation job") + return "Request Validation Failed. Failed to create validation job", http.StatusInternalServerError + } - if existingPredatorConfig == nil { - log.Error().Msgf("No existing predator config found for model %s", modelName) - return fmt.Errorf("no existing predator config found for model %s", modelName) - } + // Start asynchronous validation process + go p.performAsyncValidation(validationJob, request, payload, testDeployableID) - // Clean up ensemble scheduling and update the predator config with new metadata from the request - cleanedMetaData := p.cleanEnsembleScheduling(payload.MetaData) + log.Info().Msgf("Validation job created for group ID: %s, job ID: %d", groupId, validationJob.ID) + return "Request Validation Started. The validation will run asynchronously and update the request status when complete.", http.StatusOK +} - metaDataBytes, err := json.Marshal(cleanedMetaData) - if err != nil { - log.Error().Err(err).Msgf("Failed to marshal metadata for model %s", modelName) - return fmt.Errorf("failed to marshal metadata for model %s: %w", modelName, err) - } +// CleanupExpiredValidationLocks removes expired validation locks +// This method can be called periodically to clean up stale locks +func (p *Predator) CleanupExpiredValidationLocks() error { + if p.validationLockRepo == nil { + return errors.New("validation lock repository not initialized") + } - // Update the existing config - existingPredatorConfig.MetaData = metaDataBytes - existingPredatorConfig.UpdatedBy = approvedBy - existingPredatorConfig.UpdatedAt = time.Now() - existingPredatorConfig.HasNilData = true - existingPredatorConfig.TestResults = nil - // Save the updated config - if err := p.PredatorConfigRepo.Update(existingPredatorConfig); err != nil { - log.Error().Err(err).Msgf("Failed to update predator config for model %s", modelName) - return fmt.Errorf("failed to update predator config for model %s: %w", modelName, err) - } + log.Info().Msg("Starting cleanup of expired validation locks") - log.Info().Msgf("Successfully updated predator config for model %s", modelName) + if err := p.validationLockRepo.CleanupExpiredLocks(); err != nil { + log.Error().Err(err).Msg("Failed to cleanup expired validation locks") + return err } - log.Info().Msg("DB update stage completed successfully for edit approval") + log.Info().Msg("Successfully cleaned up expired validation locks") return nil } -func (p *Predator) copyAllModelsFromActualToStaging(sourceBucket, sourcePath, targetBucket, targetPath string) error { - // List all models in the actual target path and copy them to staging - folders, err := p.GcsClient.ListFolders(sourceBucket, sourcePath) +// GetValidationStatus returns the current validation lock status +func (p *Predator) GetValidationStatus() (bool, *validationlock.Table, error) { + if p.validationLockRepo == nil { + return false, nil, errors.New("validation lock repository not initialized") + } + + isLocked, err := p.validationLockRepo.IsLocked(validationlock.ValidationLockKey) if err != nil { - return fmt.Errorf("failed to list models in actual target path: %w", err) + return false, nil, err } - // Copy each model folder from actual target to staging - for _, modelName := range folders { - log.Info().Msgf("Copying existing model %s from actual target to staging", modelName) + if !isLocked { + return false, nil, nil + } - if err := p.GcsClient.TransferFolder(sourceBucket, sourcePath, modelName, targetBucket, targetPath, modelName); err != nil { - log.Error().Err(err).Msgf("Failed to copy existing model %s to staging", modelName) - return fmt.Errorf("failed to copy existing model %s to staging: %w", modelName, err) - } + activeLock, err := p.validationLockRepo.GetActiveLock(validationlock.ValidationLockKey) + if err != nil { + return false, nil, err } - return nil + return true, activeLock, nil } -func (p *Predator) deleteServiceDiscoveryAndConfig(req ApproveRequest, predatorRequestList []predatorrequest.PredatorRequest, requestIdPayloadMap map[uint]*Payload) error { - tx := p.Repo.DB().Begin() - if tx.Error != nil { - return tx.Error +// GetValidationJobStatus returns the status of a validation job for a given group ID +func (p *Predator) GetValidationJobStatus(groupId string) (*validationjob.Table, error) { + if p.validationJobRepo == nil { + return nil, errors.New("validation job repository not initialized") } - defer func() { - if r := recover(); r != nil { - tx.Rollback() - panic(r) // re-throw panic after rollback - } - }() - - for i := range predatorRequestList { - payload := requestIdPayloadMap[predatorRequestList[i].RequestID] - if payload == nil { - log.Error().Msgf(errFailedToParsePayload) - tx.Rollback() - return fmt.Errorf("failed to parse payload for request ID %d", predatorRequestList[i].RequestID) - } - discoveryConfigID := int(payload.DiscoveryConfigID) - log.Info().Msgf("Processing delete request for discovery config ID: %d", discoveryConfigID) - serviceDiscovery, err := p.ServiceDiscoveryRepo.WithTx(tx).GetById(discoveryConfigID) - if err != nil { - log.Error().Err(err).Msg(errFailedToFindServiceDiscovery) - tx.Rollback() - return err - } - serviceDiscovery.Active = false - serviceDiscovery.UpdatedAt = time.Now() - serviceDiscovery.UpdatedBy = req.ApprovedBy - - if err := p.ServiceDiscoveryRepo.WithTx(tx).Update(serviceDiscovery); err != nil { - log.Error().Err(err).Msg(errFailedToUpdateServiceDiscovery) - tx.Rollback() - return err - } - - predatorConfigs, err := p.PredatorConfigRepo.WithTx(tx).GetByDiscoveryConfigID(discoveryConfigID) - if err != nil { - log.Error().Err(err).Msg(errFailedToFindPredatorConfig) - tx.Rollback() - return err - } - - for j := range predatorConfigs { - predatorConfigs[j].Active = false - predatorConfigs[j].UpdatedAt = time.Now() - predatorConfigs[j].UpdatedBy = req.ApprovedBy - if err := p.PredatorConfigRepo.WithTx(tx).Update(&predatorConfigs[j]); err != nil { - log.Error().Err(err).Msg(errFailedToUpdatePredatorConfig) - tx.Rollback() - return err - } - } - - predatorRequestList[i].Status = statusInProgress - predatorRequestList[i].Reviewer = req.ApprovedBy - predatorRequestList[i].UpdatedBy = req.ApprovedBy - predatorRequestList[i].RequestStage = predatorStageRestartDeployable - predatorRequestList[i].UpdatedAt = time.Now() - - if err := p.Repo.WithTx(tx).Update(&predatorRequestList[i]); err != nil { - log.Error().Err(err).Msg(errFailedToUpdateRequestStatusAndStage) - tx.Rollback() - return err - } - } - - if err := tx.Commit().Error; err != nil { - log.Error().Err(err).Msg("transaction commit failed") - return err - } - - return nil -} - -func (p *Predator) processGCSCloneToDeleteBucket(email string, predatorRequestList []predatorrequest.PredatorRequest, requestIdPayloadMap map[uint]*Payload) ([]GcsTransferredData, error) { - var transferredGcsModelData []GcsTransferredData - if predatorRequestList[0].RequestStage == constant.EmptyString || predatorRequestList[0].RequestStage == predatorStagePending || predatorRequestList[0].RequestStage == predatorStageCloneToBucket { - for _, requestModel := range predatorRequestList { - srcBucket, srcPath, srcModelName := extractGCSDetails(requestIdPayloadMap[requestModel.RequestID].ModelSource) - destBucket, destPath := extractGCSPath(pred.DefaultModelPathKey) - log.Info().Msgf("srcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", srcBucket, srcPath, srcModelName, destBucket, destPath) - if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || srcModelName == constant.EmptyString || destBucket == constant.EmptyString || destPath == constant.EmptyString || requestIdPayloadMap[requestModel.RequestID].ModelName == constant.EmptyString { - log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) - return transferredGcsModelData, errors.New(errModelPathFormat) - } - - if err := p.GcsClient.TransferAndDeleteFolder(srcBucket, srcPath, srcModelName, destBucket, destPath, requestIdPayloadMap[requestModel.RequestID].ModelName); err != nil { - log.Error().Err(err).Msg(errGCSCopyFailed) - return transferredGcsModelData, err - } - - transferredGcsModelData = append(transferredGcsModelData, GcsTransferredData{ - SrcBucket: destBucket, - SrcPath: destPath, - SrcName: requestIdPayloadMap[requestModel.RequestID].ModelName, - DestBucket: srcBucket, - DestPath: srcPath, - DestName: srcModelName, - }) - } - p.updateRequestStatusAndStage(email, predatorRequestList, statusInProgress, predatorStageDBPopulation) - } - return transferredGcsModelData, nil -} - -func (p *Predator) processRestartDeployableStage(email string, predatorRequestList []predatorrequest.PredatorRequest, requestIdPayloadMap map[uint]*Payload) error { - if predatorRequestList[0].RequestStage != predatorStageRestartDeployable { - return nil - } - var serviceDeployableIDList []int - for _, requestModel := range predatorRequestList { - serviceDeployableIDList = append(serviceDeployableIDList, int(requestIdPayloadMap[requestModel.RequestID].ConfigMapping.ServiceDeployableID)) - } - - for _, serviceDeployableID := range serviceDeployableIDList { - sd, err := p.ServiceDeployableRepo.GetById(int(serviceDeployableID)) - if err != nil { - log.Error().Err(err).Msg(errFailedToFindServiceDeployableEntry) - return err - } - // Extract isCanary from deployable config - var deployableConfig map[string]interface{} - isCanary := false - if err := json.Unmarshal(sd.Config, &deployableConfig); err == nil { - if strategy, ok := deployableConfig["deploymentStrategy"].(string); ok && strategy == "canary" { - isCanary = true - } - } - if err := p.infrastructureHandler.RestartDeployment(sd.Name, p.workingEnv, isCanary); err != nil { - log.Error().Err(err).Msg(errFailedToRestartDeployable) - return err - } - } - - p.updateRequestStatusAndStage(email, predatorRequestList, statusApproved, constant.EmptyString) - return nil -} - -func (p *Predator) processPayload(predatorRequest predatorrequest.PredatorRequest) (*Payload, error) { - var payload Payload - decoder := json.NewDecoder(strings.NewReader(predatorRequest.Payload)) - decoder.DisallowUnknownFields() - if err := decoder.Decode(&payload); err != nil { - log.Error().Err(err).Msg("Failed to parse payload with strict decoding") - return nil, err - } - return &payload, nil -} - -func (p *Predator) processGCSCloneStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) ([]GcsModelData, error) { - var transferredGcsModelData []GcsModelData - if predatorRequestList[0].RequestStage == predatorStagePending || predatorRequestList[0].RequestStage == predatorStageCloneToBucket { - isNotProd := p.isNonProductionEnvironment() - for _, requestModel := range predatorRequestList { - - serviceDeployable, err := p.ServiceDeployableRepo.GetById(int(requestIdPayloadMap[requestModel.RequestID].ConfigMapping.ServiceDeployableID)) - - if err != nil { - log.Error().Err(err).Msg(serviceDeployableNotFound) - return transferredGcsModelData, err - } - - var deployableConfig PredatorDeployableConfig - if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err != nil { - log.Error().Err(err).Msg(failedToParseServiceConfig) - return transferredGcsModelData, err - } - - destBucket, destPath := extractGCSPath(strings.TrimSuffix(deployableConfig.GCSBucketPath, "/*")) - destModelName := requestIdPayloadMap[requestModel.RequestID].ModelName - - var srcBucket, srcPath, srcModelName string - - srcBucket = pred.GcsModelBucket - srcPath = pred.GcsModelBasePath - if requestModel.RequestType == ScaleUpRequestType { - srcModelName = destModelName - log.Info().Msgf("Scale-up: Source from model-source gs://%s/%s/%s", - srcBucket, srcPath, srcModelName) - } else { - _, _, srcModelName = extractGCSDetails(requestIdPayloadMap[requestModel.RequestID].ModelSource) - log.Info().Msgf("Onboard/Promote: Source from payload gs://%s/%s/%s", - srcBucket, srcPath, srcModelName) - } - - log.Info().Msgf("Copying to target deployable - src: %s/%s/%s, dest: %s/%s/%s", - srcBucket, srcPath, srcModelName, destBucket, destPath, destModelName) - - if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || - srcModelName == constant.EmptyString || destBucket == constant.EmptyString || - destPath == constant.EmptyString || destModelName == constant.EmptyString { - log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) - return transferredGcsModelData, errors.New(errModelPathFormat) - } - - if isNotProd { - if err := p.GcsClient.TransferFolder(srcBucket, srcPath, srcModelName, - destBucket, destPath, destModelName); err != nil { - log.Error().Err(err).Msg(errGCSCopyFailed) - return transferredGcsModelData, err - } - } else { - if err := p.GcsClient.TransferFolderWithSplitSources( - srcBucket, srcPath, pred.GcsConfigBucket, pred.GcsConfigBasePath, - srcModelName, destBucket, destPath, destModelName, - ); err != nil { - log.Error().Err(err).Msg(errGCSCopyFailed) - return transferredGcsModelData, err - } - } - - transferredGcsModelData = append(transferredGcsModelData, GcsModelData{ - Bucket: destBucket, - Path: destPath, - Name: requestIdPayloadMap[requestModel.RequestID].ModelName, - }) - - log.Info().Msgf("Successfully copied model to target deployable: %s", destModelName) - } - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusInProgress, predatorStageDBPopulation) - } - return transferredGcsModelData, nil -} - -func (p *Predator) processGCSCloneStageIndefaultFolder(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) ([]GcsModelData, error) { - var transferredGcsModelData []GcsModelData - if predatorRequestList[0].RequestStage != predatorStagePending && - predatorRequestList[0].RequestStage != predatorStageCloneToBucket { - return transferredGcsModelData, nil - } - - isNotProd := p.isNonProductionEnvironment() - - for _, requestModel := range predatorRequestList { - payload := requestIdPayloadMap[requestModel.RequestID] - - destBucket := pred.GcsModelBucket - destPath := pred.GcsModelBasePath - destModelName := payload.ModelName - - _, _, originalModelName := extractGCSDetails(payload.ModelSource) - srcBucket := pred.GcsModelBucket - srcPath := pred.GcsModelBasePath - srcModelName := originalModelName - - log.Info().Msgf("Scale-up: Copying within model-source %s → %s:\nsrcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", - srcModelName, destModelName, srcBucket, srcPath, srcModelName, destBucket, destPath) - - if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || - srcModelName == constant.EmptyString || destBucket == constant.EmptyString || - destPath == constant.EmptyString || destModelName == constant.EmptyString { - log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) - return transferredGcsModelData, errors.New(errModelPathFormat) - } - - if err := p.GcsClient.TransferFolder(srcBucket, srcPath, srcModelName, - destBucket, destPath, destModelName); err != nil { - log.Error().Err(err).Msg(errGCSCopyFailed) - return transferredGcsModelData, err - } - - log.Info().Msgf("Successfully copied model in model-source: %s → %s", srcModelName, destModelName) - - if !isNotProd && srcModelName != destModelName { - if err := p.copyConfigToNewNameInConfigSource(srcModelName, destModelName); err != nil { - log.Error().Err(err).Msgf("Failed to copy config to config-source: %s → %s", - srcModelName, destModelName) - return transferredGcsModelData, err - } - } - - transferredGcsModelData = append(transferredGcsModelData, GcsModelData{ - Bucket: destBucket, - Path: destPath, - Name: destModelName, - }) - } - - return transferredGcsModelData, nil -} - -func (p *Predator) processDBPopulationStageForDelete(predatorRequestList []predatorrequest.PredatorRequest, requestIdPayloadMap map[uint]*Payload, req ApproveRequest) { - if predatorRequestList[0].RequestStage != predatorStageDBPopulation { - return - } - - if err := p.deleteServiceDiscoveryAndConfig(req, predatorRequestList, requestIdPayloadMap); err != nil { - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) - return - } -} - -func (p *Predator) processDBPopulationStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, approvedBy string, successMessage string) error { - if predatorRequestList[0].RequestStage != predatorStageDBPopulation { - return nil - } - tx := p.Repo.DB().Begin() - for i := range predatorRequestList { - defer func() { - if r := recover(); r != nil { - tx.Rollback() - log.Printf("panic recovered, transaction rolled back") - } - }() - - if err := p.createDiscoveryAndPredatorConfigTx(tx, predatorRequestList[i], *requestIdPayloadMap[predatorRequestList[i].RequestID], approvedBy); err != nil { - tx.Rollback() - log.Error().Err(err).Msg(failedToCreateServiceDiscoveryAndConfig) - return err - } - - predatorRequestList[i].Status = statusInProgress - predatorRequestList[i].RequestStage = predatorStageRestartDeployable - if err := p.Repo.UpdateStatusAndStage(tx, &predatorRequestList[i]); err != nil { - tx.Rollback() - log.Printf(errFailedToUpdateRequestStatusAndStage, err) - } - } - if err := tx.Commit().Error; err != nil { - log.Printf("failed to commit transaction: %v", err) - return err - } - log.Printf("success %s %d\n", successMessage, predatorRequestList[0].GroupId) - return nil -} - -func (p *Predator) checkIfModelsExist(predatorRequestList []predatorrequest.PredatorRequest) bool { - for _, requestModel := range predatorRequestList { - modelName := requestModel.ModelName - if modelName == "" { - log.Error().Msgf("model name is empty for request ID %d", requestModel.RequestID) - continue - } - - predatorConfig, err := p.PredatorConfigRepo.GetActiveModelByModelName(modelName) - if err != nil { - log.Error().Err(err).Msgf("failed to fetch predator config for model %s", modelName) - continue - } - if predatorConfig != nil { - log.Error().Msgf("model %s already exists", modelName) - return true - } - } - return false -} - -func (p *Predator) processOnboardFlow(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { - if p.checkIfModelsExist(predatorRequestList) { - req.RejectReason = "model already exists" - req.Status = statusRejected - p.processRejectRequest(predatorRequestList, req) - return - } - - transferredGcsModelData, err := p.processGCSCloneStage(requestIdPayloadMap, predatorRequestList, req) - if err != nil { - log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) - p.revert(transferredGcsModelData) - return - } - - err = p.processDBPopulationStage(requestIdPayloadMap, predatorRequestList, req.ApprovedBy, onboardRequestFlow) - if err != nil { - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) - } - if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { - log.Error().Err(err).Msg(errFailedToRestartDeployable) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) - return - } -} - -func (p *Predator) revert(transferredGcsModelData []GcsModelData) error { - for _, data := range transferredGcsModelData { - if err := p.GcsClient.DeleteFolder(data.Bucket, data.Path, data.Name); err != nil { - log.Error().Err(err).Msg(errGCSCopyFailed) - return err - } - } - return nil -} - -func (p *Predator) revertForDelete(transferredGcsModelData []GcsTransferredData) error { - for _, data := range transferredGcsModelData { - if err := p.GcsClient.TransferAndDeleteFolder(data.SrcBucket, data.SrcPath, data.SrcName, data.DestBucket, data.DestPath, data.DestName); err != nil { - log.Error().Err(err).Msg(errGCSCopyFailed) - return err - } - } - return nil -} - -func (p *Predator) processScaleUpFlow(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { - if p.checkIfModelsExist(predatorRequestList) { - req.RejectReason = fmt.Sprintf("model %s already exists", requestIdPayloadMap[predatorRequestList[0].RequestID].ModelName) - req.Status = statusRejected - p.processRejectRequest(predatorRequestList, req) - return - } - - transferredGcsModelData, err := p.processGCSCloneStageIndefaultFolder(requestIdPayloadMap, predatorRequestList, req) - if err != nil { - log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) - p.revert(transferredGcsModelData) - return - } - - transferredGcsModelData, err = p.processGCSCloneStage(requestIdPayloadMap, predatorRequestList, req) - if err != nil { - log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) - p.revert(transferredGcsModelData) - return - } - - err = p.processDBPopulationStage(requestIdPayloadMap, predatorRequestList, req.ApprovedBy, cloneRequestFlow) - if err != nil { - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) - } - if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { - log.Error().Err(err).Msg(errFailedToRestartDeployable) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) - return - } -} - -func (p *Predator) processPromoteFlow(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { - if p.checkIfModelsExist(predatorRequestList) { - req.RejectReason = fmt.Sprintf("model %s already exists", requestIdPayloadMap[predatorRequestList[0].RequestID].ModelName) - req.Status = statusRejected - p.processRejectRequest(predatorRequestList, req) - return - } - - transferredGcsModelData, err := p.processGCSCloneStage(requestIdPayloadMap, predatorRequestList, req) - if err != nil { - log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) - p.revert(transferredGcsModelData) - return - } - - err = p.processDBPopulationStage(requestIdPayloadMap, predatorRequestList, req.ApprovedBy, promoteRequestFlow) - if err != nil { - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) - } - if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { - log.Error().Err(err).Msg(errFailedToRestartDeployable) - p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) - return - } -} - -func (p *Predator) updateRequestStatusAndStage(approvedBy string, predatorRequestList []predatorrequest.PredatorRequest, status, stage string) { - for i := range predatorRequestList { - predatorRequestList[i].Status = status - predatorRequestList[i].Reviewer = approvedBy - predatorRequestList[i].UpdatedBy = approvedBy - if stage != constant.EmptyString { - predatorRequestList[i].RequestStage = stage - } - if predatorRequestList[i].Status == statusApproved || - predatorRequestList[i].Status == statusFailed || - predatorRequestList[i].Status == statusRejected { - predatorRequestList[i].Active = false - } - predatorRequestList[i].UpdatedAt = time.Now() - } - - if err := p.Repo.UpdateMany(predatorRequestList); err != nil { - log.Printf(errFailedToUpdateRequestStatusAndStage, err) - } -} - -func (p *Predator) createDiscoveryAndPredatorConfigTx(tx *gorm.DB, requestModel predatorrequest.PredatorRequest, payload Payload, approvedBy string) error { - discoveryConfig, err := p.createDiscoveryConfigTx(tx, &requestModel, payload) - if err != nil { - return err - } - return p.createPredatorConfigTx(tx, &requestModel, payload, approvedBy, discoveryConfig.ID) -} - -func (p *Predator) createDiscoveryConfigTx(tx *gorm.DB, requestModel *predatorrequest.PredatorRequest, payload Payload) (discoveryconfig.DiscoveryConfig, error) { - discoveryConfig := discoveryconfig.DiscoveryConfig{ - ServiceDeployableID: int(payload.ConfigMapping.ServiceDeployableID), - CreatedBy: requestModel.CreatedBy, - UpdatedBy: requestModel.UpdatedBy, - Active: true, - CreatedAt: requestModel.CreatedAt, - UpdatedAt: time.Now(), - } - if err := tx.Create(&discoveryConfig).Error; err != nil { - log.Error().Err(err).Msg(errMsgInsertDiscovery) - return discoveryConfig, err - } - return discoveryConfig, nil -} - -func (p *Predator) createPredatorConfigTx(tx *gorm.DB, requestModel *predatorrequest.PredatorRequest, payload Payload, approvedBy string, discoveryConfigID int) error { - // Clean up ensemble scheduling before marshaling - cleanedMetaData := p.cleanEnsembleScheduling(payload.MetaData) - - metaDataBytes, err := json.Marshal(cleanedMetaData) - if err != nil { - log.Error().Err(err).Msg(errMsgMarshalMeta) - return err - } - - serviceDeployableID := int(payload.ConfigMapping.ServiceDeployableID) - serviceDeployable, err := p.ServiceDeployableRepo.GetById(serviceDeployableID) - if err != nil { - log.Error().Err(err).Msgf("Failed to get service deployable config for ID %d", serviceDeployableID) - return fmt.Errorf("failed to get service deployable config: %w", err) - } - - config := predatorconfig.PredatorConfig{ - DiscoveryConfigID: discoveryConfigID, - ModelName: payload.ModelName, - MetaData: metaDataBytes, - CreatedBy: requestModel.CreatedBy, - UpdatedBy: approvedBy, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - Active: true, - SourceModelName: payload.ConfigMapping.SourceModelName, - } - - if serviceDeployable.OverrideTesting { - log.Info().Msgf("OverrideTesting is enabled for deployable %s. Setting test_results for model %s", - serviceDeployable.Name, payload.ModelName) - - config.TestResults = json.RawMessage(`{"is_functionally_tested": true}`) - config.HasNilData = false - } - - if err := tx.Create(&config).Error; err != nil { - log.Error().Err(err).Msg(errMsgInsertConfig) - return err - } - return nil -} - -func parseGCSURL(gcsURL string) (bucket, objectPath string, ok bool) { - // Handle both gs:// and gcs:// prefixes (normalize gcs:// to gs://) - if strings.HasPrefix(gcsURL, "gcs://") { - gcsURL = strings.Replace(gcsURL, "gcs://", "gs://", 1) - } - - if !strings.HasPrefix(gcsURL, gcsPrefix) { - return constant.EmptyString, constant.EmptyString, false - } - - trimmed := strings.TrimPrefix(gcsURL, gcsPrefix) - parts := strings.SplitN(trimmed, slashConstant, 2) - if len(parts) < 1 { - return constant.EmptyString, constant.EmptyString, false - } - - bucket = parts[0] - if len(parts) == 2 { - objectPath = parts[1] - } - return bucket, objectPath, true -} - -func extractGCSPath(gcsURL string) (bucket, objectPath string) { - bucket, objectPath, ok := parseGCSURL(gcsURL) - if !ok { - return constant.EmptyString, constant.EmptyString - } - return bucket, objectPath -} - -func extractGCSDetails(gcsURL string) (bucket, pathOnly, modelName string) { - bucket, objectPath, ok := parseGCSURL(gcsURL) - if !ok || objectPath == constant.EmptyString { - return constant.EmptyString, constant.EmptyString, constant.EmptyString - } - - segments := strings.Split(objectPath, slashConstant) - if len(segments) == 0 { - return bucket, constant.EmptyString, constant.EmptyString - } - - modelName = segments[len(segments)-1] - pathOnly = path.Join(segments[:len(segments)-1]...) - return bucket, pathOnly, modelName -} - -func (p *Predator) FetchModels() ([]ModelResponse, error) { - predatorConfigs, err := p.PredatorConfigRepo.FindAllActiveConfig() - if err != nil { - return nil, fmt.Errorf(errMsgFetchConfigs, err) - } - - if len(predatorConfigs) == 0 { - return []ModelResponse{}, nil - } - - // Phase 1: Batch fetch all required data to avoid N+1 queries - discoveryConfigs, serviceDeployables, err := p.batchFetchRelatedData(predatorConfigs) - if err != nil { - return nil, fmt.Errorf("failed to batch fetch related data: %w", err) - } - - // Phase 2: Concurrently fetch deployable configs - deployableConfigs, err := p.batchFetchDeployableConfigs(serviceDeployables) - if err != nil { - return nil, fmt.Errorf("failed to batch fetch deployable configs: %w", err) - } - - // Phase 3: Build response objects - results := p.buildModelResponses(predatorConfigs, discoveryConfigs, serviceDeployables, deployableConfigs) - - return results, nil -} - -// batchFetchRelatedData efficiently fetches all discovery configs and service deployables in batch -func (p *Predator) batchFetchRelatedData(predatorConfigs []predatorconfig.PredatorConfig) (map[int]*discoveryconfig.DiscoveryConfig, map[int]*servicedeployableconfig.ServiceDeployableConfig, error) { - // Collect all unique discovery config IDs - discoveryConfigIDs := make([]int, 0, len(predatorConfigs)) - discoveryIDSet := make(map[int]bool) - - for _, config := range predatorConfigs { - if !discoveryIDSet[config.DiscoveryConfigID] { - discoveryConfigIDs = append(discoveryConfigIDs, config.DiscoveryConfigID) - discoveryIDSet[config.DiscoveryConfigID] = true - } - } - - // Batch fetch all discovery configs - discoveryConfigs := make(map[int]*discoveryconfig.DiscoveryConfig) - for _, id := range discoveryConfigIDs { - config, err := p.ServiceDiscoveryRepo.GetById(id) - if err != nil { - continue // Skip failed configs, same behavior as original - } - discoveryConfigs[id] = config - } - - // Collect all unique service deployable IDs - serviceDeployableIDs := make([]int, 0, len(discoveryConfigs)) - serviceDeployableIDSet := make(map[int]bool) - - for _, config := range discoveryConfigs { - if !serviceDeployableIDSet[config.ServiceDeployableID] { - serviceDeployableIDs = append(serviceDeployableIDs, config.ServiceDeployableID) - serviceDeployableIDSet[config.ServiceDeployableID] = true - } - } - - // Batch fetch all service deployables - serviceDeployables := make(map[int]*servicedeployableconfig.ServiceDeployableConfig) - for _, id := range serviceDeployableIDs { - deployable, err := p.ServiceDeployableRepo.GetById(id) - if err != nil { - continue // Skip failed deployables, same behavior as original - } - serviceDeployables[id] = deployable - } - - return discoveryConfigs, serviceDeployables, nil -} - -// batchFetchDeployableConfigs concurrently fetches deployable configs for all service deployables -func (p *Predator) batchFetchDeployableConfigs(serviceDeployables map[int]*servicedeployableconfig.ServiceDeployableConfig) (map[int]externalcall.Config, error) { - deployableConfigs := make(map[int]externalcall.Config) - var mu sync.Mutex - var wg sync.WaitGroup - - // Use a semaphore to limit concurrent API calls - semaphore := make(chan struct{}, 10) // Limit to 10 concurrent calls - - for id, deployable := range serviceDeployables { - wg.Add(1) - go func(deployableID int, sd *servicedeployableconfig.ServiceDeployableConfig) { - defer wg.Done() - semaphore <- struct{}{} // Acquire semaphore - defer func() { <-semaphore }() // Release semaphore - - infraConfig := p.infrastructureHandler.GetConfig(sd.Name, p.workingEnv) - // Convert to externalcall.Config for compatibility - config := externalcall.Config{ - MinReplica: infraConfig.MinReplica, - MaxReplica: infraConfig.MaxReplica, - RunningStatus: infraConfig.RunningStatus, - } - - mu.Lock() - deployableConfigs[deployableID] = config - mu.Unlock() - }(id, deployable) - } - - wg.Wait() - return deployableConfigs, nil -} - -// buildModelResponses constructs the final ModelResponse objects -func (p *Predator) buildModelResponses( - predatorConfigs []predatorconfig.PredatorConfig, - discoveryConfigs map[int]*discoveryconfig.DiscoveryConfig, - serviceDeployables map[int]*servicedeployableconfig.ServiceDeployableConfig, - deployableConfigs map[int]externalcall.Config, -) []ModelResponse { - results := make([]ModelResponse, 0, len(predatorConfigs)) - - for _, config := range predatorConfigs { - // Get discovery config - serviceDiscovery, exists := discoveryConfigs[config.DiscoveryConfigID] - if !exists { - continue // Skip if discovery config not found - } - - // Get service deployable - serviceDeployable, exists := serviceDeployables[serviceDiscovery.ServiceDeployableID] - if !exists { - continue // Skip if service deployable not found - } - - // Parse deployable config - var deployableConfig PredatorDeployableConfig - if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err != nil { - continue // Skip if config parsing fails - } - - // Get infrastructure config (HPA/replica info) - infraConfig := deployableConfigs[serviceDiscovery.ServiceDeployableID] - - deploymentConfig := map[string]any{ - machineTypeKey: deployableConfig.MachineType, - cpuThresholdKey: deployableConfig.CPUThreshold, - gpuThresholdKey: deployableConfig.GPUThreshold, - cpuRequestKey: deployableConfig.CPURequest, - cpuLimitKey: deployableConfig.CPULimit, - memRequestKey: deployableConfig.MemoryRequest, - memLimitKey: deployableConfig.MemoryLimit, - gpuRequestKey: deployableConfig.GPURequest, - gpuLimitKey: deployableConfig.GPULimit, - minReplicaKey: infraConfig.MinReplica, - maxReplicaKey: infraConfig.MaxReplica, - nodeSelectorKey: deployableConfig.NodeSelectorValue, - tritonImageTagKey: deployableConfig.TritonImageTag, - basePathKey: deployableConfig.GCSBucketPath, - } - - modelResponse := ModelResponse{ - ID: config.ID, - ModelName: config.ModelName, - MetaData: config.MetaData, - Host: serviceDeployable.Host, - MachineType: deployableConfig.MachineType, - DeploymentConfig: deploymentConfig, - MonitoringUrl: serviceDeployable.MonitoringUrl, - GCSPath: strings.TrimSuffix(deployableConfig.GCSBucketPath, "/*"), - CreatedBy: config.CreatedBy, - CreatedAt: config.CreatedAt, - UpdatedBy: config.UpdatedBy, - UpdatedAt: config.UpdatedAt, - DeployableRunningStatus: infraConfig.RunningStatus, - TestResults: config.TestResults, - HasNilData: config.HasNilData, - SourceModelName: config.SourceModelName, - } - - results = append(results, modelResponse) - } - - return results -} - -func (p *Predator) FetchAllPredatorRequests(role, email string) ([]map[string]interface{}, error) { - var requests []predatorrequest.PredatorRequest - var err error - - if role == adminRole { - requests, err = p.Repo.GetAll() - } else { - requests, err = p.Repo.GetAllByEmail(email) - } - - if err != nil { - return nil, fmt.Errorf("error fetching predator requests: %v", err) - } - - groupedRequests := make(map[uint][]PredatorRequestResponse) - - for _, req := range requests { - var parsedPayload map[string]interface{} - if err := json.Unmarshal([]byte(req.Payload), &parsedPayload); err != nil { - return nil, fmt.Errorf("error parsing payload for request ID %d: %v", req.RequestID, err) - } - - // Initialize response with default values - requestResponse := PredatorRequestResponse{ - RequestID: req.RequestID, - GroupID: req.GroupId, - Payload: parsedPayload, - CreatedBy: req.CreatedBy, - UpdatedBy: req.UpdatedBy, - Reviewer: req.Reviewer, - RequestStage: req.RequestStage, - RequestType: req.RequestType, - Status: req.Status, - RejectReason: req.RejectReason, - CreatedAt: req.CreatedAt, - UpdatedAt: req.UpdatedAt, - IsValid: req.IsValid, - HasNilData: false, - TestResults: json.RawMessage("{}"), - } - - // Extract model name from payload and fetch predator config - // Skip predator config lookup for edit requests as models might not exist in DB yet - if modelName, ok := parsedPayload["model_name"].(string); ok && modelName != "" { - if predatorConfig, err := p.PredatorConfigRepo.GetActiveModelByModelName(modelName); err == nil { - requestResponse.HasNilData = predatorConfig.HasNilData - if predatorConfig.TestResults != nil { - requestResponse.TestResults = predatorConfig.TestResults - } - } - } - - groupedRequests[req.GroupId] = append(groupedRequests[req.GroupId], requestResponse) - } - - var response []map[string]interface{} - - for groupID, groupRequests := range groupedRequests { - groupData := map[string]interface{}{ - "group_id": groupID, - "groups": groupRequests, - } - response = append(response, groupData) - } - - return response, nil -} - -func (p *Predator) ValidateRequest(groupId string) (string, int) { - // Validate input and basic checks first (before acquiring lock) - id, err := strconv.ParseUint(groupId, 10, 32) - if err != nil { - return "Invalid request ID format", http.StatusBadRequest - } - - request, err := p.Repo.GetAllByGroupID(uint(id)) - if err != nil { - return "Request not found", http.StatusNotFound - } - - if len(request) == 0 { - return "Request Validation Failed. No requests found", http.StatusNotFound - } - - payload, err := p.processPayload(request[0]) - if err != nil { - log.Error().Err(err).Msg("Failed to parse payload for validation") - return "Request Validation Failed. Failed to parse request payload", http.StatusBadRequest - } - - // Determine test deployable ID based on machine type - testDeployableID, err := p.getTestDeployableID(payload) - if err != nil { - log.Error().Err(err).Msg("Failed to determine test deployable ID") - return "Request Validation Failed. Failed to determine test deployable ID", http.StatusInternalServerError - } - - // Create deployable-specific lock key (allows parallel processing for different deployables) - lockKey := fmt.Sprintf("validation-deployable-%d", testDeployableID) - - // Try to acquire deployable-specific distributed lock - lock, err := p.validationLockRepo.AcquireLock(lockKey, 30*time.Minute) - if err != nil { - log.Warn().Err(err).Msgf("Validation request for group ID %s rejected - failed to acquire lock for deployable %d", groupId, testDeployableID) - return fmt.Sprintf("Request Validation Failed. Another validation is already in progress for %s deployable. Please try again later.", - map[int]string{pred.TestDeployableID: "CPU", pred.TestGpuDeployableID: "GPU"}[testDeployableID]), http.StatusConflict - } - - log.Info().Msgf("Starting validation for group ID: %s on deployable %d (lock acquired by %s)", groupId, testDeployableID, lock.LockedBy) - - // Validate request status - for _, req := range request { - if req.Status == statusApproved { - p.releaseLockWithError(lock.ID, groupId, "Request already approved") - return "Request Validation Failed. Request is already approved", http.StatusBadRequest - } - if req.Status == statusRejected { - p.releaseLockWithError(lock.ID, groupId, "Request already rejected") - return "Request Validation Failed. Request is already rejected", http.StatusBadRequest - } - } - - // Get service name from deployable config - serviceName, err := p.getServiceNameFromDeployable(testDeployableID) - if err != nil { - log.Error().Err(err).Msg("Failed to get service name from deployable") - p.releaseLockWithError(lock.ID, groupId, "Failed to get service name") - return "Request Validation Failed. Failed to get service name", http.StatusInternalServerError - } - - // Create validation job - validationJob := &validationjob.Table{ - GroupID: groupId, - LockID: lock.ID, - TestDeployableID: testDeployableID, - ServiceName: serviceName, - Status: validationjob.StatusPending, - MaxHealthChecks: 15, - HealthCheckInterval: 60, - } - - if err := p.validationJobRepo.Create(validationJob); err != nil { - log.Error().Err(err).Msg("Failed to create validation job") - p.releaseLockWithError(lock.ID, groupId, "Failed to create validation job") - return "Request Validation Failed. Failed to create validation job", http.StatusInternalServerError - } - - // Start asynchronous validation process - go p.performAsyncValidation(validationJob, request, payload, testDeployableID) - - log.Info().Msgf("Validation job created for group ID: %s, job ID: %d", groupId, validationJob.ID) - return "Request Validation Started. The validation will run asynchronously and update the request status when complete.", http.StatusOK -} - -// CleanupExpiredValidationLocks removes expired validation locks -// This method can be called periodically to clean up stale locks -func (p *Predator) CleanupExpiredValidationLocks() error { - if p.validationLockRepo == nil { - return errors.New("validation lock repository not initialized") - } - - log.Info().Msg("Starting cleanup of expired validation locks") - - if err := p.validationLockRepo.CleanupExpiredLocks(); err != nil { - log.Error().Err(err).Msg("Failed to cleanup expired validation locks") - return err - } - - log.Info().Msg("Successfully cleaned up expired validation locks") - return nil -} - -// GetValidationStatus returns the current validation lock status -func (p *Predator) GetValidationStatus() (bool, *validationlock.Table, error) { - if p.validationLockRepo == nil { - return false, nil, errors.New("validation lock repository not initialized") - } - - isLocked, err := p.validationLockRepo.IsLocked(validationlock.ValidationLockKey) - if err != nil { - return false, nil, err - } - - if !isLocked { - return false, nil, nil - } - - activeLock, err := p.validationLockRepo.GetActiveLock(validationlock.ValidationLockKey) - if err != nil { - return false, nil, err - } - - return true, activeLock, nil -} - -// releaseLockWithError is a helper function to release lock and log error -func (p *Predator) releaseLockWithError(lockID uint, groupID, errorMsg string) { - if releaseErr := p.validationLockRepo.ReleaseLock(lockID); releaseErr != nil { - log.Error().Err(releaseErr).Msgf("Failed to release validation lock for group ID %s after error: %s", groupID, errorMsg) - } - log.Error().Msgf("Validation failed for group ID %s: %s", groupID, errorMsg) -} - -// markModelWithNilData marks a model as having nil data issues - -// getTestDeployableID determines the appropriate test deployable ID based on machine type -func (p *Predator) getTestDeployableID(payload *Payload) (int, error) { - // Get the target deployable ID from the request - targetDeployableID := int(payload.ConfigMapping.ServiceDeployableID) - // Fetch the service deployable config to check machine type - serviceDeployable, err := p.ServiceDeployableRepo.GetById(targetDeployableID) - if err != nil { - return 0, fmt.Errorf("failed to fetch service deployable config: %w", err) - } - - // Parse the deployable config to extract machine type - var deployableConfig PredatorDeployableConfig - if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err != nil { - return 0, fmt.Errorf("failed to parse service deployable config: %w", err) - } - - // Select test deployable ID based on machine type - switch strings.ToUpper(deployableConfig.MachineType) { - case "CPU": - log.Info().Msgf("Using CPU test deployable ID: %d", pred.TestDeployableID) - return pred.TestDeployableID, nil - case "GPU": - log.Info().Msgf("Using GPU test deployable ID: %d", pred.TestGpuDeployableID) - return pred.TestGpuDeployableID, nil - default: - // Default to CPU if machine type is not specified or unknown - log.Warn().Msgf("Unknown machine type '%s', defaulting to CPU test deployable ID: %d", - deployableConfig.MachineType, pred.TestDeployableID) - return pred.TestDeployableID, nil - } -} - -// getServiceNameFromDeployable extracts service name from deployable configuration -func (p *Predator) getServiceNameFromDeployable(deployableID int) (string, error) { - serviceDeployable, err := p.ServiceDeployableRepo.GetById(deployableID) - if err != nil { - return "", fmt.Errorf("failed to get deployable config: %w", err) - } - return serviceDeployable.Name, nil -} - -// performAsyncValidation performs the actual validation process asynchronously -func (p *Predator) performAsyncValidation(job *validationjob.Table, requests []predatorrequest.PredatorRequest, payload *Payload, testDeployableID int) { - defer func() { - // Always release the lock when validation completes - if releaseErr := p.validationLockRepo.ReleaseLock(job.LockID); releaseErr != nil { - log.Error().Err(releaseErr).Msgf("Failed to release validation lock for job %d", job.ID) - } - log.Info().Msgf("Released validation lock for job %d", job.ID) - }() - - log.Info().Msgf("Starting async validation for job %d, group %s", job.ID, job.GroupID) - - // Step 1: Clear temporary deployable - if err := p.clearTemporaryDeployable(testDeployableID); err != nil { - log.Error().Err(err).Msg("Failed to clear temporary deployable") - p.failValidationJob(job.ID, "Failed to clear temporary deployable: "+err.Error()) - return - } - - // Step 2: Copy existing models to temporary deployable - targetDeployableID := int(payload.ConfigMapping.ServiceDeployableID) - if err := p.copyExistingModelsToTemporary(targetDeployableID, testDeployableID); err != nil { - log.Error().Err(err).Msg("Failed to copy existing models to temporary deployable") - p.failValidationJob(job.ID, "Failed to copy existing models: "+err.Error()) - return - } - - // Step 3: Copy new models from request to temporary deployable - if err := p.copyRequestModelsToTemporary(requests, testDeployableID); err != nil { - log.Error().Err(err).Msg("Failed to copy request models to temporary deployable") - p.failValidationJob(job.ID, "Failed to copy request models: "+err.Error()) - return - } - - // Step 4: Restart temporary deployable - if err := p.restartTemporaryDeployable(testDeployableID); err != nil { - log.Error().Err(err).Msg("Failed to restart temporary deployable") - p.failValidationJob(job.ID, "Failed to restart temporary deployable: "+err.Error()) - return - } - - // Update job status to checking and record restart time - now := time.Now() - if err := p.validationJobRepo.UpdateStatus(job.ID, validationjob.StatusChecking, ""); err != nil { - log.Error().Err(err).Msgf("Failed to update job %d status to checking", job.ID) - } - - // Update restart time in the job - job.RestartedAt = &now - job.Status = validationjob.StatusChecking - - // Step 5: Start health checking process - p.startHealthCheckingProcess(job) -} - -// startHealthCheckingProcess monitors the deployment health and updates validation status -func (p *Predator) startHealthCheckingProcess(job *validationjob.Table) { - log.Info().Msgf("Starting health check process for job %d, service %s", job.ID, job.ServiceName) - - for job.HealthCheckCount < job.MaxHealthChecks { - // Wait for the specified interval before checking - time.Sleep(time.Duration(job.HealthCheckInterval) * time.Second) - - // Increment health check count - if err := p.validationJobRepo.IncrementHealthCheck(job.ID); err != nil { - log.Error().Err(err).Msgf("Failed to increment health check count for job %d", job.ID) - } - job.HealthCheckCount++ - - // Check deployment health using infrastructure handler - isHealthy, err := p.checkDeploymentHealth(job.ServiceName) - if err != nil { - log.Error().Err(err).Msgf("Failed to check deployment health for job %d", job.ID) - continue // Continue checking, don't fail immediately on health check errors - } - - if isHealthy { - log.Info().Msgf("Deployment is healthy for job %d, validation successful", job.ID) - p.completeValidationJob(job.ID, true, "Deployment is healthy and running successfully") - p.updateRequestValidationStatus(job.GroupID, true) - return - } - - log.Info().Msgf("Deployment not yet healthy for job %d, check %d/%d", job.ID, job.HealthCheckCount, job.MaxHealthChecks) - } - - // If we reach here, max health checks exceeded - log.Warn().Msgf("Max health checks exceeded for job %d, marking as failed", job.ID) - p.completeValidationJob(job.ID, false, fmt.Sprintf("Deployment failed to become healthy after %d checks", job.MaxHealthChecks)) - p.updateRequestValidationStatus(job.GroupID, false) -} - -// checkDeploymentHealth checks if the deployment is healthy using infrastructure handler -func (p *Predator) checkDeploymentHealth(serviceName string) (bool, error) { - resourceDetail, err := p.infrastructureHandler.GetResourceDetail(serviceName, p.workingEnv) - if err != nil { - return false, fmt.Errorf("failed to get resource detail: %w", err) - } - - if resourceDetail == nil || len(resourceDetail.Nodes) == 0 { - return false, nil - } - - healthyPodCount := 0 - totalPodCount := 0 - - for _, node := range resourceDetail.Nodes { - if node.Kind == "Deployment" { - totalPodCount++ - if node.Health.Status == "Healthy" { - healthyPodCount++ - } - } - } - - log.Info().Msgf("Health check for service %s: %d/%d pods healthy and running", serviceName, healthyPodCount, totalPodCount) - - if totalPodCount == healthyPodCount { - return true, nil - } - // Consider deployment healthy if at least one pod is healthy and running - return false, nil -} - -// failValidationJob marks a validation job as failed -func (p *Predator) failValidationJob(jobID uint, errorMessage string) { - if err := p.validationJobRepo.UpdateValidationResult(jobID, false, errorMessage); err != nil { - log.Error().Err(err).Msgf("Failed to update validation job %d as failed", jobID) - } -} - -// completeValidationJob marks a validation job as completed -func (p *Predator) completeValidationJob(jobID uint, success bool, message string) { - if err := p.validationJobRepo.UpdateValidationResult(jobID, success, message); err != nil { - log.Error().Err(err).Msgf("Failed to update validation job %d as completed", jobID) - } -} - -// updateRequestValidationStatus updates the request table with validation results -func (p *Predator) updateRequestValidationStatus(groupID string, success bool) { - id, err := strconv.ParseUint(groupID, 10, 32) - if err != nil { - log.Error().Err(err).Msgf("Failed to parse group ID %s for status update", groupID) - return - } - - requests, err := p.Repo.GetAllByGroupID(uint(id)) - if err != nil { - log.Error().Err(err).Msgf("Failed to get requests for group ID %s", groupID) - return - } - - // Update all requests in the group - for _, request := range requests { - request.UpdatedAt = time.Now() - request.IsValid = success - if !success { - request.RejectReason = "Validation Failed" - request.Status = statusRejected - request.UpdatedBy = "Validation Job" - request.UpdatedAt = time.Now() - request.Active = false - } - if err := p.Repo.Update(&request); err != nil { - log.Error().Err(err).Msgf("Failed to update request %d status", request.RequestID) - } else { - log.Info().Msgf("Updated request %d status to %s", request.RequestID, request.Status) - } - } -} - -// GetValidationJobStatus returns the status of a validation job for a given group ID -func (p *Predator) GetValidationJobStatus(groupId string) (*validationjob.Table, error) { - if p.validationJobRepo == nil { - return nil, errors.New("validation job repository not initialized") - } - - job, err := p.validationJobRepo.GetByGroupID(groupId) - if err != nil { - return nil, fmt.Errorf("failed to get validation job for group %s: %w", groupId, err) - } + job, err := p.validationJobRepo.GetByGroupID(groupId) + if err != nil { + return nil, fmt.Errorf("failed to get validation job for group %s: %w", groupId, err) + } return job, nil } -// clearTemporaryDeployable clears all models from the temporary deployable GCS path -func (p *Predator) clearTemporaryDeployable(testDeployableID int) error { - // Get temporary deployable config - testServiceDeployable, err := p.ServiceDeployableRepo.GetById(testDeployableID) - if err != nil { - return fmt.Errorf("failed to fetch temporary service deployable: %w", err) - } - - var tempDeployableConfig PredatorDeployableConfig - if err := json.Unmarshal(testServiceDeployable.Config, &tempDeployableConfig); err != nil { - return fmt.Errorf("failed to parse temporary deployable config: %w", err) - } - - if tempDeployableConfig.GCSBucketPath != "NA" { - // Extract bucket and path from temporary deployable config - tempBucket, tempPath := extractGCSPath(strings.TrimSuffix(tempDeployableConfig.GCSBucketPath, "/*")) - - // Clear all models from temporary deployable - log.Info().Msgf("Clearing temporary deployable GCS path: gs://%s/%s", tempBucket, tempPath) - if err := p.GcsClient.DeleteFolder(tempBucket, tempPath, ""); err != nil { - return fmt.Errorf("failed to clear temporary deployable GCS path: %w", err) - } - } - - return nil -} - -// copyExistingModelsToTemporary copies all existing models from target deployable to temporary deployable -func (p *Predator) copyExistingModelsToTemporary(targetDeployableID, tempDeployableID int) error { - // Get target deployable config - targetServiceDeployable, err := p.ServiceDeployableRepo.GetById(targetDeployableID) - if err != nil { - return fmt.Errorf("failed to fetch target service deployable: %w", err) - } - - var targetDeployableConfig PredatorDeployableConfig - if err := json.Unmarshal(targetServiceDeployable.Config, &targetDeployableConfig); err != nil { - return fmt.Errorf("failed to parse target deployable config: %w", err) - } - - // Get temporary deployable config - tempServiceDeployable, err := p.ServiceDeployableRepo.GetById(tempDeployableID) - if err != nil { - return fmt.Errorf("failed to fetch temporary service deployable: %w", err) - } - - var tempDeployableConfig PredatorDeployableConfig - if err := json.Unmarshal(tempServiceDeployable.Config, &tempDeployableConfig); err != nil { - return fmt.Errorf("failed to parse temporary deployable config: %w", err) - } - - if targetDeployableConfig.GCSBucketPath != "NA" { - // Extract GCS paths - targetBucket, targetPath := extractGCSPath(strings.TrimSuffix(targetDeployableConfig.GCSBucketPath, "/*")) - tempBucket, tempPath := extractGCSPath(strings.TrimSuffix(tempDeployableConfig.GCSBucketPath, "/*")) - - // Copy all existing models from target to temporary deployable - return p.copyAllModelsFromActualToStaging(targetBucket, targetPath, tempBucket, tempPath) - } else { - return nil - } -} - -// copyRequestModelsToTemporary copies the requested models to temporary deployable -func (p *Predator) copyRequestModelsToTemporary(requests []predatorrequest.PredatorRequest, tempDeployableID int) error { - // Get temporary deployable config - tempServiceDeployable, err := p.ServiceDeployableRepo.GetById(tempDeployableID) - if err != nil { - return fmt.Errorf("failed to fetch temporary service deployable: %w", err) - } - - var tempDeployableConfig PredatorDeployableConfig - if err := json.Unmarshal(tempServiceDeployable.Config, &tempDeployableConfig); err != nil { - return fmt.Errorf("failed to parse temporary deployable config: %w", err) - } - - tempBucket, tempPath := extractGCSPath(strings.TrimSuffix(tempDeployableConfig.GCSBucketPath, "/*")) - - isNotProd := p.isNonProductionEnvironment() - - // Copy each requested model from default GCS location to temporary deployable - for _, request := range requests { - modelName := request.ModelName - payload, err := p.processPayload(request) - if err != nil { - log.Error().Err(err).Msgf("Failed to parse payload for request %d", request.RequestID) - return fmt.Errorf("failed to parse payload for request %d: %w", request.RequestID, err) - } - - var sourceBucket, sourcePath, sourceModelName string - if payload.ModelSource != "" { - sourceBucket, sourcePath, sourceModelName = extractGCSDetails(payload.ModelSource) - log.Info().Msgf("Using ModelSource from payload for validation: gs://%s/%s/%s", - sourceBucket, sourcePath, sourceModelName) - } else { - sourceBucket = pred.GcsModelBucket - sourcePath = pred.GcsModelBasePath - sourceModelName = modelName - log.Info().Msgf("Using default model source for validation: gs://%s/%s/%s", - sourceBucket, sourcePath, sourceModelName) - } - log.Info().Msgf("Copying model %s from gs://%s/%s/%s to temporary deployable gs://%s/%s", - modelName, sourceBucket, sourcePath, sourceModelName, tempBucket, tempPath) - - if isNotProd { - if err := p.GcsClient.TransferFolder(sourceBucket, sourcePath, sourceModelName, - tempBucket, tempPath, modelName); err != nil { - return fmt.Errorf("failed to copy requested model %s to temporary deployable: %w", modelName, err) - } - } else { - if err := p.GcsClient.TransferFolderWithSplitSources( - sourceBucket, sourcePath, pred.GcsConfigBucket, pred.GcsConfigBasePath, - sourceModelName, tempBucket, tempPath, modelName, - ); err != nil { - return fmt.Errorf("failed to copy requested model %s to temporary deployable: %w", modelName, err) - } - } - - log.Info().Msgf("Successfully copied model %s to temporary deployable", modelName) - } - - return nil -} - -// restartTemporaryDeployable restarts the temporary deployable for validation -func (p *Predator) restartTemporaryDeployable(tempDeployableID int) error { - tempServiceDeployable, err := p.ServiceDeployableRepo.GetById(tempDeployableID) - if err != nil { - return fmt.Errorf("failed to fetch temporary service deployable: %w", err) - } - - // Extract isCanary from deployable config - var deployableConfig map[string]interface{} - isCanary := false - if err := json.Unmarshal(tempServiceDeployable.Config, &deployableConfig); err == nil { - if strategy, ok := deployableConfig["deploymentStrategy"].(string); ok && strategy == "canary" { - isCanary = true - } - } - if err := p.infrastructureHandler.RestartDeployment(tempServiceDeployable.Name, p.workingEnv, isCanary); err != nil { - return fmt.Errorf("failed to restart temporary deployable: %w", err) - } - - log.Info().Msgf("Successfully restarted temporary deployable: %s for validation", tempServiceDeployable.Name) - return nil -} - -// convertDimsToIntSlice converts input.Dims to []int, handling nested interfaces and various types -// Dynamic dimensions (-1) are replaced with reasonable default values for test data generation -func convertDimsToIntSlice(dims interface{}) ([]int, error) { - var result []int - - switch v := dims.(type) { - case []int: - result = make([]int, len(v)) - copy(result, v) - case []int64: - result = make([]int, len(v)) - for i, dim := range v { - result[i] = int(dim) - } - case []interface{}: - result = make([]int, len(v)) - for i, dim := range v { - switch d := dim.(type) { - case int: - result[i] = d - case int64: - result[i] = int(d) - case float64: - result[i] = int(d) - default: - return nil, fmt.Errorf("unsupported dimension type in slice: %T", d) - } - } - case int: - result = []int{v} - case int64: - result = []int{int(v)} - case float64: - result = []int{int(v)} - default: - return nil, fmt.Errorf("unsupported dims type: %T", v) - } - - // Replace dynamic dimensions (-1) with reasonable default values for test data generation - for i, dim := range result { - if dim == -1 { - // Use different default sizes based on position - if i == 0 { - result[i] = 10 // First dimension (often sequence length): default to 10 - } else { - result[i] = 128 // Other dimensions: default to 128 - } - log.Debug().Msgf("Replaced dynamic dimension -1 at position %d with %d", i, result[i]) - } else if dim < 0 { - // Handle any other negative dimensions - result[i] = 1 - log.Debug().Msgf("Replaced negative dimension %d at position %d with 1", dim, i) - } - } - - return result, nil -} func (p *Predator) GenerateFunctionalTestRequest(req RequestGenerationRequest) (RequestGenerationResponse, error) { @@ -2520,186 +833,16 @@ func (p *Predator) GenerateFunctionalTestRequest(req RequestGenerationRequest) ( newshape[i+1] = int64(dim) } - response.RequestBody.Inputs = append(response.RequestBody.Inputs, Input{ - Name: input.Name, - Dims: newshape, - DataType: input.DataType, - Data: data, - Features: input.Features, - }) - } - - return response, nil -} - -// flattenInputTo3DByteSlice converts input data to 3D byte slice format [batch][feature][bytes] -// This matches the working adapter's data structure expectations -func (p *Predator) flattenInputTo3DByteSlice(data any, dataType string) ([][][]byte, error) { - // The input data comes as nested arrays [batch_size][feature_count] - // For FP16: each feature is a single float32 value converted to 2 bytes - // We need to convert this to [batch][feature][bytes] format exactly like the working adapter - - switch v := data.(type) { - case [][]float32: - // 2D array of float32 values [batch_size][feature_count] - batchSize := len(v) - if batchSize == 0 { - return [][][]byte{}, nil - } - featureCount := len(v[0]) - - result := make([][][]byte, batchSize) - for batchIdx := 0; batchIdx < batchSize; batchIdx++ { - result[batchIdx] = make([][]byte, featureCount) - for featureIdx := 0; featureIdx < featureCount; featureIdx++ { - val := v[batchIdx][featureIdx] - switch dataType { - case "FP16": - fp16Bytes, err := serializer.Float32ToFloat16Bytes(val) - if err != nil { - return nil, err - } - result[batchIdx][featureIdx] = fp16Bytes - case "FP32": - bytes := make([]byte, 4) - binary.LittleEndian.PutUint32(bytes, math.Float32bits(val)) - result[batchIdx][featureIdx] = bytes - default: - return nil, fmt.Errorf("unsupported numeric type %s for float32 data", dataType) - } - } - } - return result, nil - - default: - // Fallback: try to flatten and reshape based on expected structure - flattened, err := serializer.FlattenMatrixByType(data, dataType) - if err != nil { - return nil, err - } - - switch dataType { - case "FP16": - if f32slice, ok := flattened.([]float32); ok { - // We need to infer the batch structure from the input data - // For now, assume it matches the shape from the input tensor - // This is a fallback - the main case should handle [][]float32 - batchSize := 1 - featureCount := len(f32slice) - - result := make([][][]byte, batchSize) - result[0] = make([][]byte, featureCount) - for i, val := range f32slice { - fp16Bytes, err := serializer.Float32ToFloat16Bytes(val) - if err != nil { - return nil, err - } - result[0][i] = fp16Bytes - } - return result, nil - } - case "BYTES": - if byteSlice, ok := flattened.([][]byte); ok { - // For BYTES, each element is a separate feature - result := make([][][]byte, 1) - result[0] = byteSlice - return result, nil - } - } - - return nil, fmt.Errorf("unsupported data format: %T for type %s", data, dataType) - } -} - -// getElementSize returns the byte size of a single element for the given data type -func getElementSize(dataType string) int { - switch strings.ToUpper(dataType) { - case "FP32", "TYPE_FP32": - return 4 - case "FP64", "TYPE_FP64": - return 8 - case "INT32", "TYPE_INT32": - return 4 - case "INT64", "TYPE_INT64": - return 8 - case "INT16", "TYPE_INT16": - return 2 - case "INT8", "TYPE_INT8": - return 1 - case "UINT32", "TYPE_UINT32": - return 4 - case "UINT64", "TYPE_UINT64": - return 8 - case "UINT16", "TYPE_UINT16": - return 2 - case "UINT8", "TYPE_UINT8": - return 1 - case "BOOL", "TYPE_BOOL": - return 1 - case "FP16", "TYPE_FP16": - return 2 - default: - return 0 // Unknown type - } -} - -// reshapeDataForBatch reshapes flattened data to preserve batch dimension -func reshapeDataForBatch(data interface{}, dims []int64) interface{} { - if len(dims) == 0 { - return data - } - - batchSize := dims[0] - featureDims := dims[1:] - - // Calculate elements per batch - elementsPerBatch := int64(1) - for _, dim := range featureDims { - elementsPerBatch *= dim - } - - // Convert data to slice if it isn't already - var dataSlice []interface{} - switch v := data.(type) { - case []interface{}: - dataSlice = v - case []string: - for _, item := range v { - dataSlice = append(dataSlice, item) - } - case []float32: - for _, item := range v { - dataSlice = append(dataSlice, item) - } - case []float64: - for _, item := range v { - dataSlice = append(dataSlice, item) - } - case []int32: - for _, item := range v { - dataSlice = append(dataSlice, item) - } - case []int64: - for _, item := range v { - dataSlice = append(dataSlice, item) - } - default: - // If we can't convert, return as-is - return data - } - - // Reshape into batches - var result [][]interface{} - for i := int64(0); i < batchSize; i++ { - start := i * elementsPerBatch - end := start + elementsPerBatch - if end <= int64(len(dataSlice)) { - batch := dataSlice[start:end] - result = append(result, batch) - } + response.RequestBody.Inputs = append(response.RequestBody.Inputs, Input{ + Name: input.Name, + Dims: newshape, + DataType: input.DataType, + Data: data, + Features: input.Features, + }) } - return result + return response, nil } func (p *Predator) ExecuteFunctionalTestRequest(req ExecuteRequestFunctionalRequest) (ExecuteRequestFunctionalResponse, error) { @@ -3491,235 +1634,6 @@ func (p *Predator) UploadModelFolderFromLocal(req UploadModelFolderRequest, isPa }, statusCode, nil } -// uploadSingleModel processes a single model upload with improved validation and error handling -func (p *Predator) uploadSingleModel(modelItem ModelUploadItem, bucket, basePath string, isPartial bool, authToken string) ModelUploadResult { - // Step 1: Extract and validate model name - modelName, err := p.extractModelName(modelItem.Metadata) - if err != nil { - return p.createErrorResult("unknown", "Failed to extract model name", err) - } - - log.Info().Msgf("Processing %s upload for model: %s from %s", - map[bool]string{true: "partial", false: "full"}[isPartial], modelName, modelItem.GCSPath) - - // Step 2: Setup destination paths - destPath := path.Join(basePath, modelName) - fullGCSPath := fmt.Sprintf("gs://%s/%s", bucket, destPath) - - // Step 3: Validate upload prerequisites - if err := p.validateUploadPrerequisites(bucket, destPath, isPartial, modelName); err != nil { - return p.createErrorResult(modelName, "Upload prerequisites validation failed", err) - } - - // Step 4: Validate source model structure and configuration - if err := p.validateSourceModel(modelItem.GCSPath, isPartial); err != nil { - return p.createErrorResult(modelName, "Source model validation failed", err) - } - - // Step 5: Validate metadata features (after model structure validation) - if err := p.validateMetadataFeatures(modelItem.Metadata, authToken); err != nil { - return p.createErrorResult(modelName, "Feature validation failed", err) - } - - // Step 6: Download/sync model files based on upload type - if err := p.syncModelFiles(modelItem.GCSPath, bucket, destPath, modelName, isPartial); err != nil { - return p.createErrorResult(modelName, "Model file sync failed", err) - } - - // Step 7: Copy config.pbtxt to prod config source (only in production) - if err := p.copyConfigToProdConfigSource(modelItem.GCSPath, modelName); err != nil { - return p.createErrorResult(modelName, "Failed to copy config to prod config source", err) - } - - // Upload processed metadata.json (always done regardless of partial/full) - metadataPath, err := p.uploadModelMetadata(modelItem.Metadata, bucket, destPath) - if err != nil { - return p.createErrorResult(modelName, "Metadata upload failed", err) - } - - log.Info().Msgf("Successfully completed %s upload for model: %s", - map[bool]string{true: "partial", false: "full"}[isPartial], modelName) - return ModelUploadResult{ - ModelName: modelName, - GCSPath: fullGCSPath, - MetadataPath: metadataPath, - Status: "success", - } -} - -// copyConfigToProdConfigSource copies config.pbtxt to the prod config source path -// This is done in both int and prd environments so config is available for prod deployments -func (p *Predator) copyConfigToProdConfigSource(gcsPath, modelName string) error { - // Check if config source is configured - if pred.GcsConfigBucket == "" || pred.GcsConfigBasePath == "" { - log.Warn().Msg("Config source not configured, skipping config.pbtxt copy to config source") - return nil - } - - // Parse source GCS path - srcBucket, srcPath := extractGCSPath(gcsPath) - if srcBucket == "" || srcPath == "" { - return fmt.Errorf("invalid GCS path format: %s", gcsPath) - } - - // Read config.pbtxt from source - srcConfigPath := path.Join(srcPath, configFile) - configData, err := p.GcsClient.ReadFile(srcBucket, srcConfigPath) - if err != nil { - return fmt.Errorf("failed to read config.pbtxt from source: %w", err) - } - - // Update model name while preserving formatting - updatedConfigData := p.replaceModelNameInConfigPreservingFormat(configData, modelName) - - // Upload to prod config source path with updated model name - destConfigPath := path.Join(pred.GcsConfigBasePath, modelName, configFile) - if err := p.GcsClient.UploadFile(pred.GcsConfigBucket, destConfigPath, updatedConfigData); err != nil { - return fmt.Errorf("failed to upload config.pbtxt to config source: %w", err) - } - - log.Info().Msgf("Successfully copied config.pbtxt to config source with model name %s: gs://%s/%s", - modelName, pred.GcsConfigBucket, destConfigPath) - return nil -} - -// Helper functions for simplified upload flow - -// createErrorResult creates a standardized error result -func (p *Predator) createErrorResult(modelName, message string, err error) ModelUploadResult { - return ModelUploadResult{ - ModelName: modelName, - Status: "error", - Error: fmt.Sprintf("%s: %v", message, err), - } -} - -// generateUploadSummary creates response message and status code based on results -func (p *Predator) generateUploadSummary(successCount, failCount int, results []ModelUploadResult) (string, int) { - switch { - case failCount == 0: - return fmt.Sprintf("%d model uploaded successfully", successCount), http.StatusOK - case successCount == 0: - return fmt.Sprintf("%d model failed to upload. Errors: %s", failCount, results[0].Error), http.StatusBadRequest - default: - return fmt.Sprintf("Mixed results: %d successful, %d failed. Errors: %s", successCount, failCount, results[0].Error), http.StatusPartialContent - } -} - -// validateUploadPrerequisites validates upload requirements based on type -func (p *Predator) validateUploadPrerequisites(bucket, destPath string, isPartial bool, modelName string) error { - exists, err := p.GcsClient.CheckFolderExists(bucket, destPath) - if err != nil { - return fmt.Errorf("failed to check model existence: %w", err) - } - - if isPartial { - // Partial upload requires existing model - if !exists { - return fmt.Errorf("partial upload requires existing model folder at destination") - } - log.Info().Msgf("Partial upload: updating existing model %s", modelName) - } else { - // Full upload can create new or replace existing - if exists { - log.Info().Msgf("Full upload: replacing existing model %s", modelName) - } else { - log.Info().Msgf("Full upload: creating new model %s", modelName) - } - } - - return nil -} - -// validateSourceModel validates the source model structure and configuration -func (p *Predator) validateSourceModel(gcsPath string, isPartial bool) error { - // Parse GCS path - srcBucket, srcPath := extractGCSPath(gcsPath) - if srcBucket == "" || srcPath == "" { - return fmt.Errorf("invalid GCS path format: %s", gcsPath) - } - - // Always validate config.pbtxt (required for both partial and full) - if err := p.validateModelConfiguration(gcsPath); err != nil { - return fmt.Errorf("config.pbtxt validation failed: %w", err) - } - - if !isPartial { - // For full upload, validate complete model structure - if err := p.validateCompleteModelStructure(srcBucket, srcPath); err != nil { - return fmt.Errorf("complete model structure validation failed: %w", err) - } - } - - return nil -} - -// validateCompleteModelStructure validates that version "1" folder exists with non-empty files -// Note: config.pbtxt is already validated above -func (p *Predator) validateCompleteModelStructure(srcBucket, srcPath string) error { - // Check if version "1" folder exists - versionPath := path.Join(srcPath, "1") - exists, err := p.GcsClient.CheckFolderExists(srcBucket, versionPath) - if err != nil { - return fmt.Errorf("failed to check version folder 1/: %w", err) - } - - if !exists { - return fmt.Errorf("version folder 1/ not found - required for complete model") - } - - // Check if version "1" folder has at least one non-empty file - if err := p.validateVersionHasFiles(srcBucket, versionPath); err != nil { - return fmt.Errorf("version folder 1/ validation failed: %w", err) - } - - log.Info().Msgf("Model structure validation passed - version 1/ folder exists with files") - return nil -} - -// validateVersionHasFiles checks if version folder has at least one non-empty file -func (p *Predator) validateVersionHasFiles(srcBucket, versionPath string) error { - // Simply check if the version folder exists and has any content - // CheckFolderExists returns true if there are any objects with the given prefix - exists, err := p.GcsClient.CheckFolderExists(srcBucket, versionPath) - if err != nil { - return fmt.Errorf("failed to check version folder contents: %w", err) - } - - if !exists { - return fmt.Errorf("version folder 1/ is empty - must contain model files") - } - - log.Info().Msgf("Version folder 1/ contains files") - return nil -} - -// syncModelFiles handles file synchronization based on upload type -func (p *Predator) syncModelFiles(gcsPath, destBucket, destPath, modelName string, isPartial bool) error { - if isPartial { - // Partial upload: only sync config.pbtxt - return p.syncPartialFiles(gcsPath, destBucket, destPath, modelName) - } else { - // Full upload: sync everything - return p.syncFullModel(gcsPath, destBucket, destPath, modelName) - } -} - -// uploadModelMetadata uploads metadata.json to GCS and returns the full path -func (p *Predator) uploadModelMetadata(metadata interface{}, bucket, destPath string) (string, error) { - metadataBytes, err := json.Marshal(metadata) - if err != nil { - return "", fmt.Errorf("failed to serialize metadata: %w", err) - } - - metadataPath := path.Join(destPath, "metadata.json") - if err := p.GcsClient.UploadFile(bucket, metadataPath, metadataBytes); err != nil { - return "", fmt.Errorf("failed to upload metadata: %w", err) - } - - return fmt.Sprintf("gs://%s/%s", bucket, metadataPath), nil -} - // Legacy functions for backward compatibility func (p *Predator) CheckModelExists(bucket, path string) (bool, error) { return p.GcsClient.CheckFolderExists(bucket, path) @@ -3728,412 +1642,3 @@ func (p *Predator) CheckModelExists(bucket, path string) (bool, error) { func (p *Predator) UploadFileToGCS(bucket, path string, data []byte) error { return p.GcsClient.UploadFile(bucket, path, data) } - -// validateMetadataFeatures validates the features in metadata against online/offline validation APIs -func (p *Predator) validateMetadataFeatures(metadata interface{}, authToken string) error { - // Parse metadata to extract features - metadataBytes, err := json.Marshal(metadata) - if err != nil { - return fmt.Errorf("failed to marshal metadata: %w", err) - } - - var featureMeta FeatureMetadata - if err := json.Unmarshal(metadataBytes, &featureMeta); err != nil { - return fmt.Errorf("failed to unmarshal metadata: %w", err) - } - - // Validate that auth token is provided - if authToken == "" { - return fmt.Errorf("authorization token is required for feature validation") - } - - // Group features by validation type - onlineFeaturesByEntity := make(map[string][]string) - pricingFeaturesByEntity := make(map[string][]string) - var offlineFeatures []string - - for _, input := range featureMeta.Inputs { - for _, feature := range input.Features { - featureType, entity, gf, featureName, isValid := externalcall.ParseFeatureString(feature) - if !isValid { - log.Error().Msgf("Invalid feature format: %s", feature) - return fmt.Errorf("invalid feature format: %s", feature) - } - - switch featureType { - case "ONLINE_FEATURE", "PARENT_ONLINE_FEATURE": - // Validate using online validation API - onlineFeaturesByEntity[entity] = append(onlineFeaturesByEntity[entity], gf) - log.Info().Msgf("Added online feature for validation - entity: %s, feature: %s", entity, gf) - - case "OFFLINE_FEATURE", "PARENT_OFFLINE_FEATURE": - // Validate using offline validation API - offlineFeatures = append(offlineFeatures, featureName) - log.Info().Msgf("Added offline feature for validation: %s", featureName) - - case "RTP_FEATURE", "PARENT_RTP_FEATURE": - // Validate using pricing service API - store full entity:feature_group:feature format - fullFeature := entity + ":" + gf // entity:feature_group:feature - pricingFeaturesByEntity[entity] = append(pricingFeaturesByEntity[entity], fullFeature) - log.Info().Msgf("Added pricing feature for validation - entity: %s, full feature: %s", entity, fullFeature) - - case "DEFAULT_FEATURE", "PARENT_DEFAULT_FEATURE", "MODEL_FEATURE", "CALIBRATION": - // These feature types don't need API validation - they are correct by default - log.Info().Msgf("Skipping API validation for feature type %s: %s (no validation required)", featureType, feature) - continue - - default: - log.Warn().Msgf("Unknown feature type %s for feature: %s", featureType, feature) - } - } - } - - // Validate online features - for entity, features := range onlineFeaturesByEntity { - if err := p.validateOnlineFeatures(entity, features, authToken); err != nil { - return fmt.Errorf("online feature validation failed for entity %s: %w", entity, err) - } - } - - // Validate offline features - if len(offlineFeatures) > 0 { - if err := p.validateOfflineFeatures(offlineFeatures, authToken); err != nil { - return fmt.Errorf("offline feature validation failed: %w", err) - } - } - - // Validate pricing features - for entity, features := range pricingFeaturesByEntity { - if err := p.validatePricingFeatures(entity, features); err != nil { - return fmt.Errorf("pricing feature validation failed for entity %s: %w", entity, err) - } - } - - return nil -} - -// validateOnlineFeatures validates online features for a specific entity -func (p *Predator) validateOnlineFeatures(entity string, features []string, token string) error { - response, err := p.featureValidationClient.ValidateOnlineFeatures(entity, token) - if err != nil { - return fmt.Errorf("failed to call online validation API: %w", err) - } - - // Check if all features exist in the response - for _, feature := range features { - if !externalcall.ValidateFeatureExists(feature, response) { - return fmt.Errorf("online feature '%s' does not exist for entity '%s'", feature, entity) - } - } - - log.Info().Msgf("Successfully validated %d online features for entity %s", len(features), entity) - return nil -} - -// validateOfflineFeatures validates offline features by checking online mapping -func (p *Predator) validateOfflineFeatures(features []string, token string) error { - response, err := p.featureValidationClient.ValidateOfflineFeatures(features, token) - if err != nil { - return fmt.Errorf("failed to call offline validation API: %w", err) - } - - if response.Error != "" { - return fmt.Errorf("offline validation API returned error: %s", response.Error) - } - - // Check if all offline features have online mappings - for _, feature := range features { - if _, exists := response.Data[feature]; !exists { - return fmt.Errorf("offline feature '%s' does not have an online mapping", feature) - } - } - - log.Info().Msgf("Successfully validated %d offline features", len(features)) - return nil -} - -// validatePricingFeatures validates pricing features for a specific entity -func (p *Predator) validatePricingFeatures(entity string, features []string) error { - if !pred.IsMeeshoEnabled { - return nil - } - response, err := externalcall.PricingClient.GetDataTypes(entity) - if err != nil { - return fmt.Errorf("failed to call pricing service API: %w", err) - } - - // Check if all features exist in the response - for _, feature := range features { - if !externalcall.ValidatePricingFeatureExists(feature, response) { - return fmt.Errorf("pricing feature '%s' does not exist for entity '%s'", feature, entity) - } - } - - log.Info().Msgf("Successfully validated %d pricing features for entity %s", len(features), entity) - return nil -} - -// extractModelName extracts model name from metadata -func (p *Predator) extractModelName(metadata interface{}) (string, error) { - // Parse metadata to extract model name - metadataBytes, err := json.Marshal(metadata) - if err != nil { - return "", fmt.Errorf("failed to marshal metadata: %w", err) - } - - var metadataMap map[string]interface{} - if err := json.Unmarshal(metadataBytes, &metadataMap); err != nil { - return "", fmt.Errorf("failed to unmarshal metadata: %w", err) - } - - modelName, exists := metadataMap["model_name"] - if !exists { - return "", fmt.Errorf("model_name not found in metadata") - } - - modelNameStr, ok := modelName.(string) - if !ok || modelNameStr == "" { - return "", fmt.Errorf("model_name must be a non-empty string") - } - - return modelNameStr, nil -} - -// syncFullModel syncs all model files for full upload -func (p *Predator) syncFullModel(gcsPath, destBucket, destPath, modelName string) error { - log.Info().Msgf("Syncing full model from GCS path: %s", gcsPath) - - // Parse the GCS path to extract bucket and object path - srcBucket, srcPath := extractGCSPath(gcsPath) - if srcBucket == "" || srcPath == "" { - return fmt.Errorf("invalid GCS path format: %s", gcsPath) - } - - // Extract the model folder name from the source path - pathSegments := strings.Split(strings.TrimSuffix(srcPath, "/"), "/") - srcModelName := pathSegments[len(pathSegments)-1] - srcBasePath := strings.TrimSuffix(srcPath, "/"+srcModelName) - - // Step 2: Transfer all files from source to destination - log.Info().Msgf("Full upload: transferring all files from %s/%s to %s/%s", - srcBucket, srcPath, destBucket, destPath) - - return p.GcsClient.TransferFolder(srcBucket, srcBasePath, srcModelName, - destBucket, strings.TrimSuffix(destPath, "/"+modelName), modelName) -} - -// syncPartialFiles syncs only config.pbtxt for partial upload -// Note: metadata.json is handled separately in uploadModelMetadata -func (p *Predator) syncPartialFiles(gcsPath, destBucket, destPath, modelName string) error { - // Parse GCS path - srcBucket, srcPath := extractGCSPath(gcsPath) - if srcBucket == "" || srcPath == "" { - return fmt.Errorf("invalid GCS path format: %s", gcsPath) - } - - // Files to sync for partial upload (only config.pbtxt) - // metadata.json is always uploaded from the request metadata - filesToSync := []string{"config.pbtxt"} - - log.Info().Msgf("Partial upload: syncing %v for model %s", filesToSync, modelName) - - for _, fileName := range filesToSync { - srcFilePath := path.Join(srcPath, fileName) - destFilePath := path.Join(destPath, fileName) - - // Read file from source - data, err := p.GcsClient.ReadFile(srcBucket, srcFilePath) - if err != nil { - return fmt.Errorf("required file %s not found in source %s/%s: %w", - fileName, srcBucket, srcFilePath, err) - } - - // Note: config.pbtxt modification is handled by GCS client during TransferFolder - - // Upload to destination - if err := p.GcsClient.UploadFile(destBucket, destFilePath, data); err != nil { - return fmt.Errorf("failed to upload %s: %w", fileName, err) - } - - log.Info().Msgf("Successfully synced %s for partial upload of model %s", fileName, modelName) - } - - return nil -} - -// validateModelConfiguration validates the model configuration by: -// 1. Downloading and parsing config.pbtxt to proto -// 2. Checking if backend is "python" -// 3. If backend is python, checking if preprocessing.tar.gz exists in source -func (p *Predator) validateModelConfiguration(gcsPath string) error { - log.Info().Msgf("Validating model configuration for GCS path: %s", gcsPath) - - // Parse the GCS path to extract bucket and object path - srcBucket, srcPath := extractGCSPath(gcsPath) - if srcBucket == "" || srcPath == "" { - return fmt.Errorf("invalid GCS path format: %s", gcsPath) - } - - // Step 1: Download config.pbtxt - configPath := path.Join(srcPath, configFile) - configData, err := p.GcsClient.ReadFile(srcBucket, configPath) - if err != nil { - return fmt.Errorf("failed to read config.pbtxt from %s/%s: %w", srcBucket, configPath, err) - } - - // Step 2: Parse config.pbtxt to proto - var modelConfig ModelConfig - if err := prototext.Unmarshal(configData, &modelConfig); err != nil { - return fmt.Errorf("failed to parse config.pbtxt as proto: %w", err) - } - - log.Info().Msgf("Parsed model config - Name: %s, Backend: %s", modelConfig.Name, modelConfig.Backend) - - return nil -} - -// cleanEnsembleScheduling cleans up ensemble scheduling to avoid storing {"step": null} -func (p *Predator) cleanEnsembleScheduling(metadata MetaData) MetaData { - // If ensemble scheduling step is empty, set to nil so omitempty works - if len(metadata.Ensembling.Step) == 0 { - metadata.Ensembling = Ensembling{Step: nil} - } - return metadata -} - -// Returns the derived model name with deployable tag -func (p *Predator) GetDerivedModelName(payloadObject Payload, requestType string) (string, error) { - if requestType != ScaleUpRequestType { - return payloadObject.ModelName, nil - } - serviceDeployableID := payloadObject.ConfigMapping.ServiceDeployableID - serviceDeployable, err := p.ServiceDeployableRepo.GetById(int(serviceDeployableID)) - if err != nil { - return constant.EmptyString, fmt.Errorf("%s: %w", errFetchDeployableConfig, err) - } - - deployableTag := serviceDeployable.DeployableTag - if deployableTag == "" { - return payloadObject.ModelName, nil - } - - derivedModelName := payloadObject.ModelName + deployableTagDelimiter + deployableTag - derivedModelName = derivedModelName + deployableTagDelimiter + scaleupTag - return derivedModelName, nil -} - -// Returns the original model name if no tag is found (backward compatibility). -func (p *Predator) GetOriginalModelName(derivedModelName string, serviceDeployableID int) (string, error) { - serviceDeployable, err := p.ServiceDeployableRepo.GetById(serviceDeployableID) - if err != nil { - return constant.EmptyString, fmt.Errorf("%s: %w", errFetchDeployableConfig, err) - } - - deployableTag := serviceDeployable.DeployableTag - if deployableTag == "" { - return derivedModelName, nil - } - - scaleupSuffix := deployableTagDelimiter + scaleupTag - derivedModelName = strings.TrimSuffix(derivedModelName, scaleupSuffix) - - deployableTagSuffix := deployableTagDelimiter + deployableTag - if originalName, foundSuffix := strings.CutSuffix(derivedModelName, deployableTagSuffix); foundSuffix { - return originalName, nil - } - - return derivedModelName, nil -} - -func (p *Predator) isNonProductionEnvironment() bool { - env := strings.ToLower(strings.TrimSpace(pred.AppEnv)) - if env == "prd" || env == "prod" { - return false - } - return true -} - -func (p *Predator) copyConfigToNewNameInConfigSource(oldModelName, newModelName string) error { - if oldModelName == newModelName { - return nil - } - - if pred.GcsConfigBucket == "" || pred.GcsConfigBasePath == "" { - log.Warn().Msg("Config source not configured, skipping config.pbtxt copy in config source") - return nil - } - - destConfigPath := path.Join(pred.GcsConfigBasePath, newModelName, configFile) - exists, err := p.GcsClient.CheckFileExists(pred.GcsConfigBucket, destConfigPath) - if err != nil { - log.Warn().Err(err).Msgf("Failed to check if config exists for %s, will attempt copy anyway", newModelName) - } else if exists { - log.Info().Msgf("Config already exists for %s in config source, skipping copy", newModelName) - return nil - } - - srcConfigPath := path.Join(pred.GcsConfigBasePath, oldModelName, configFile) - - configData, err := p.GcsClient.ReadFile(pred.GcsConfigBucket, srcConfigPath) - if err != nil { - return fmt.Errorf("failed to read config.pbtxt from %s: %w", srcConfigPath, err) - } - - // Use formatting-preserving function instead of marshal/unmarshal - updatedConfigData := p.replaceModelNameInConfigPreservingFormat(configData, newModelName) - - if err := p.GcsClient.UploadFile(pred.GcsConfigBucket, destConfigPath, updatedConfigData); err != nil { - return fmt.Errorf("failed to upload config.pbtxt to %s: %w", destConfigPath, err) - } - - log.Info().Msgf("Successfully copied config.pbtxt from %s to %s in config source", - oldModelName, newModelName) - return nil -} - -// replaceModelNameInConfigPreservingFormat updates only the top-level model name while preserving formatting -// It replaces only the first occurrence to avoid modifying nested names in inputs/outputs/instance_groups -func (p *Predator) replaceModelNameInConfigPreservingFormat(data []byte, destModelName string) []byte { - content := string(data) - lines := strings.Split(content, "\n") - - for i, line := range lines { - trimmed := strings.TrimSpace(line) - // Match top-level "name:" field - should be at the start of line (or minimal indentation) - // Skip nested names which are typically indented with 2+ spaces - if strings.HasPrefix(trimmed, "name:") { - // Check indentation: top-level fields have minimal/no indentation - leadingWhitespace := len(line) - len(strings.TrimLeft(line, " \t")) - // Skip if heavily indented (nested field) - if leadingWhitespace >= 2 { - continue - } - - // Match the first occurrence of name: "value" pattern - namePattern := regexp.MustCompile(`name\s*:\s*"([^"]+)"`) - matches := namePattern.FindStringSubmatch(line) - if len(matches) > 1 { - oldModelName := matches[1] - // Replace only the FIRST occurrence to avoid replacing nested names - loc := namePattern.FindStringIndex(line) - if loc != nil { - // Replace only the matched portion (first occurrence) - before := line[:loc[0]] - matched := line[loc[0]:loc[1]] - after := line[loc[1]:] - // Replace the value inside quotes while preserving the "name:" format - valuePattern := regexp.MustCompile(`"([^"]+)"`) - valueReplaced := valuePattern.ReplaceAllString(matched, fmt.Sprintf(`"%s"`, destModelName)) - lines[i] = before + valueReplaced + after - } else { - // Fallback: replace all (shouldn't happen with valid input) - lines[i] = namePattern.ReplaceAllString(line, fmt.Sprintf(`name: "%s"`, destModelName)) - } - log.Info().Msgf("Replacing top-level model name in config.pbtxt: '%s' -> '%s'", oldModelName, destModelName) - break - } - } - } - - return []byte(strings.Join(lines, "\n")) -} diff --git a/horizon/internal/predator/handler/predator_approval.go b/horizon/internal/predator/handler/predator_approval.go new file mode 100644 index 00000000..0b67a44c --- /dev/null +++ b/horizon/internal/predator/handler/predator_approval.go @@ -0,0 +1,842 @@ +package handler + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/Meesho/BharatMLStack/horizon/internal/constant" + pred "github.com/Meesho/BharatMLStack/horizon/internal/predator" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/discoveryconfig" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorconfig" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorrequest" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +func (p *Predator) processRequest(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { + if req.Status == statusApproved { + switch predatorRequestList[0].RequestType { + case OnboardRequestType: + p.processOnboardFlow(requestIdPayloadMap, predatorRequestList, req) + case ScaleUpRequestType: + p.processScaleUpFlow(requestIdPayloadMap, predatorRequestList, req) + case PromoteRequestType: + p.processPromoteFlow(requestIdPayloadMap, predatorRequestList, req) + case DeleteRequestType: + p.processDeleteRequest(requestIdPayloadMap, predatorRequestList, req) + case EditRequestType: + p.processEditRequest(requestIdPayloadMap, predatorRequestList, req) + default: + log.Error().Err(errors.New(errInvalidRequestType)).Msg(errInvalidRequestType) + } + } else { + p.processRejectRequest(predatorRequestList, req) + } +} + +func (p *Predator) processRejectRequest(predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { + for i := range predatorRequestList { + predatorRequestList[i].Status = statusRejected + predatorRequestList[i].RejectReason = req.RejectReason + predatorRequestList[i].Reviewer = req.ApprovedBy + predatorRequestList[i].UpdatedBy = req.ApprovedBy + predatorRequestList[i].UpdatedAt = time.Now() + predatorRequestList[i].Active = false + } + + if err := p.Repo.UpdateMany(predatorRequestList); err != nil { + log.Printf(errFailedToUpdateRequestStatusAndStage, err) + } + + log.Printf("Request %s rejected successfully.\n", req.GroupID) +} + +func (p *Predator) processDeleteRequest(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { + transferredGcsModelData, err := p.processGCSCloneToDeleteBucket(req.ApprovedBy, predatorRequestList, requestIdPayloadMap) + if err != nil { + log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) + p.revertForDelete(transferredGcsModelData) + return + } + + p.processDBPopulationStageForDelete(predatorRequestList, requestIdPayloadMap, req) + + if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { + log.Error().Err(err).Msg(errFailedToRestartDeployable) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) + return + } + +} + +func (p *Predator) processEditRequest(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { + log.Info().Msgf("Starting edit request flow for group ID: %s", req.GroupID) + + // Step 1: Get target deployable configuration from the request + targetDeployableID := int(requestIdPayloadMap[predatorRequestList[0].RequestID].ConfigMapping.ServiceDeployableID) + targetServiceDeployable, err := p.ServiceDeployableRepo.GetById(targetDeployableID) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch target service deployable for edit request") + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) + return + } + + var targetDeployableConfig PredatorDeployableConfig + if err := json.Unmarshal(targetServiceDeployable.Config, &targetDeployableConfig); err != nil { + log.Error().Err(err).Msg("Failed to parse target service deployable config") + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) + return + } + + targetBucket, targetPath := extractGCSPath(strings.TrimSuffix(targetDeployableConfig.GCSBucketPath, "/*")) + log.Info().Msgf("Target deployable path: gs://%s/%s", targetBucket, targetPath) + + // Step 2: GCS Copy Stage - Copy models from source to target deployable path + transferredGcsModelData, err := p.processEditGCSCopyStage(requestIdPayloadMap, predatorRequestList, targetBucket, targetPath) + if err != nil { + log.Error().Err(err).Msg("Failed to copy models for edit request") + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) + p.revert(transferredGcsModelData) + return + } + + // Update stage to DB Population after successful GCS copy + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusInProgress, predatorStageDBPopulation) + + // Step 3: DB Update Stage - Update existing predator config with new metadata from request + err = p.processEditDBUpdateStage(requestIdPayloadMap, predatorRequestList, req.ApprovedBy) + if err != nil { + log.Error().Err(err).Msg("Failed to update database for edit request") + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) + p.revert(transferredGcsModelData) + return + } + + // Update stage to Restart Deployable after successful DB update + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusInProgress, predatorStageRestartDeployable) + + // Step 4: Restart Deployable Stage - Restart target deployable + if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { + log.Error().Err(err).Msg("Failed to restart deployable for edit request") + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) + return + } + + // Mark request as approved and completed + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusApproved, constant.EmptyString) + log.Info().Msgf("Edit request completed successfully for group ID: %s", req.GroupID) +} + +// processEditGCSCopyStage copies models from source to target deployable path for edit approval +func (p *Predator) processEditGCSCopyStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, targetBucket, targetPath string) ([]GcsModelData, error) { + var transferredGcsModelData []GcsModelData + + // Check if we're in the correct stage for GCS copy + if predatorRequestList[0].RequestStage != predatorStagePending && predatorRequestList[0].RequestStage != predatorStageCloneToBucket && predatorRequestList[0].RequestStage != constant.EmptyString { + log.Info().Msgf("Skipping GCS copy stage - current stage: %s", predatorRequestList[0].RequestStage) + return transferredGcsModelData, nil + } + + isNotProd := p.isNonProductionEnvironment() + + for _, requestModel := range predatorRequestList { + payload := requestIdPayloadMap[requestModel.RequestID] + if payload == nil { + log.Error().Msgf("Payload not found for request ID %d", requestModel.RequestID) + continue + } + + modelName := requestModel.ModelName + + // Use the source path from the payload, not the default GCS bucket + if payload.ModelSource == "" { + log.Error().Msgf("ModelSource is empty for request ID %d", requestModel.RequestID) + return transferredGcsModelData, fmt.Errorf("model source path is empty for model %s", modelName) + } + + // Normalize GCS URL (handle gcs:// prefix) + normalizedModelSource := payload.ModelSource + if strings.HasPrefix(normalizedModelSource, "gcs://") { + normalizedModelSource = strings.Replace(normalizedModelSource, "gcs://", "gs://", 1) + log.Info().Msgf("Normalized GCS URL from %s to %s", payload.ModelSource, normalizedModelSource) + } + + // Parse the source GCS path + sourceBucket, sourcePath := extractGCSPath(normalizedModelSource) + if sourceBucket == "" || sourcePath == "" { + log.Error().Msgf("Invalid source GCS path format: %s (normalized: %s)", payload.ModelSource, normalizedModelSource) + return transferredGcsModelData, fmt.Errorf("invalid source GCS path format: %s", normalizedModelSource) + } + + log.Info().Msgf("Copying model %s from source gs://%s/%s to target gs://%s/%s for edit approval", + modelName, sourceBucket, sourcePath, targetBucket, targetPath) + + // Copy model from source to target deployable path + // Extract model folder name from source path and copy to target with the same model name + pathSegments := strings.Split(strings.TrimSuffix(sourcePath, "/"), "/") + sourceModelName := pathSegments[len(pathSegments)-1] + sourceBasePath := strings.TrimSuffix(sourcePath, "/"+sourceModelName) + + if isNotProd { + if err := p.GcsClient.TransferFolder( + sourceBucket, sourceBasePath, sourceModelName, + targetBucket, targetPath, modelName, + ); err != nil { + return transferredGcsModelData, err + } + } else { + configBucket := pred.GcsConfigBucket + configPath := pred.GcsConfigBasePath + if err := p.GcsClient.TransferFolderWithSplitSources( + sourceBucket, sourceBasePath, configBucket, configPath, + sourceModelName, targetBucket, targetPath, modelName, + ); err != nil { + return transferredGcsModelData, err + } + } + + // Track transferred data for potential rollback + transferredGcsModelData = append(transferredGcsModelData, GcsModelData{ + Bucket: targetBucket, + Path: targetPath, + Name: modelName, + }) + + log.Info().Msgf("Successfully copied model %s for edit approval", modelName) + } + + return transferredGcsModelData, nil +} + +// processEditDBUpdateStage updates predator config for edit approval +// This updates the existing predator config with new config.pbtxt and metadata.json changes +func (p *Predator) processEditDBUpdateStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, approvedBy string) error { + // Check if we're in the correct stage for DB update + if predatorRequestList[0].RequestStage != predatorStageDBPopulation { + log.Info().Msgf("Skipping DB update stage - current stage: %s", predatorRequestList[0].RequestStage) + return nil + } + + log.Info().Msg("Starting DB update stage for edit approval") + + for _, requestModel := range predatorRequestList { + payload := requestIdPayloadMap[requestModel.RequestID] + if payload == nil { + log.Error().Msgf("Payload not found for request ID %d", requestModel.RequestID) + continue + } + + modelName := requestModel.ModelName + log.Info().Msgf("Updating predator config for model %s", modelName) + + // Find existing predator config for this model + existingPredatorConfig, err := p.PredatorConfigRepo.GetActiveModelByModelName(modelName) + if err != nil { + log.Error().Err(err).Msgf("Failed to fetch existing predator config for model %s", modelName) + return fmt.Errorf("failed to fetch existing predator config for model %s: %w", modelName, err) + } + + if existingPredatorConfig == nil { + log.Error().Msgf("No existing predator config found for model %s", modelName) + return fmt.Errorf("no existing predator config found for model %s", modelName) + } + + // Clean up ensemble scheduling and update the predator config with new metadata from the request + cleanedMetaData := p.cleanEnsembleScheduling(payload.MetaData) + + metaDataBytes, err := json.Marshal(cleanedMetaData) + if err != nil { + log.Error().Err(err).Msgf("Failed to marshal metadata for model %s", modelName) + return fmt.Errorf("failed to marshal metadata for model %s: %w", modelName, err) + } + + // Update the existing config + existingPredatorConfig.MetaData = metaDataBytes + existingPredatorConfig.UpdatedBy = approvedBy + existingPredatorConfig.UpdatedAt = time.Now() + existingPredatorConfig.HasNilData = true + existingPredatorConfig.TestResults = nil + // Save the updated config + if err := p.PredatorConfigRepo.Update(existingPredatorConfig); err != nil { + log.Error().Err(err).Msgf("Failed to update predator config for model %s", modelName) + return fmt.Errorf("failed to update predator config for model %s: %w", modelName, err) + } + + log.Info().Msgf("Successfully updated predator config for model %s", modelName) + } + + log.Info().Msg("DB update stage completed successfully for edit approval") + return nil +} + +func (p *Predator) copyAllModelsFromActualToStaging(sourceBucket, sourcePath, targetBucket, targetPath string) error { + // List all models in the actual target path and copy them to staging + folders, err := p.GcsClient.ListFolders(sourceBucket, sourcePath) + if err != nil { + return fmt.Errorf("failed to list models in actual target path: %w", err) + } + + // Copy each model folder from actual target to staging + for _, modelName := range folders { + log.Info().Msgf("Copying existing model %s from actual target to staging", modelName) + + if err := p.GcsClient.TransferFolder(sourceBucket, sourcePath, modelName, targetBucket, targetPath, modelName); err != nil { + log.Error().Err(err).Msgf("Failed to copy existing model %s to staging", modelName) + return fmt.Errorf("failed to copy existing model %s to staging: %w", modelName, err) + } + } + + return nil +} + +func (p *Predator) deleteServiceDiscoveryAndConfig(req ApproveRequest, predatorRequestList []predatorrequest.PredatorRequest, requestIdPayloadMap map[uint]*Payload) error { + tx := p.Repo.DB().Begin() + if tx.Error != nil { + return tx.Error + } + + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) // re-throw panic after rollback + } + }() + + for i := range predatorRequestList { + payload := requestIdPayloadMap[predatorRequestList[i].RequestID] + if payload == nil { + log.Error().Msgf(errFailedToParsePayload) + tx.Rollback() + return fmt.Errorf("failed to parse payload for request ID %d", predatorRequestList[i].RequestID) + } + discoveryConfigID := int(payload.DiscoveryConfigID) + log.Info().Msgf("Processing delete request for discovery config ID: %d", discoveryConfigID) + serviceDiscovery, err := p.ServiceDiscoveryRepo.WithTx(tx).GetById(discoveryConfigID) + if err != nil { + log.Error().Err(err).Msg(errFailedToFindServiceDiscovery) + tx.Rollback() + return err + } + serviceDiscovery.Active = false + serviceDiscovery.UpdatedAt = time.Now() + serviceDiscovery.UpdatedBy = req.ApprovedBy + + if err := p.ServiceDiscoveryRepo.WithTx(tx).Update(serviceDiscovery); err != nil { + log.Error().Err(err).Msg(errFailedToUpdateServiceDiscovery) + tx.Rollback() + return err + } + + predatorConfigs, err := p.PredatorConfigRepo.WithTx(tx).GetByDiscoveryConfigID(discoveryConfigID) + if err != nil { + log.Error().Err(err).Msg(errFailedToFindPredatorConfig) + tx.Rollback() + return err + } + + for j := range predatorConfigs { + predatorConfigs[j].Active = false + predatorConfigs[j].UpdatedAt = time.Now() + predatorConfigs[j].UpdatedBy = req.ApprovedBy + if err := p.PredatorConfigRepo.WithTx(tx).Update(&predatorConfigs[j]); err != nil { + log.Error().Err(err).Msg(errFailedToUpdatePredatorConfig) + tx.Rollback() + return err + } + } + + predatorRequestList[i].Status = statusInProgress + predatorRequestList[i].Reviewer = req.ApprovedBy + predatorRequestList[i].UpdatedBy = req.ApprovedBy + predatorRequestList[i].RequestStage = predatorStageRestartDeployable + predatorRequestList[i].UpdatedAt = time.Now() + + if err := p.Repo.WithTx(tx).Update(&predatorRequestList[i]); err != nil { + log.Error().Err(err).Msg(errFailedToUpdateRequestStatusAndStage) + tx.Rollback() + return err + } + } + + if err := tx.Commit().Error; err != nil { + log.Error().Err(err).Msg("transaction commit failed") + return err + } + + return nil +} + +func (p *Predator) processGCSCloneToDeleteBucket(email string, predatorRequestList []predatorrequest.PredatorRequest, requestIdPayloadMap map[uint]*Payload) ([]GcsTransferredData, error) { + var transferredGcsModelData []GcsTransferredData + if predatorRequestList[0].RequestStage == constant.EmptyString || predatorRequestList[0].RequestStage == predatorStagePending || predatorRequestList[0].RequestStage == predatorStageCloneToBucket { + for _, requestModel := range predatorRequestList { + srcBucket, srcPath, srcModelName := extractGCSDetails(requestIdPayloadMap[requestModel.RequestID].ModelSource) + destBucket, destPath := extractGCSPath(pred.DefaultModelPathKey) + log.Info().Msgf("srcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", srcBucket, srcPath, srcModelName, destBucket, destPath) + if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || srcModelName == constant.EmptyString || destBucket == constant.EmptyString || destPath == constant.EmptyString || requestIdPayloadMap[requestModel.RequestID].ModelName == constant.EmptyString { + log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) + return transferredGcsModelData, errors.New(errModelPathFormat) + } + + if err := p.GcsClient.TransferAndDeleteFolder(srcBucket, srcPath, srcModelName, destBucket, destPath, requestIdPayloadMap[requestModel.RequestID].ModelName); err != nil { + log.Error().Err(err).Msg(errGCSCopyFailed) + return transferredGcsModelData, err + } + + transferredGcsModelData = append(transferredGcsModelData, GcsTransferredData{ + SrcBucket: destBucket, + SrcPath: destPath, + SrcName: requestIdPayloadMap[requestModel.RequestID].ModelName, + DestBucket: srcBucket, + DestPath: srcPath, + DestName: srcModelName, + }) + } + p.updateRequestStatusAndStage(email, predatorRequestList, statusInProgress, predatorStageDBPopulation) + } + return transferredGcsModelData, nil +} + +func (p *Predator) processRestartDeployableStage(email string, predatorRequestList []predatorrequest.PredatorRequest, requestIdPayloadMap map[uint]*Payload) error { + if predatorRequestList[0].RequestStage != predatorStageRestartDeployable { + return nil + } + var serviceDeployableIDList []int + for _, requestModel := range predatorRequestList { + serviceDeployableIDList = append(serviceDeployableIDList, int(requestIdPayloadMap[requestModel.RequestID].ConfigMapping.ServiceDeployableID)) + } + + for _, serviceDeployableID := range serviceDeployableIDList { + sd, err := p.ServiceDeployableRepo.GetById(int(serviceDeployableID)) + if err != nil { + log.Error().Err(err).Msg(errFailedToFindServiceDeployableEntry) + return err + } + // Extract isCanary from deployable config + var deployableConfig map[string]interface{} + isCanary := false + if err := json.Unmarshal(sd.Config, &deployableConfig); err == nil { + if strategy, ok := deployableConfig["deploymentStrategy"].(string); ok && strategy == "canary" { + isCanary = true + } + } + if err := p.infrastructureHandler.RestartDeployment(sd.Name, p.workingEnv, isCanary); err != nil { + log.Error().Err(err).Msg(errFailedToRestartDeployable) + return err + } + } + + p.updateRequestStatusAndStage(email, predatorRequestList, statusApproved, constant.EmptyString) + return nil +} + +func (p *Predator) processPayload(predatorRequest predatorrequest.PredatorRequest) (*Payload, error) { + var payload Payload + decoder := json.NewDecoder(strings.NewReader(predatorRequest.Payload)) + decoder.DisallowUnknownFields() + if err := decoder.Decode(&payload); err != nil { + log.Error().Err(err).Msg("Failed to parse payload with strict decoding") + return nil, err + } + return &payload, nil +} + +func (p *Predator) processGCSCloneStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) ([]GcsModelData, error) { + var transferredGcsModelData []GcsModelData + if predatorRequestList[0].RequestStage == predatorStagePending || predatorRequestList[0].RequestStage == predatorStageCloneToBucket { + isNotProd := p.isNonProductionEnvironment() + for _, requestModel := range predatorRequestList { + + serviceDeployable, err := p.ServiceDeployableRepo.GetById(int(requestIdPayloadMap[requestModel.RequestID].ConfigMapping.ServiceDeployableID)) + + if err != nil { + log.Error().Err(err).Msg(serviceDeployableNotFound) + return transferredGcsModelData, err + } + + var deployableConfig PredatorDeployableConfig + if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err != nil { + log.Error().Err(err).Msg(failedToParseServiceConfig) + return transferredGcsModelData, err + } + + destBucket, destPath := extractGCSPath(strings.TrimSuffix(deployableConfig.GCSBucketPath, "/*")) + destModelName := requestIdPayloadMap[requestModel.RequestID].ModelName + + var srcBucket, srcPath, srcModelName string + + srcBucket = pred.GcsModelBucket + srcPath = pred.GcsModelBasePath + if requestModel.RequestType == ScaleUpRequestType { + srcModelName = destModelName + log.Info().Msgf("Scale-up: Source from model-source gs://%s/%s/%s", + srcBucket, srcPath, srcModelName) + } else { + _, _, srcModelName = extractGCSDetails(requestIdPayloadMap[requestModel.RequestID].ModelSource) + log.Info().Msgf("Onboard/Promote: Source from payload gs://%s/%s/%s", + srcBucket, srcPath, srcModelName) + } + + log.Info().Msgf("Copying to target deployable - src: %s/%s/%s, dest: %s/%s/%s", + srcBucket, srcPath, srcModelName, destBucket, destPath, destModelName) + + if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || + srcModelName == constant.EmptyString || destBucket == constant.EmptyString || + destPath == constant.EmptyString || destModelName == constant.EmptyString { + log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) + return transferredGcsModelData, errors.New(errModelPathFormat) + } + + if isNotProd { + if err := p.GcsClient.TransferFolder(srcBucket, srcPath, srcModelName, + destBucket, destPath, destModelName); err != nil { + log.Error().Err(err).Msg(errGCSCopyFailed) + return transferredGcsModelData, err + } + } else { + if err := p.GcsClient.TransferFolderWithSplitSources( + srcBucket, srcPath, pred.GcsConfigBucket, pred.GcsConfigBasePath, + srcModelName, destBucket, destPath, destModelName, + ); err != nil { + log.Error().Err(err).Msg(errGCSCopyFailed) + return transferredGcsModelData, err + } + } + + transferredGcsModelData = append(transferredGcsModelData, GcsModelData{ + Bucket: destBucket, + Path: destPath, + Name: requestIdPayloadMap[requestModel.RequestID].ModelName, + }) + + log.Info().Msgf("Successfully copied model to target deployable: %s", destModelName) + } + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusInProgress, predatorStageDBPopulation) + } + return transferredGcsModelData, nil +} + +func (p *Predator) processGCSCloneStageIndefaultFolder(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) ([]GcsModelData, error) { + var transferredGcsModelData []GcsModelData + if predatorRequestList[0].RequestStage != predatorStagePending && + predatorRequestList[0].RequestStage != predatorStageCloneToBucket { + return transferredGcsModelData, nil + } + + isNotProd := p.isNonProductionEnvironment() + + for _, requestModel := range predatorRequestList { + payload := requestIdPayloadMap[requestModel.RequestID] + + destBucket := pred.GcsModelBucket + destPath := pred.GcsModelBasePath + destModelName := payload.ModelName + + _, _, originalModelName := extractGCSDetails(payload.ModelSource) + srcBucket := pred.GcsModelBucket + srcPath := pred.GcsModelBasePath + srcModelName := originalModelName + + log.Info().Msgf("Scale-up: Copying within model-source %s → %s:\nsrcBucket: %s, srcPath: %s, srcModelName: %s, destBucket: %s, destPath: %s", + srcModelName, destModelName, srcBucket, srcPath, srcModelName, destBucket, destPath) + + if srcBucket == constant.EmptyString || srcPath == constant.EmptyString || + srcModelName == constant.EmptyString || destBucket == constant.EmptyString || + destPath == constant.EmptyString || destModelName == constant.EmptyString { + log.Error().Err(errors.New(errModelPathFormat)).Msg(errInvalidGcsBucketPath) + return transferredGcsModelData, errors.New(errModelPathFormat) + } + + if err := p.GcsClient.TransferFolder(srcBucket, srcPath, srcModelName, + destBucket, destPath, destModelName); err != nil { + log.Error().Err(err).Msg(errGCSCopyFailed) + return transferredGcsModelData, err + } + + log.Info().Msgf("Successfully copied model in model-source: %s → %s", srcModelName, destModelName) + + if !isNotProd && srcModelName != destModelName { + if err := p.copyConfigToNewNameInConfigSource(srcModelName, destModelName); err != nil { + log.Error().Err(err).Msgf("Failed to copy config to config-source: %s → %s", + srcModelName, destModelName) + return transferredGcsModelData, err + } + } + + transferredGcsModelData = append(transferredGcsModelData, GcsModelData{ + Bucket: destBucket, + Path: destPath, + Name: destModelName, + }) + } + + return transferredGcsModelData, nil +} + +func (p *Predator) processDBPopulationStageForDelete(predatorRequestList []predatorrequest.PredatorRequest, requestIdPayloadMap map[uint]*Payload, req ApproveRequest) { + if predatorRequestList[0].RequestStage != predatorStageDBPopulation { + return + } + + if err := p.deleteServiceDiscoveryAndConfig(req, predatorRequestList, requestIdPayloadMap); err != nil { + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) + return + } +} + +func (p *Predator) processDBPopulationStage(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, approvedBy string, successMessage string) error { + if predatorRequestList[0].RequestStage != predatorStageDBPopulation { + return nil + } + tx := p.Repo.DB().Begin() + for i := range predatorRequestList { + defer func() { + if r := recover(); r != nil { + tx.Rollback() + log.Printf("panic recovered, transaction rolled back") + } + }() + + if err := p.createDiscoveryAndPredatorConfigTx(tx, predatorRequestList[i], *requestIdPayloadMap[predatorRequestList[i].RequestID], approvedBy); err != nil { + tx.Rollback() + log.Error().Err(err).Msg(failedToCreateServiceDiscoveryAndConfig) + return err + } + + predatorRequestList[i].Status = statusInProgress + predatorRequestList[i].RequestStage = predatorStageRestartDeployable + if err := p.Repo.UpdateStatusAndStage(tx, &predatorRequestList[i]); err != nil { + tx.Rollback() + log.Printf(errFailedToUpdateRequestStatusAndStage, err) + } + } + if err := tx.Commit().Error; err != nil { + log.Printf("failed to commit transaction: %v", err) + return err + } + log.Printf("success %s %d\n", successMessage, predatorRequestList[0].GroupId) + return nil +} + +func (p *Predator) checkIfModelsExist(predatorRequestList []predatorrequest.PredatorRequest) bool { + for _, requestModel := range predatorRequestList { + modelName := requestModel.ModelName + if modelName == "" { + log.Error().Msgf("model name is empty for request ID %d", requestModel.RequestID) + continue + } + + predatorConfig, err := p.PredatorConfigRepo.GetActiveModelByModelName(modelName) + if err != nil { + log.Error().Err(err).Msgf("failed to fetch predator config for model %s", modelName) + continue + } + if predatorConfig != nil { + log.Error().Msgf("model %s already exists", modelName) + return true + } + } + return false +} + +func (p *Predator) processOnboardFlow(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { + if p.checkIfModelsExist(predatorRequestList) { + req.RejectReason = "model already exists" + req.Status = statusRejected + p.processRejectRequest(predatorRequestList, req) + return + } + + transferredGcsModelData, err := p.processGCSCloneStage(requestIdPayloadMap, predatorRequestList, req) + if err != nil { + log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) + p.revert(transferredGcsModelData) + return + } + + err = p.processDBPopulationStage(requestIdPayloadMap, predatorRequestList, req.ApprovedBy, onboardRequestFlow) + if err != nil { + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) + } + if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { + log.Error().Err(err).Msg(errFailedToRestartDeployable) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) + return + } +} + +func (p *Predator) revert(transferredGcsModelData []GcsModelData) error { + for _, data := range transferredGcsModelData { + if err := p.GcsClient.DeleteFolder(data.Bucket, data.Path, data.Name); err != nil { + log.Error().Err(err).Msg(errGCSCopyFailed) + return err + } + } + return nil +} + +func (p *Predator) revertForDelete(transferredGcsModelData []GcsTransferredData) error { + for _, data := range transferredGcsModelData { + if err := p.GcsClient.TransferAndDeleteFolder(data.SrcBucket, data.SrcPath, data.SrcName, data.DestBucket, data.DestPath, data.DestName); err != nil { + log.Error().Err(err).Msg(errGCSCopyFailed) + return err + } + } + return nil +} + +func (p *Predator) processScaleUpFlow(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { + if p.checkIfModelsExist(predatorRequestList) { + req.RejectReason = fmt.Sprintf("model %s already exists", requestIdPayloadMap[predatorRequestList[0].RequestID].ModelName) + req.Status = statusRejected + p.processRejectRequest(predatorRequestList, req) + return + } + + transferredGcsModelData, err := p.processGCSCloneStageIndefaultFolder(requestIdPayloadMap, predatorRequestList, req) + if err != nil { + log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) + p.revert(transferredGcsModelData) + return + } + + transferredGcsModelData, err = p.processGCSCloneStage(requestIdPayloadMap, predatorRequestList, req) + if err != nil { + log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) + p.revert(transferredGcsModelData) + return + } + + err = p.processDBPopulationStage(requestIdPayloadMap, predatorRequestList, req.ApprovedBy, cloneRequestFlow) + if err != nil { + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) + } + if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { + log.Error().Err(err).Msg(errFailedToRestartDeployable) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) + return + } +} + +func (p *Predator) processPromoteFlow(requestIdPayloadMap map[uint]*Payload, predatorRequestList []predatorrequest.PredatorRequest, req ApproveRequest) { + if p.checkIfModelsExist(predatorRequestList) { + req.RejectReason = fmt.Sprintf("model %s already exists", requestIdPayloadMap[predatorRequestList[0].RequestID].ModelName) + req.Status = statusRejected + p.processRejectRequest(predatorRequestList, req) + return + } + + transferredGcsModelData, err := p.processGCSCloneStage(requestIdPayloadMap, predatorRequestList, req) + if err != nil { + log.Error().Err(err).Msg(errFailedToOperateGcsCloneStage) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageCloneToBucket) + p.revert(transferredGcsModelData) + return + } + + err = p.processDBPopulationStage(requestIdPayloadMap, predatorRequestList, req.ApprovedBy, promoteRequestFlow) + if err != nil { + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageDBPopulation) + } + if err := p.processRestartDeployableStage(req.ApprovedBy, predatorRequestList, requestIdPayloadMap); err != nil { + log.Error().Err(err).Msg(errFailedToRestartDeployable) + p.updateRequestStatusAndStage(req.ApprovedBy, predatorRequestList, statusFailed, predatorStageRestartDeployable) + return + } +} + +func (p *Predator) updateRequestStatusAndStage(approvedBy string, predatorRequestList []predatorrequest.PredatorRequest, status, stage string) { + for i := range predatorRequestList { + predatorRequestList[i].Status = status + predatorRequestList[i].Reviewer = approvedBy + predatorRequestList[i].UpdatedBy = approvedBy + if stage != constant.EmptyString { + predatorRequestList[i].RequestStage = stage + } + if predatorRequestList[i].Status == statusApproved || + predatorRequestList[i].Status == statusFailed || + predatorRequestList[i].Status == statusRejected { + predatorRequestList[i].Active = false + } + predatorRequestList[i].UpdatedAt = time.Now() + } + + if err := p.Repo.UpdateMany(predatorRequestList); err != nil { + log.Printf(errFailedToUpdateRequestStatusAndStage, err) + } +} + +func (p *Predator) createDiscoveryAndPredatorConfigTx(tx *gorm.DB, requestModel predatorrequest.PredatorRequest, payload Payload, approvedBy string) error { + discoveryConfig, err := p.createDiscoveryConfigTx(tx, &requestModel, payload) + if err != nil { + return err + } + return p.createPredatorConfigTx(tx, &requestModel, payload, approvedBy, discoveryConfig.ID) +} + +func (p *Predator) createDiscoveryConfigTx(tx *gorm.DB, requestModel *predatorrequest.PredatorRequest, payload Payload) (discoveryconfig.DiscoveryConfig, error) { + discoveryConfig := discoveryconfig.DiscoveryConfig{ + ServiceDeployableID: int(payload.ConfigMapping.ServiceDeployableID), + CreatedBy: requestModel.CreatedBy, + UpdatedBy: requestModel.UpdatedBy, + Active: true, + CreatedAt: requestModel.CreatedAt, + UpdatedAt: time.Now(), + } + if err := tx.Create(&discoveryConfig).Error; err != nil { + log.Error().Err(err).Msg(errMsgInsertDiscovery) + return discoveryConfig, err + } + return discoveryConfig, nil +} + +func (p *Predator) createPredatorConfigTx(tx *gorm.DB, requestModel *predatorrequest.PredatorRequest, payload Payload, approvedBy string, discoveryConfigID int) error { + // Clean up ensemble scheduling before marshaling + cleanedMetaData := p.cleanEnsembleScheduling(payload.MetaData) + + metaDataBytes, err := json.Marshal(cleanedMetaData) + if err != nil { + log.Error().Err(err).Msg(errMsgMarshalMeta) + return err + } + + serviceDeployableID := int(payload.ConfigMapping.ServiceDeployableID) + serviceDeployable, err := p.ServiceDeployableRepo.GetById(serviceDeployableID) + if err != nil { + log.Error().Err(err).Msgf("Failed to get service deployable config for ID %d", serviceDeployableID) + return fmt.Errorf("failed to get service deployable config: %w", err) + } + + config := predatorconfig.PredatorConfig{ + DiscoveryConfigID: discoveryConfigID, + ModelName: payload.ModelName, + MetaData: metaDataBytes, + CreatedBy: requestModel.CreatedBy, + UpdatedBy: approvedBy, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + Active: true, + SourceModelName: payload.ConfigMapping.SourceModelName, + } + + if serviceDeployable.OverrideTesting { + log.Info().Msgf("OverrideTesting is enabled for deployable %s. Setting test_results for model %s", + serviceDeployable.Name, payload.ModelName) + + config.TestResults = json.RawMessage(`{"is_functionally_tested": true}`) + config.HasNilData = false + } + + if err := tx.Create(&config).Error; err != nil { + log.Error().Err(err).Msg(errMsgInsertConfig) + return err + } + return nil +} diff --git a/horizon/internal/predator/handler/predator_constants.go b/horizon/internal/predator/handler/predator_constants.go new file mode 100644 index 00000000..4d304dbb --- /dev/null +++ b/horizon/internal/predator/handler/predator_constants.go @@ -0,0 +1,115 @@ +package handler + +const ( + OnboardRequestType = "Onboard" + ScaleUpRequestType = "ScaleUp" + PromoteRequestType = "Promote" + EditRequestType = "Edit" + DeleteRequestType = "Delete" + configFile = "config.pbtxt" + pendingApproval = "Pending Approval" + slashConstant = "/" + gcsPrefix = "gs://" + adminRole = "admin" + typeString = "TYPE_STRING" + bytesKeys = "BYTES" + typePrefix = "TYPE_" + errMsgFetchConfigs = "failed to fetch predator configs: %w" + errMsgParsePayload = "failed to parse payload for request ID %d: %w" + cpuRequestKey = "cpu_request" + cpuLimitKey = "cpu_limit" + memRequestKey = "mem_request" + memLimitKey = "mem_limit" + gpuRequestKey = "gpu_request" + gpuLimitKey = "gpu_limit" + minReplicaKey = "min_replica" + maxReplicaKey = "max_replica" + nodeSelectorKey = "node_selector" + statusFailed = "Failed" + statusInProgress = "In Progress" + errMsgMarshalMeta = "Failed to marshal metadata" + errMsgInsertConfig = "Failed to insert predator_config" + errMsgInsertDiscovery = "Failed to insert service discovery" + errMsgCreateConnection = "Error in creating connection" + errMsgTypeAssertion = "failed to cast connection to *infra.SQLConnection" + errMsgTypeAssertionLog = "Type assertion error" + errMsgCreateRequestRepo = "Error in creating predator request repository" + errMsgCreateDeployableRepo = "Error in creating service deployable repository" + errMsgCreateConfigRepo = "Error in creating predator config repository" + errMsgCreateDiscoveryRepo = "Error in creating service discovery repository" + errMsgCreateGroupIdCounterRepo = "Error in creating group id counter repository" + errMsgProcessPayload = "failed to process payload" + errMsgCreateRequestFormat = "could not create %s request" + successMsgFormat = "Model %s Request Raised Successfully." + fieldModelName = "model_name" + statusPendingApproval = "Pending Approval" + errModelNotFound = "model not found" + errFetchDiscoveryConfig = "failed to fetch service discovery config" + errFetchDeployableConfig = "failed to fetch service deployable config" + errUnmarshalDeployableConfig = "failed to unmarshal service deployable config" + errMarshalPayload = "failed to marshal payload" + errCreateDeleteRequest = "could not create delete request" + successDeleteRequestMsg = "Model deletion request raised successfully" + fieldModelSourcePath = "model_source_path" + fieldMetaData = "meta_data" + fieldDiscoveryConfigID = "discovery_config_id" + fieldConfigMapping = "config_mapping" + errReadConfigFileFormat = "failed to read config.pbtxt: %v" + errUnmarshalProtoFormat = "failed to unmarshal proto text: %v" + errNoInstanceGroup = "no instance group defined in model config" + errModelPathPrefix = "model_path must be provided and start with /" + errModelPathFormat = "invalid model_path format. Expected: /bucket/path/to/model" + errModelNameMissing = "model name is missing in config" + errMaxBatchSizeMissing = "max_batch_size is missing or zero in config" + errBackendMissing = "backend is missing in config" + errNoInputDefinitions = "no input definitions found in config" + errNoOutputDefinitions = "no output definitions found in config" + errInstanceGroupMissing = "instance group is missing in config" + errInvalidRequestIDFormat = "invalid group ID format" + errFailedToFetchRequest = "failed to fetch request for group id %s" + errInvalidRequestType = "invalid request type" + statusApproved = "Approved" + statusRejected = "Rejected" + errInvalidGcsBucketPath = "invalid gcs bucket path format for source or destination" + errFailedToUpdateRequest = "Failed to update request status" + successRejectMessage = "Request %d rejected successfully.\n" + errFailedToFindServiceDiscovery = "Failed to find service discovery entry" + errFailedToUpdateServiceDiscovery = "Failed to update service discovery to inactive" + errFailedToFindPredatorConfig = "Failed to find predator config entry" + errFailedToUpdatePredatorConfig = "Failed to update predator config to inactive" + + errFailedToParsePayload = "Failed to parse payload" + errChildModelNotInDeleteRequest = "ensemble model %s has child model %s which is not included in the delete request" + errChildModelDifferentDeployable = "ensemble model %s and its child model %s belong to different deployables (ensemble: %d, child: %d)" + errFailedToFetchDiscoveryConfigForModel = "failed to fetch discovery config for model %s: %w" + errFailedToFetchDiscoveryConfigForEnsemble = "failed to fetch discovery config for ensemble model %s: %w" + errFailedToFetchDiscoveryConfigForChild = "failed to fetch discovery config for child model %s: %w" + errDuplicateModelNameInDeployable = "duplicate model name %s found within deployable %d" + errNormalModelIsChildOfEnsemble = "model %s is a child of ensemble model %s in the same deployable %d, but ensemble is not included in delete request" + errEnsembleMissingChild = "ensemble model %s has child model %s which is not included in the delete request" + errChildMissingEnsemble = "child model %s is included in delete request but its parent ensemble model %s is not included" + errFailedToFindServiceDeployableEntry = "Failed to find service deployable entry" + errFailedToOperateGcsCloneStage = "Failed to operate gcs clone stage" + errFailedToRestartDeployable = "Failed to restart deployable" + errGCSCopyFailed = "GCS copy failed" + errFailedToUpdateRequestStatusAndStage = "Failed to update request status and stage %s" + onboardRequestFlow = "Onboard request" + cloneRequestFlow = "Clone request" + promoteRequestFlow = "Promote request" + predatorStageRestartDeployable = "Restart Deployable" + predatorStagePending = "Pending" + machineTypeKey = "machine_type" + cpuThresholdKey = "cpu_threshold" + gpuThresholdKey = "gpu_threshold" + tritonImageTagKey = "triton_image_tag" + basePathKey = "base_path" + predatorStageCloneToBucket = "Clone To Bucket" + predatorStageDBPopulation = "DB Population" + predatorStageRequestPayloadError = "Request Payload Error" + serviceDeployableNotFound = "ServiceDeployable not found" + failedToParseServiceConfig = "Failed to parse service config" + failedToCreateServiceDiscoveryAndConfig = "Failed to create service discovery and config" + predatorInferMethod = "inference.GRPCInferenceService/ModelInfer" + deployableTagDelimiter = "_" + scaleupTag = "scaleup" +) diff --git a/horizon/internal/predator/handler/predator_fetch.go b/horizon/internal/predator/handler/predator_fetch.go new file mode 100644 index 00000000..074cd800 --- /dev/null +++ b/horizon/internal/predator/handler/predator_fetch.go @@ -0,0 +1,166 @@ +package handler + +import ( + "encoding/json" + "strings" + "sync" + + "github.com/Meesho/BharatMLStack/horizon/internal/externalcall" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/discoveryconfig" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorconfig" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" +) + +// batchFetchRelatedData efficiently fetches all discovery configs and service deployables in batch +func (p *Predator) batchFetchRelatedData(predatorConfigs []predatorconfig.PredatorConfig) (map[int]*discoveryconfig.DiscoveryConfig, map[int]*servicedeployableconfig.ServiceDeployableConfig, error) { + // Collect all unique discovery config IDs + discoveryConfigIDs := make([]int, 0, len(predatorConfigs)) + discoveryIDSet := make(map[int]bool) + + for _, config := range predatorConfigs { + if !discoveryIDSet[config.DiscoveryConfigID] { + discoveryConfigIDs = append(discoveryConfigIDs, config.DiscoveryConfigID) + discoveryIDSet[config.DiscoveryConfigID] = true + } + } + + // Batch fetch all discovery configs + discoveryConfigs := make(map[int]*discoveryconfig.DiscoveryConfig) + for _, id := range discoveryConfigIDs { + config, err := p.ServiceDiscoveryRepo.GetById(id) + if err != nil { + continue // Skip failed configs, same behavior as original + } + discoveryConfigs[id] = config + } + + // Collect all unique service deployable IDs + serviceDeployableIDs := make([]int, 0, len(discoveryConfigs)) + serviceDeployableIDSet := make(map[int]bool) + + for _, config := range discoveryConfigs { + if !serviceDeployableIDSet[config.ServiceDeployableID] { + serviceDeployableIDs = append(serviceDeployableIDs, config.ServiceDeployableID) + serviceDeployableIDSet[config.ServiceDeployableID] = true + } + } + + // Batch fetch all service deployables + serviceDeployables := make(map[int]*servicedeployableconfig.ServiceDeployableConfig) + for _, id := range serviceDeployableIDs { + deployable, err := p.ServiceDeployableRepo.GetById(id) + if err != nil { + continue // Skip failed deployables, same behavior as original + } + serviceDeployables[id] = deployable + } + + return discoveryConfigs, serviceDeployables, nil +} + +// batchFetchDeployableConfigs concurrently fetches deployable configs for all service deployables +func (p *Predator) batchFetchDeployableConfigs(serviceDeployables map[int]*servicedeployableconfig.ServiceDeployableConfig) (map[int]externalcall.Config, error) { + deployableConfigs := make(map[int]externalcall.Config) + var mu sync.Mutex + var wg sync.WaitGroup + + // Use a semaphore to limit concurrent API calls + semaphore := make(chan struct{}, 10) // Limit to 10 concurrent calls + + for id, deployable := range serviceDeployables { + wg.Add(1) + go func(deployableID int, sd *servicedeployableconfig.ServiceDeployableConfig) { + defer wg.Done() + semaphore <- struct{}{} // Acquire semaphore + defer func() { <-semaphore }() // Release semaphore + + infraConfig := p.infrastructureHandler.GetConfig(sd.Name, p.workingEnv) + // Convert to externalcall.Config for compatibility + config := externalcall.Config{ + MinReplica: infraConfig.MinReplica, + MaxReplica: infraConfig.MaxReplica, + RunningStatus: infraConfig.RunningStatus, + } + + mu.Lock() + deployableConfigs[deployableID] = config + mu.Unlock() + }(id, deployable) + } + + wg.Wait() + return deployableConfigs, nil +} + +// buildModelResponses constructs the final ModelResponse objects +func (p *Predator) buildModelResponses( + predatorConfigs []predatorconfig.PredatorConfig, + discoveryConfigs map[int]*discoveryconfig.DiscoveryConfig, + serviceDeployables map[int]*servicedeployableconfig.ServiceDeployableConfig, + deployableConfigs map[int]externalcall.Config, +) []ModelResponse { + results := make([]ModelResponse, 0, len(predatorConfigs)) + + for _, config := range predatorConfigs { + // Get discovery config + serviceDiscovery, exists := discoveryConfigs[config.DiscoveryConfigID] + if !exists { + continue // Skip if discovery config not found + } + + // Get service deployable + serviceDeployable, exists := serviceDeployables[serviceDiscovery.ServiceDeployableID] + if !exists { + continue // Skip if service deployable not found + } + + // Parse deployable config + var deployableConfig PredatorDeployableConfig + if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err != nil { + continue // Skip if config parsing fails + } + + // Get infrastructure config (HPA/replica info) + infraConfig := deployableConfigs[serviceDiscovery.ServiceDeployableID] + + deploymentConfig := map[string]any{ + machineTypeKey: deployableConfig.MachineType, + cpuThresholdKey: deployableConfig.CPUThreshold, + gpuThresholdKey: deployableConfig.GPUThreshold, + cpuRequestKey: deployableConfig.CPURequest, + cpuLimitKey: deployableConfig.CPULimit, + memRequestKey: deployableConfig.MemoryRequest, + memLimitKey: deployableConfig.MemoryLimit, + gpuRequestKey: deployableConfig.GPURequest, + gpuLimitKey: deployableConfig.GPULimit, + minReplicaKey: infraConfig.MinReplica, + maxReplicaKey: infraConfig.MaxReplica, + nodeSelectorKey: deployableConfig.NodeSelectorValue, + tritonImageTagKey: deployableConfig.TritonImageTag, + basePathKey: deployableConfig.GCSBucketPath, + } + + modelResponse := ModelResponse{ + ID: config.ID, + ModelName: config.ModelName, + MetaData: config.MetaData, + Host: serviceDeployable.Host, + MachineType: deployableConfig.MachineType, + DeploymentConfig: deploymentConfig, + MonitoringUrl: serviceDeployable.MonitoringUrl, + GCSPath: strings.TrimSuffix(deployableConfig.GCSBucketPath, "/*"), + CreatedBy: config.CreatedBy, + CreatedAt: config.CreatedAt, + UpdatedBy: config.UpdatedBy, + UpdatedAt: config.UpdatedAt, + DeployableRunningStatus: infraConfig.RunningStatus, + TestResults: config.TestResults, + HasNilData: config.HasNilData, + SourceModelName: config.SourceModelName, + } + + results = append(results, modelResponse) + } + + return results +} diff --git a/horizon/internal/predator/handler/predator_functional_testing.go b/horizon/internal/predator/handler/predator_functional_testing.go new file mode 100644 index 00000000..f66aa773 --- /dev/null +++ b/horizon/internal/predator/handler/predator_functional_testing.go @@ -0,0 +1,220 @@ +package handler + +import ( + "encoding/binary" + "fmt" + "math" + "strings" + + "github.com/Meesho/BharatMLStack/horizon/pkg/serializer" + "github.com/rs/zerolog/log" +) + +// flattenInputTo3DByteSlice converts input data to 3D byte slice format [batch][feature][bytes] +// This matches the working adapter's data structure expectations +func (p *Predator) flattenInputTo3DByteSlice(data any, dataType string) ([][][]byte, error) { + switch v := data.(type) { + case [][]float32: + batchSize := len(v) + if batchSize == 0 { + return [][][]byte{}, nil + } + featureCount := len(v[0]) + + result := make([][][]byte, batchSize) + for batchIdx := 0; batchIdx < batchSize; batchIdx++ { + result[batchIdx] = make([][]byte, featureCount) + for featureIdx := 0; featureIdx < featureCount; featureIdx++ { + val := v[batchIdx][featureIdx] + switch dataType { + case "FP16": + fp16Bytes, err := serializer.Float32ToFloat16Bytes(val) + if err != nil { + return nil, err + } + result[batchIdx][featureIdx] = fp16Bytes + case "FP32": + bytes := make([]byte, 4) + binary.LittleEndian.PutUint32(bytes, math.Float32bits(val)) + result[batchIdx][featureIdx] = bytes + default: + return nil, fmt.Errorf("unsupported numeric type %s for float32 data", dataType) + } + } + } + return result, nil + + default: + flattened, err := serializer.FlattenMatrixByType(data, dataType) + if err != nil { + return nil, err + } + + switch dataType { + case "FP16": + if f32slice, ok := flattened.([]float32); ok { + batchSize := 1 + featureCount := len(f32slice) + result := make([][][]byte, batchSize) + result[0] = make([][]byte, featureCount) + for i, val := range f32slice { + fp16Bytes, err := serializer.Float32ToFloat16Bytes(val) + if err != nil { + return nil, err + } + result[0][i] = fp16Bytes + } + return result, nil + } + case "BYTES": + if byteSlice, ok := flattened.([][]byte); ok { + result := make([][][]byte, 1) + result[0] = byteSlice + return result, nil + } + } + + return nil, fmt.Errorf("unsupported data format: %T for type %s", data, dataType) + } +} + +// getElementSize returns the byte size of a single element for the given data type +func getElementSize(dataType string) int { + switch strings.ToUpper(dataType) { + case "FP32", "TYPE_FP32": + return 4 + case "FP64", "TYPE_FP64": + return 8 + case "INT32", "TYPE_INT32": + return 4 + case "INT64", "TYPE_INT64": + return 8 + case "INT16", "TYPE_INT16": + return 2 + case "INT8", "TYPE_INT8": + return 1 + case "UINT32", "TYPE_UINT32": + return 4 + case "UINT64", "TYPE_UINT64": + return 8 + case "UINT16", "TYPE_UINT16": + return 2 + case "UINT8", "TYPE_UINT8": + return 1 + case "BOOL", "TYPE_BOOL": + return 1 + case "FP16", "TYPE_FP16": + return 2 + default: + return 0 + } +} + +// reshapeDataForBatch reshapes flattened data to preserve batch dimension +func reshapeDataForBatch(data interface{}, dims []int64) interface{} { + if len(dims) == 0 { + return data + } + + batchSize := dims[0] + featureDims := dims[1:] + + elementsPerBatch := int64(1) + for _, dim := range featureDims { + elementsPerBatch *= dim + } + + var dataSlice []interface{} + switch v := data.(type) { + case []interface{}: + dataSlice = v + case []string: + for _, item := range v { + dataSlice = append(dataSlice, item) + } + case []float32: + for _, item := range v { + dataSlice = append(dataSlice, item) + } + case []float64: + for _, item := range v { + dataSlice = append(dataSlice, item) + } + case []int32: + for _, item := range v { + dataSlice = append(dataSlice, item) + } + case []int64: + for _, item := range v { + dataSlice = append(dataSlice, item) + } + default: + return data + } + + var result [][]interface{} + for i := int64(0); i < batchSize; i++ { + start := i * elementsPerBatch + end := start + elementsPerBatch + if end <= int64(len(dataSlice)) { + batch := dataSlice[start:end] + result = append(result, batch) + } + } + + return result +} + +// convertDimsToIntSlice converts input.Dims to []int, handling nested interfaces and various types +func convertDimsToIntSlice(dims interface{}) ([]int, error) { + var result []int + + switch v := dims.(type) { + case []int: + result = make([]int, len(v)) + copy(result, v) + case []int64: + result = make([]int, len(v)) + for i, dim := range v { + result[i] = int(dim) + } + case []interface{}: + result = make([]int, len(v)) + for i, dim := range v { + switch d := dim.(type) { + case int: + result[i] = d + case int64: + result[i] = int(d) + case float64: + result[i] = int(d) + default: + return nil, fmt.Errorf("unsupported dimension type in slice: %T", d) + } + } + case int: + result = []int{v} + case int64: + result = []int{int(v)} + case float64: + result = []int{int(v)} + default: + return nil, fmt.Errorf("unsupported dims type: %T", v) + } + + for i, dim := range result { + if dim == -1 { + if i == 0 { + result[i] = 10 + } else { + result[i] = 128 + } + log.Debug().Msgf("Replaced dynamic dimension -1 at position %d with %d", i, result[i]) + } else if dim < 0 { + result[i] = 1 + log.Debug().Msgf("Replaced negative dimension %d at position %d with 1", dim, i) + } + } + + return result, nil +} diff --git a/horizon/internal/predator/handler/predator_helpers.go b/horizon/internal/predator/handler/predator_helpers.go new file mode 100644 index 00000000..91d53a03 --- /dev/null +++ b/horizon/internal/predator/handler/predator_helpers.go @@ -0,0 +1,108 @@ +package handler + +import ( + "fmt" + "path" + "strings" + + "github.com/Meesho/BharatMLStack/horizon/internal/constant" + pred "github.com/Meesho/BharatMLStack/horizon/internal/predator" +) + +func parseGCSURL(gcsURL string) (bucket, objectPath string, ok bool) { + if strings.HasPrefix(gcsURL, "gcs://") { + gcsURL = strings.Replace(gcsURL, "gcs://", "gs://", 1) + } + + if !strings.HasPrefix(gcsURL, gcsPrefix) { + return constant.EmptyString, constant.EmptyString, false + } + + trimmed := strings.TrimPrefix(gcsURL, gcsPrefix) + parts := strings.SplitN(trimmed, slashConstant, 2) + if len(parts) < 1 { + return constant.EmptyString, constant.EmptyString, false + } + + bucket = parts[0] + if len(parts) == 2 { + objectPath = parts[1] + } + return bucket, objectPath, true +} + +func extractGCSPath(gcsURL string) (bucket, objectPath string) { + bucket, objectPath, ok := parseGCSURL(gcsURL) + if !ok { + return constant.EmptyString, constant.EmptyString + } + return bucket, objectPath +} + +func extractGCSDetails(gcsURL string) (bucket, pathOnly, modelName string) { + bucket, objectPath, ok := parseGCSURL(gcsURL) + if !ok || objectPath == constant.EmptyString { + return constant.EmptyString, constant.EmptyString, constant.EmptyString + } + + segments := strings.Split(objectPath, slashConstant) + if len(segments) == 0 { + return bucket, constant.EmptyString, constant.EmptyString + } + + modelName = segments[len(segments)-1] + pathOnly = path.Join(segments[:len(segments)-1]...) + return bucket, pathOnly, modelName +} + +// GetDerivedModelName returns the derived model name with deployable tag +func (p *Predator) GetDerivedModelName(payloadObject Payload, requestType string) (string, error) { + if requestType != ScaleUpRequestType { + return payloadObject.ModelName, nil + } + serviceDeployableID := payloadObject.ConfigMapping.ServiceDeployableID + serviceDeployable, err := p.ServiceDeployableRepo.GetById(int(serviceDeployableID)) + if err != nil { + return constant.EmptyString, fmt.Errorf("%s: %w", errFetchDeployableConfig, err) + } + + deployableTag := serviceDeployable.DeployableTag + if deployableTag == "" { + return payloadObject.ModelName, nil + } + + derivedModelName := payloadObject.ModelName + deployableTagDelimiter + deployableTag + derivedModelName = derivedModelName + deployableTagDelimiter + scaleupTag + return derivedModelName, nil +} + +// GetOriginalModelName returns the original model name if no tag is found (backward compatibility) +func (p *Predator) GetOriginalModelName(derivedModelName string, serviceDeployableID int) (string, error) { + serviceDeployable, err := p.ServiceDeployableRepo.GetById(serviceDeployableID) + if err != nil { + return constant.EmptyString, fmt.Errorf("%s: %w", errFetchDeployableConfig, err) + } + + deployableTag := serviceDeployable.DeployableTag + if deployableTag == "" { + return derivedModelName, nil + } + + scaleupSuffix := deployableTagDelimiter + scaleupTag + derivedModelName = strings.TrimSuffix(derivedModelName, scaleupSuffix) + + deployableTagSuffix := deployableTagDelimiter + deployableTag + if originalName, foundSuffix := strings.CutSuffix(derivedModelName, deployableTagSuffix); foundSuffix { + return originalName, nil + } + + return derivedModelName, nil +} + +func (p *Predator) isNonProductionEnvironment() bool { + env := strings.ToLower(strings.TrimSpace(pred.AppEnv)) + if env == "prd" || env == "prod" { + return false + } + return true +} diff --git a/horizon/internal/predator/handler/predator_upload.go b/horizon/internal/predator/handler/predator_upload.go new file mode 100644 index 00000000..835a4ec3 --- /dev/null +++ b/horizon/internal/predator/handler/predator_upload.go @@ -0,0 +1,532 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + "path" + "regexp" + "strings" + + "github.com/Meesho/BharatMLStack/horizon/internal/externalcall" + pred "github.com/Meesho/BharatMLStack/horizon/internal/predator" + "github.com/rs/zerolog/log" + "google.golang.org/protobuf/encoding/prototext" +) + +// uploadSingleModel processes a single model upload with improved validation and error handling +func (p *Predator) uploadSingleModel(modelItem ModelUploadItem, bucket, basePath string, isPartial bool, authToken string) ModelUploadResult { + // Step 1: Extract and validate model name + modelName, err := p.extractModelName(modelItem.Metadata) + if err != nil { + return p.createErrorResult("unknown", "Failed to extract model name", err) + } + + log.Info().Msgf("Processing %s upload for model: %s from %s", + map[bool]string{true: "partial", false: "full"}[isPartial], modelName, modelItem.GCSPath) + + // Step 2: Setup destination paths + destPath := path.Join(basePath, modelName) + fullGCSPath := fmt.Sprintf("gs://%s/%s", bucket, destPath) + + // Step 3: Validate upload prerequisites + if err := p.validateUploadPrerequisites(bucket, destPath, isPartial, modelName); err != nil { + return p.createErrorResult(modelName, "Upload prerequisites validation failed", err) + } + + // Step 4: Validate source model structure and configuration + if err := p.validateSourceModel(modelItem.GCSPath, isPartial); err != nil { + return p.createErrorResult(modelName, "Source model validation failed", err) + } + + // Step 5: Validate metadata features (after model structure validation) + if err := p.validateMetadataFeatures(modelItem.Metadata, authToken); err != nil { + return p.createErrorResult(modelName, "Feature validation failed", err) + } + + // Step 6: Download/sync model files based on upload type + if err := p.syncModelFiles(modelItem.GCSPath, bucket, destPath, modelName, isPartial); err != nil { + return p.createErrorResult(modelName, "Model file sync failed", err) + } + + // Step 7: Copy config.pbtxt to prod config source (only in production) + if err := p.copyConfigToProdConfigSource(modelItem.GCSPath, modelName); err != nil { + return p.createErrorResult(modelName, "Failed to copy config to prod config source", err) + } + + // Upload processed metadata.json (always done regardless of partial/full) + metadataPath, err := p.uploadModelMetadata(modelItem.Metadata, bucket, destPath) + if err != nil { + return p.createErrorResult(modelName, "Metadata upload failed", err) + } + + log.Info().Msgf("Successfully completed %s upload for model: %s", + map[bool]string{true: "partial", false: "full"}[isPartial], modelName) + return ModelUploadResult{ + ModelName: modelName, + GCSPath: fullGCSPath, + MetadataPath: metadataPath, + Status: "success", + } +} + +// copyConfigToProdConfigSource copies config.pbtxt to the prod config source path +func (p *Predator) copyConfigToProdConfigSource(gcsPath, modelName string) error { + if pred.GcsConfigBucket == "" || pred.GcsConfigBasePath == "" { + log.Warn().Msg("Config source not configured, skipping config.pbtxt copy to config source") + return nil + } + + srcBucket, srcPath := extractGCSPath(gcsPath) + if srcBucket == "" || srcPath == "" { + return fmt.Errorf("invalid GCS path format: %s", gcsPath) + } + + srcConfigPath := path.Join(srcPath, configFile) + configData, err := p.GcsClient.ReadFile(srcBucket, srcConfigPath) + if err != nil { + return fmt.Errorf("failed to read config.pbtxt from source: %w", err) + } + + updatedConfigData := p.replaceModelNameInConfigPreservingFormat(configData, modelName) + + destConfigPath := path.Join(pred.GcsConfigBasePath, modelName, configFile) + if err := p.GcsClient.UploadFile(pred.GcsConfigBucket, destConfigPath, updatedConfigData); err != nil { + return fmt.Errorf("failed to upload config.pbtxt to config source: %w", err) + } + + log.Info().Msgf("Successfully copied config.pbtxt to config source with model name %s: gs://%s/%s", + modelName, pred.GcsConfigBucket, destConfigPath) + return nil +} + +// createErrorResult creates a standardized error result +func (p *Predator) createErrorResult(modelName, message string, err error) ModelUploadResult { + return ModelUploadResult{ + ModelName: modelName, + Status: "error", + Error: fmt.Sprintf("%s: %v", message, err), + } +} + +// generateUploadSummary creates response message and status code based on results +func (p *Predator) generateUploadSummary(successCount, failCount int, results []ModelUploadResult) (string, int) { + switch { + case failCount == 0: + return fmt.Sprintf("%d model uploaded successfully", successCount), http.StatusOK + case successCount == 0: + return fmt.Sprintf("%d model failed to upload. Errors: %s", failCount, results[0].Error), http.StatusBadRequest + default: + return fmt.Sprintf("Mixed results: %d successful, %d failed. Errors: %s", successCount, failCount, results[0].Error), http.StatusPartialContent + } +} + +// validateUploadPrerequisites validates upload requirements based on type +func (p *Predator) validateUploadPrerequisites(bucket, destPath string, isPartial bool, modelName string) error { + exists, err := p.GcsClient.CheckFolderExists(bucket, destPath) + if err != nil { + return fmt.Errorf("failed to check model existence: %w", err) + } + + if isPartial { + if !exists { + return fmt.Errorf("partial upload requires existing model folder at destination") + } + log.Info().Msgf("Partial upload: updating existing model %s", modelName) + } else { + if exists { + log.Info().Msgf("Full upload: replacing existing model %s", modelName) + } else { + log.Info().Msgf("Full upload: creating new model %s", modelName) + } + } + + return nil +} + +// validateSourceModel validates the source model structure and configuration +func (p *Predator) validateSourceModel(gcsPath string, isPartial bool) error { + srcBucket, srcPath := extractGCSPath(gcsPath) + if srcBucket == "" || srcPath == "" { + return fmt.Errorf("invalid GCS path format: %s", gcsPath) + } + + if err := p.validateModelConfiguration(gcsPath); err != nil { + return fmt.Errorf("config.pbtxt validation failed: %w", err) + } + + if !isPartial { + if err := p.validateCompleteModelStructure(srcBucket, srcPath); err != nil { + return fmt.Errorf("complete model structure validation failed: %w", err) + } + } + + return nil +} + +// validateCompleteModelStructure validates that version "1" folder exists with non-empty files +func (p *Predator) validateCompleteModelStructure(srcBucket, srcPath string) error { + versionPath := path.Join(srcPath, "1") + exists, err := p.GcsClient.CheckFolderExists(srcBucket, versionPath) + if err != nil { + return fmt.Errorf("failed to check version folder 1/: %w", err) + } + + if !exists { + return fmt.Errorf("version folder 1/ not found - required for complete model") + } + + if err := p.validateVersionHasFiles(srcBucket, versionPath); err != nil { + return fmt.Errorf("version folder 1/ validation failed: %w", err) + } + + log.Info().Msg("Model structure validation passed - version 1/ folder exists with files") + return nil +} + +// validateVersionHasFiles checks if version folder has at least one non-empty file +func (p *Predator) validateVersionHasFiles(srcBucket, versionPath string) error { + exists, err := p.GcsClient.CheckFolderExists(srcBucket, versionPath) + if err != nil { + return fmt.Errorf("failed to check version folder contents: %w", err) + } + + if !exists { + return fmt.Errorf("version folder 1/ is empty - must contain model files") + } + + log.Info().Msg("Version folder 1/ contains files") + return nil +} + +// syncModelFiles handles file synchronization based on upload type +func (p *Predator) syncModelFiles(gcsPath, destBucket, destPath, modelName string, isPartial bool) error { + if isPartial { + return p.syncPartialFiles(gcsPath, destBucket, destPath, modelName) + } + return p.syncFullModel(gcsPath, destBucket, destPath, modelName) +} + +// uploadModelMetadata uploads metadata.json to GCS and returns the full path +func (p *Predator) uploadModelMetadata(metadata interface{}, bucket, destPath string) (string, error) { + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return "", fmt.Errorf("failed to serialize metadata: %w", err) + } + + metadataPath := path.Join(destPath, "metadata.json") + if err := p.GcsClient.UploadFile(bucket, metadataPath, metadataBytes); err != nil { + return "", fmt.Errorf("failed to upload metadata: %w", err) + } + + return fmt.Sprintf("gs://%s/%s", bucket, metadataPath), nil +} + +// validateMetadataFeatures validates the features in metadata against online/offline validation APIs +func (p *Predator) validateMetadataFeatures(metadata interface{}, authToken string) error { + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("failed to marshal metadata: %w", err) + } + + var featureMeta FeatureMetadata + if err := json.Unmarshal(metadataBytes, &featureMeta); err != nil { + return fmt.Errorf("failed to unmarshal metadata: %w", err) + } + + if authToken == "" { + return fmt.Errorf("authorization token is required for feature validation") + } + + onlineFeaturesByEntity := make(map[string][]string) + pricingFeaturesByEntity := make(map[string][]string) + var offlineFeatures []string + + for _, input := range featureMeta.Inputs { + for _, feature := range input.Features { + featureType, entity, gf, featureName, isValid := externalcall.ParseFeatureString(feature) + if !isValid { + log.Error().Msgf("Invalid feature format: %s", feature) + return fmt.Errorf("invalid feature format: %s", feature) + } + + switch featureType { + case "ONLINE_FEATURE", "PARENT_ONLINE_FEATURE": + onlineFeaturesByEntity[entity] = append(onlineFeaturesByEntity[entity], gf) + log.Info().Msgf("Added online feature for validation - entity: %s, feature: %s", entity, gf) + case "OFFLINE_FEATURE", "PARENT_OFFLINE_FEATURE": + offlineFeatures = append(offlineFeatures, featureName) + log.Info().Msgf("Added offline feature for validation: %s", featureName) + case "RTP_FEATURE", "PARENT_RTP_FEATURE": + fullFeature := entity + ":" + gf + pricingFeaturesByEntity[entity] = append(pricingFeaturesByEntity[entity], fullFeature) + log.Info().Msgf("Added pricing feature for validation - entity: %s, full feature: %s", entity, fullFeature) + case "DEFAULT_FEATURE", "PARENT_DEFAULT_FEATURE", "MODEL_FEATURE", "CALIBRATION": + log.Info().Msgf("Skipping API validation for feature type %s: %s (no validation required)", featureType, feature) + continue + default: + log.Warn().Msgf("Unknown feature type %s for feature: %s", featureType, feature) + } + } + } + + for entity, features := range onlineFeaturesByEntity { + if err := p.validateOnlineFeatures(entity, features, authToken); err != nil { + return fmt.Errorf("online feature validation failed for entity %s: %w", entity, err) + } + } + + if len(offlineFeatures) > 0 { + if err := p.validateOfflineFeatures(offlineFeatures, authToken); err != nil { + return fmt.Errorf("offline feature validation failed: %w", err) + } + } + + for entity, features := range pricingFeaturesByEntity { + if err := p.validatePricingFeatures(entity, features); err != nil { + return fmt.Errorf("pricing feature validation failed for entity %s: %w", entity, err) + } + } + + return nil +} + +// validateOnlineFeatures validates online features for a specific entity +func (p *Predator) validateOnlineFeatures(entity string, features []string, token string) error { + response, err := p.featureValidationClient.ValidateOnlineFeatures(entity, token) + if err != nil { + return fmt.Errorf("failed to call online validation API: %w", err) + } + + for _, feature := range features { + if !externalcall.ValidateFeatureExists(feature, response) { + return fmt.Errorf("online feature '%s' does not exist for entity '%s'", feature, entity) + } + } + + log.Info().Msgf("Successfully validated %d online features for entity %s", len(features), entity) + return nil +} + +// validateOfflineFeatures validates offline features by checking online mapping +func (p *Predator) validateOfflineFeatures(features []string, token string) error { + response, err := p.featureValidationClient.ValidateOfflineFeatures(features, token) + if err != nil { + return fmt.Errorf("failed to call offline validation API: %w", err) + } + + if response.Error != "" { + return fmt.Errorf("offline validation API returned error: %s", response.Error) + } + + for _, feature := range features { + if _, exists := response.Data[feature]; !exists { + return fmt.Errorf("offline feature '%s' does not have an online mapping", feature) + } + } + + log.Info().Msgf("Successfully validated %d offline features", len(features)) + return nil +} + +// validatePricingFeatures validates pricing features for a specific entity +func (p *Predator) validatePricingFeatures(entity string, features []string) error { + if !pred.IsMeeshoEnabled { + return nil + } + response, err := externalcall.PricingClient.GetDataTypes(entity) + if err != nil { + return fmt.Errorf("failed to call pricing service API: %w", err) + } + + for _, feature := range features { + if !externalcall.ValidatePricingFeatureExists(feature, response) { + return fmt.Errorf("pricing feature '%s' does not exist for entity '%s'", feature, entity) + } + } + + log.Info().Msgf("Successfully validated %d pricing features for entity %s", len(features), entity) + return nil +} + +// extractModelName extracts model name from metadata +func (p *Predator) extractModelName(metadata interface{}) (string, error) { + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return "", fmt.Errorf("failed to marshal metadata: %w", err) + } + + var metadataMap map[string]interface{} + if err := json.Unmarshal(metadataBytes, &metadataMap); err != nil { + return "", fmt.Errorf("failed to unmarshal metadata: %w", err) + } + + modelName, exists := metadataMap["model_name"] + if !exists { + return "", fmt.Errorf("model_name not found in metadata") + } + + modelNameStr, ok := modelName.(string) + if !ok || modelNameStr == "" { + return "", fmt.Errorf("model_name must be a non-empty string") + } + + return modelNameStr, nil +} + +// syncFullModel syncs all model files for full upload +func (p *Predator) syncFullModel(gcsPath, destBucket, destPath, modelName string) error { + log.Info().Msgf("Syncing full model from GCS path: %s", gcsPath) + + srcBucket, srcPath := extractGCSPath(gcsPath) + if srcBucket == "" || srcPath == "" { + return fmt.Errorf("invalid GCS path format: %s", gcsPath) + } + + pathSegments := strings.Split(strings.TrimSuffix(srcPath, "/"), "/") + srcModelName := pathSegments[len(pathSegments)-1] + srcBasePath := strings.TrimSuffix(srcPath, "/"+srcModelName) + + log.Info().Msgf("Full upload: transferring all files from %s/%s to %s/%s", + srcBucket, srcPath, destBucket, destPath) + + return p.GcsClient.TransferFolder(srcBucket, srcBasePath, srcModelName, + destBucket, strings.TrimSuffix(destPath, "/"+modelName), modelName) +} + +// syncPartialFiles syncs only config.pbtxt for partial upload +func (p *Predator) syncPartialFiles(gcsPath, destBucket, destPath, modelName string) error { + srcBucket, srcPath := extractGCSPath(gcsPath) + if srcBucket == "" || srcPath == "" { + return fmt.Errorf("invalid GCS path format: %s", gcsPath) + } + + filesToSync := []string{"config.pbtxt"} + log.Info().Msgf("Partial upload: syncing %v for model %s", filesToSync, modelName) + + for _, fileName := range filesToSync { + srcFilePath := path.Join(srcPath, fileName) + destFilePath := path.Join(destPath, fileName) + + data, err := p.GcsClient.ReadFile(srcBucket, srcFilePath) + if err != nil { + return fmt.Errorf("required file %s not found in source %s/%s: %w", + fileName, srcBucket, srcFilePath, err) + } + + if err := p.GcsClient.UploadFile(destBucket, destFilePath, data); err != nil { + return fmt.Errorf("failed to upload %s: %w", fileName, err) + } + + log.Info().Msgf("Successfully synced %s for partial upload of model %s", fileName, modelName) + } + + return nil +} + +// validateModelConfiguration validates the model configuration +func (p *Predator) validateModelConfiguration(gcsPath string) error { + log.Info().Msgf("Validating model configuration for GCS path: %s", gcsPath) + + srcBucket, srcPath := extractGCSPath(gcsPath) + if srcBucket == "" || srcPath == "" { + return fmt.Errorf("invalid GCS path format: %s", gcsPath) + } + + configPath := path.Join(srcPath, configFile) + configData, err := p.GcsClient.ReadFile(srcBucket, configPath) + if err != nil { + return fmt.Errorf("failed to read config.pbtxt from %s/%s: %w", srcBucket, configPath, err) + } + + var modelConfig ModelConfig + if err := prototext.Unmarshal(configData, &modelConfig); err != nil { + return fmt.Errorf("failed to parse config.pbtxt as proto: %w", err) + } + + log.Info().Msgf("Parsed model config - Name: %s, Backend: %s", modelConfig.Name, modelConfig.Backend) + return nil +} + +// cleanEnsembleScheduling cleans up ensemble scheduling to avoid storing {"step": null} +func (p *Predator) cleanEnsembleScheduling(metadata MetaData) MetaData { + if len(metadata.Ensembling.Step) == 0 { + metadata.Ensembling = Ensembling{Step: nil} + } + return metadata +} + +// copyConfigToNewNameInConfigSource copies config from old model name to new in config source +func (p *Predator) copyConfigToNewNameInConfigSource(oldModelName, newModelName string) error { + if oldModelName == newModelName { + return nil + } + + if pred.GcsConfigBucket == "" || pred.GcsConfigBasePath == "" { + log.Warn().Msg("Config source not configured, skipping config.pbtxt copy in config source") + return nil + } + + destConfigPath := path.Join(pred.GcsConfigBasePath, newModelName, configFile) + exists, err := p.GcsClient.CheckFileExists(pred.GcsConfigBucket, destConfigPath) + if err != nil { + log.Warn().Err(err).Msgf("Failed to check if config exists for %s, will attempt copy anyway", newModelName) + } else if exists { + log.Info().Msgf("Config already exists for %s in config source, skipping copy", newModelName) + return nil + } + + srcConfigPath := path.Join(pred.GcsConfigBasePath, oldModelName, configFile) + + configData, err := p.GcsClient.ReadFile(pred.GcsConfigBucket, srcConfigPath) + if err != nil { + return fmt.Errorf("failed to read config.pbtxt from %s: %w", srcConfigPath, err) + } + + updatedConfigData := p.replaceModelNameInConfigPreservingFormat(configData, newModelName) + + if err := p.GcsClient.UploadFile(pred.GcsConfigBucket, destConfigPath, updatedConfigData); err != nil { + return fmt.Errorf("failed to upload config.pbtxt to %s: %w", destConfigPath, err) + } + + log.Info().Msgf("Successfully copied config.pbtxt from %s to %s in config source", + oldModelName, newModelName) + return nil +} + +// replaceModelNameInConfigPreservingFormat updates only the top-level model name while preserving formatting +func (p *Predator) replaceModelNameInConfigPreservingFormat(data []byte, destModelName string) []byte { + content := string(data) + lines := strings.Split(content, "\n") + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "name:") { + leadingWhitespace := len(line) - len(strings.TrimLeft(line, " \t")) + if leadingWhitespace >= 2 { + continue + } + + namePattern := regexp.MustCompile(`name\s*:\s*"([^"]+)"`) + matches := namePattern.FindStringSubmatch(line) + if len(matches) > 1 { + oldModelName := matches[1] + loc := namePattern.FindStringIndex(line) + if loc != nil { + before := line[:loc[0]] + matched := line[loc[0]:loc[1]] + after := line[loc[1]:] + valuePattern := regexp.MustCompile(`"([^"]+)"`) + valueReplaced := valuePattern.ReplaceAllString(matched, fmt.Sprintf(`"%s"`, destModelName)) + lines[i] = before + valueReplaced + after + } else { + lines[i] = namePattern.ReplaceAllString(line, fmt.Sprintf(`name: "%s"`, destModelName)) + } + log.Info().Msgf("Replacing top-level model name in config.pbtxt: '%s' -> '%s'", oldModelName, destModelName) + break + } + } + } + + return []byte(strings.Join(lines, "\n")) +} diff --git a/horizon/internal/predator/handler/predator_validation.go b/horizon/internal/predator/handler/predator_validation.go new file mode 100644 index 00000000..15021baf --- /dev/null +++ b/horizon/internal/predator/handler/predator_validation.go @@ -0,0 +1,502 @@ +package handler + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "time" + + pred "github.com/Meesho/BharatMLStack/horizon/internal/predator" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorconfig" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorrequest" + "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/validationjob" + "github.com/rs/zerolog/log" +) + +func (p *Predator) ValidateDeleteRequest(predatorConfigList []predatorconfig.PredatorConfig, ids []int) (bool, error) { + if len(predatorConfigList) != len(ids) { + log.Error().Err(errors.New(errModelNotFound)).Msgf("model not found for ids %v", ids) + return false, errors.New(errModelNotFound) + } + + // Create maps for quick lookup + requestedModelMap := make(map[string]predatorconfig.PredatorConfig) // modelName -> config + requestedDeployableMap := make(map[int]bool) // serviceDeployableID -> exists + deployableModelMap := make(map[int]map[string]predatorconfig.PredatorConfig) // deployableID -> modelName -> config + + // Build maps from requested models + for _, predatorConfig := range predatorConfigList { + // Get service deployable ID for this model + discoveryConfig, err := p.ServiceDiscoveryRepo.GetById(predatorConfig.DiscoveryConfigID) + if err != nil { + log.Error().Err(err).Msgf("failed to fetch discovery config for model %s", predatorConfig.ModelName) + return false, fmt.Errorf(errFailedToFetchDiscoveryConfigForModel, predatorConfig.ModelName, err) + } + + requestedModelMap[predatorConfig.ModelName] = predatorConfig + requestedDeployableMap[discoveryConfig.ServiceDeployableID] = true + + // Group models by deployable + if deployableModelMap[discoveryConfig.ServiceDeployableID] == nil { + deployableModelMap[discoveryConfig.ServiceDeployableID] = make(map[string]predatorconfig.PredatorConfig) + } + deployableModelMap[discoveryConfig.ServiceDeployableID][predatorConfig.ModelName] = predatorConfig + } + + // Check for duplicate model names within same deployable + for deployableID, models := range deployableModelMap { + if len(models) > 1 { + // Check if any model names are duplicated within this deployable + modelNameCount := make(map[string]int) + for modelName := range models { + modelNameCount[modelName]++ + } + for modelName, count := range modelNameCount { + if count > 1 { + return false, fmt.Errorf(errDuplicateModelNameInDeployable, modelName, deployableID) + } + } + } + } + + // Validate ensemble-child group deletion requirements + if err := p.validateEnsembleChildGroupDeletion(requestedModelMap, deployableModelMap); err != nil { + return false, err + } + + return true, nil +} + +func (p *Predator) validateEnsembleChildGroupDeletion(requestedModelMap map[string]predatorconfig.PredatorConfig, deployableModelMap map[int]map[string]predatorconfig.PredatorConfig) error { + // Get all active models to check for ensemble relationships + allModels, err := p.PredatorConfigRepo.FindAllActiveConfig() + if err != nil { + log.Error().Err(err).Msgf("failed to fetch all active models") + return fmt.Errorf("failed to fetch all active models: %w", err) + } + + // Group all models by deployable for easier lookup + allModelsByDeployable := make(map[int]map[string]predatorconfig.PredatorConfig) + for _, model := range allModels { + discoveryConfig, err := p.ServiceDiscoveryRepo.GetById(model.DiscoveryConfigID) + if err != nil { + log.Error().Err(err).Msgf("failed to fetch discovery config for model %s", model.ModelName) + continue + } + + if allModelsByDeployable[discoveryConfig.ServiceDeployableID] == nil { + allModelsByDeployable[discoveryConfig.ServiceDeployableID] = make(map[string]predatorconfig.PredatorConfig) + } + allModelsByDeployable[discoveryConfig.ServiceDeployableID][model.ModelName] = model + } + + // Check each deployable for ensemble-child relationships + for deployableID, modelsInDeployable := range allModelsByDeployable { + requestedModelsInDeployable := deployableModelMap[deployableID] + if requestedModelsInDeployable == nil { + continue // No models from this deployable in the delete request + } + + // Check each model in this deployable + for modelName, model := range modelsInDeployable { + var metadata MetaData + if err := json.Unmarshal(model.MetaData, &metadata); err != nil { + log.Error().Err(err).Msgf("failed to unmarshal metadata for model %s", modelName) + continue + } + + // Check if this is an ensemble model + if len(metadata.Ensembling.Step) > 0 { + // This is an ensemble model + isEnsembleInRequest := requestedModelsInDeployable[modelName].ID != 0 + + // Check each child of this ensemble + for _, step := range metadata.Ensembling.Step { + childModelName := step.ModelName + isChildInRequest := requestedModelsInDeployable[childModelName].ID != 0 + + // If ensemble is in request, all children must be in request + if isEnsembleInRequest && !isChildInRequest { + return fmt.Errorf(errEnsembleMissingChild, modelName, childModelName) + } + + // If child is in request, ensemble must be in request + if isChildInRequest && !isEnsembleInRequest { + return fmt.Errorf(errChildMissingEnsemble, childModelName, modelName) + } + } + } + } + } + + return nil +} + +// releaseLockWithError is a helper function to release lock and log error +func (p *Predator) releaseLockWithError(lockID uint, groupID, errorMsg string) { + if releaseErr := p.validationLockRepo.ReleaseLock(lockID); releaseErr != nil { + log.Error().Err(releaseErr).Msgf("Failed to release validation lock for group ID %s after error: %s", groupID, errorMsg) + } + log.Error().Msgf("Validation failed for group ID %s: %s", groupID, errorMsg) +} + +// getTestDeployableID determines the appropriate test deployable ID based on machine type +func (p *Predator) getTestDeployableID(payload *Payload) (int, error) { + // Get the target deployable ID from the request + targetDeployableID := int(payload.ConfigMapping.ServiceDeployableID) + // Fetch the service deployable config to check machine type + serviceDeployable, err := p.ServiceDeployableRepo.GetById(targetDeployableID) + if err != nil { + return 0, fmt.Errorf("failed to fetch service deployable config: %w", err) + } + + // Parse the deployable config to extract machine type + var deployableConfig PredatorDeployableConfig + if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err != nil { + return 0, fmt.Errorf("failed to parse service deployable config: %w", err) + } + + // Select test deployable ID based on machine type + switch strings.ToUpper(deployableConfig.MachineType) { + case "CPU": + log.Info().Msgf("Using CPU test deployable ID: %d", pred.TestDeployableID) + return pred.TestDeployableID, nil + case "GPU": + log.Info().Msgf("Using GPU test deployable ID: %d", pred.TestGpuDeployableID) + return pred.TestGpuDeployableID, nil + default: + // Default to CPU if machine type is not specified or unknown + log.Warn().Msgf("Unknown machine type '%s', defaulting to CPU test deployable ID: %d", + deployableConfig.MachineType, pred.TestDeployableID) + return pred.TestDeployableID, nil + } +} + +// getServiceNameFromDeployable extracts service name from deployable configuration +func (p *Predator) getServiceNameFromDeployable(deployableID int) (string, error) { + serviceDeployable, err := p.ServiceDeployableRepo.GetById(deployableID) + if err != nil { + return "", fmt.Errorf("failed to get deployable config: %w", err) + } + return serviceDeployable.Name, nil +} + +// performAsyncValidation performs the actual validation process asynchronously +func (p *Predator) performAsyncValidation(job *validationjob.Table, requests []predatorrequest.PredatorRequest, payload *Payload, testDeployableID int) { + defer func() { + // Always release the lock when validation completes + if releaseErr := p.validationLockRepo.ReleaseLock(job.LockID); releaseErr != nil { + log.Error().Err(releaseErr).Msgf("Failed to release validation lock for job %d", job.ID) + } + log.Info().Msgf("Released validation lock for job %d", job.ID) + }() + + log.Info().Msgf("Starting async validation for job %d, group %s", job.ID, job.GroupID) + + // Step 1: Clear temporary deployable + if err := p.clearTemporaryDeployable(testDeployableID); err != nil { + log.Error().Err(err).Msg("Failed to clear temporary deployable") + p.failValidationJob(job.ID, "Failed to clear temporary deployable: "+err.Error()) + return + } + + // Step 2: Copy existing models to temporary deployable + targetDeployableID := int(payload.ConfigMapping.ServiceDeployableID) + if err := p.copyExistingModelsToTemporary(targetDeployableID, testDeployableID); err != nil { + log.Error().Err(err).Msg("Failed to copy existing models to temporary deployable") + p.failValidationJob(job.ID, "Failed to copy existing models: "+err.Error()) + return + } + + // Step 3: Copy new models from request to temporary deployable + if err := p.copyRequestModelsToTemporary(requests, testDeployableID); err != nil { + log.Error().Err(err).Msg("Failed to copy request models to temporary deployable") + p.failValidationJob(job.ID, "Failed to copy request models: "+err.Error()) + return + } + + // Step 4: Restart temporary deployable + if err := p.restartTemporaryDeployable(testDeployableID); err != nil { + log.Error().Err(err).Msg("Failed to restart temporary deployable") + p.failValidationJob(job.ID, "Failed to restart temporary deployable: "+err.Error()) + return + } + + // Update job status to checking and record restart time + now := time.Now() + if err := p.validationJobRepo.UpdateStatus(job.ID, validationjob.StatusChecking, ""); err != nil { + log.Error().Err(err).Msgf("Failed to update job %d status to checking", job.ID) + } + + // Update restart time in the job + job.RestartedAt = &now + job.Status = validationjob.StatusChecking + + // Step 5: Start health checking process + p.startHealthCheckingProcess(job) +} + +// startHealthCheckingProcess monitors the deployment health and updates validation status +func (p *Predator) startHealthCheckingProcess(job *validationjob.Table) { + log.Info().Msgf("Starting health check process for job %d, service %s", job.ID, job.ServiceName) + + for job.HealthCheckCount < job.MaxHealthChecks { + // Wait for the specified interval before checking + time.Sleep(time.Duration(job.HealthCheckInterval) * time.Second) + + // Increment health check count + if err := p.validationJobRepo.IncrementHealthCheck(job.ID); err != nil { + log.Error().Err(err).Msgf("Failed to increment health check count for job %d", job.ID) + } + job.HealthCheckCount++ + + // Check deployment health using infrastructure handler + isHealthy, err := p.checkDeploymentHealth(job.ServiceName) + if err != nil { + log.Error().Err(err).Msgf("Failed to check deployment health for job %d", job.ID) + continue // Continue checking, don't fail immediately on health check errors + } + + if isHealthy { + log.Info().Msgf("Deployment is healthy for job %d, validation successful", job.ID) + p.completeValidationJob(job.ID, true, "Deployment is healthy and running successfully") + p.updateRequestValidationStatus(job.GroupID, true) + return + } + + log.Info().Msgf("Deployment not yet healthy for job %d, check %d/%d", job.ID, job.HealthCheckCount, job.MaxHealthChecks) + } + + // If we reach here, max health checks exceeded + log.Warn().Msgf("Max health checks exceeded for job %d, marking as failed", job.ID) + p.completeValidationJob(job.ID, false, fmt.Sprintf("Deployment failed to become healthy after %d checks", job.MaxHealthChecks)) + p.updateRequestValidationStatus(job.GroupID, false) +} + +// checkDeploymentHealth checks if the deployment is healthy using infrastructure handler +func (p *Predator) checkDeploymentHealth(serviceName string) (bool, error) { + resourceDetail, err := p.infrastructureHandler.GetResourceDetail(serviceName, p.workingEnv) + if err != nil { + return false, fmt.Errorf("failed to get resource detail: %w", err) + } + + if resourceDetail == nil || len(resourceDetail.Nodes) == 0 { + return false, nil + } + + healthyPodCount := 0 + totalPodCount := 0 + + for _, node := range resourceDetail.Nodes { + if node.Kind == "Deployment" { + totalPodCount++ + if node.Health.Status == "Healthy" { + healthyPodCount++ + } + } + } + + log.Info().Msgf("Health check for service %s: %d/%d pods healthy and running", serviceName, healthyPodCount, totalPodCount) + + if totalPodCount == healthyPodCount { + return true, nil + } + // Consider deployment healthy if at least one pod is healthy and running + return false, nil +} + +// failValidationJob marks a validation job as failed +func (p *Predator) failValidationJob(jobID uint, errorMessage string) { + if err := p.validationJobRepo.UpdateValidationResult(jobID, false, errorMessage); err != nil { + log.Error().Err(err).Msgf("Failed to update validation job %d as failed", jobID) + } +} + +// completeValidationJob marks a validation job as completed +func (p *Predator) completeValidationJob(jobID uint, success bool, message string) { + if err := p.validationJobRepo.UpdateValidationResult(jobID, success, message); err != nil { + log.Error().Err(err).Msgf("Failed to update validation job %d as completed", jobID) + } +} + +// updateRequestValidationStatus updates the request table with validation results +func (p *Predator) updateRequestValidationStatus(groupID string, success bool) { + id, err := strconv.ParseUint(groupID, 10, 32) + if err != nil { + log.Error().Err(err).Msgf("Failed to parse group ID %s for status update", groupID) + return + } + + requests, err := p.Repo.GetAllByGroupID(uint(id)) + if err != nil { + log.Error().Err(err).Msgf("Failed to get requests for group ID %s", groupID) + return + } + + // Update all requests in the group + for _, request := range requests { + request.UpdatedAt = time.Now() + request.IsValid = success + if !success { + request.RejectReason = "Validation Failed" + request.Status = statusRejected + request.UpdatedBy = "Validation Job" + request.UpdatedAt = time.Now() + request.Active = false + } + if err := p.Repo.Update(&request); err != nil { + log.Error().Err(err).Msgf("Failed to update request %d status", request.RequestID) + } else { + log.Info().Msgf("Updated request %d status to %s", request.RequestID, request.Status) + } + } +} + +// clearTemporaryDeployable clears all models from the temporary deployable GCS path +func (p *Predator) clearTemporaryDeployable(testDeployableID int) error { + // Get temporary deployable config + testServiceDeployable, err := p.ServiceDeployableRepo.GetById(testDeployableID) + if err != nil { + return fmt.Errorf("failed to fetch temporary service deployable: %w", err) + } + + var tempDeployableConfig PredatorDeployableConfig + if err := json.Unmarshal(testServiceDeployable.Config, &tempDeployableConfig); err != nil { + return fmt.Errorf("failed to parse temporary deployable config: %w", err) + } + + if tempDeployableConfig.GCSBucketPath != "NA" { + // Extract bucket and path from temporary deployable config + tempBucket, tempPath := extractGCSPath(strings.TrimSuffix(tempDeployableConfig.GCSBucketPath, "/*")) + + // Clear all models from temporary deployable + log.Info().Msgf("Clearing temporary deployable GCS path: gs://%s/%s", tempBucket, tempPath) + if err := p.GcsClient.DeleteFolder(tempBucket, tempPath, ""); err != nil { + return fmt.Errorf("failed to clear temporary deployable GCS path: %w", err) + } + } + + return nil +} + +// copyExistingModelsToTemporary copies all existing models from target deployable to temporary deployable +func (p *Predator) copyExistingModelsToTemporary(targetDeployableID, tempDeployableID int) error { + // Get target deployable config + targetServiceDeployable, err := p.ServiceDeployableRepo.GetById(targetDeployableID) + if err != nil { + return fmt.Errorf("failed to fetch target service deployable: %w", err) + } + + var targetDeployableConfig PredatorDeployableConfig + if err := json.Unmarshal(targetServiceDeployable.Config, &targetDeployableConfig); err != nil { + return fmt.Errorf("failed to parse target deployable config: %w", err) + } + + // Get temporary deployable config + tempServiceDeployable, err := p.ServiceDeployableRepo.GetById(tempDeployableID) + if err != nil { + return fmt.Errorf("failed to fetch temporary service deployable: %w", err) + } + + var tempDeployableConfig PredatorDeployableConfig + if err := json.Unmarshal(tempServiceDeployable.Config, &tempDeployableConfig); err != nil { + return fmt.Errorf("failed to parse temporary deployable config: %w", err) + } + + if targetDeployableConfig.GCSBucketPath != "NA" { + // Extract GCS paths + targetBucket, targetPath := extractGCSPath(strings.TrimSuffix(targetDeployableConfig.GCSBucketPath, "/*")) + tempBucket, tempPath := extractGCSPath(strings.TrimSuffix(tempDeployableConfig.GCSBucketPath, "/*")) + + // Copy all existing models from target to temporary deployable + return p.copyAllModelsFromActualToStaging(targetBucket, targetPath, tempBucket, tempPath) + } else { + return nil + } +} + +// copyRequestModelsToTemporary copies the requested models to temporary deployable +func (p *Predator) copyRequestModelsToTemporary(requests []predatorrequest.PredatorRequest, tempDeployableID int) error { + // Get temporary deployable config + tempServiceDeployable, err := p.ServiceDeployableRepo.GetById(tempDeployableID) + if err != nil { + return fmt.Errorf("failed to fetch temporary service deployable: %w", err) + } + + var tempDeployableConfig PredatorDeployableConfig + if err := json.Unmarshal(tempServiceDeployable.Config, &tempDeployableConfig); err != nil { + return fmt.Errorf("failed to parse temporary deployable config: %w", err) + } + + tempBucket, tempPath := extractGCSPath(strings.TrimSuffix(tempDeployableConfig.GCSBucketPath, "/*")) + + isNotProd := p.isNonProductionEnvironment() + + // Copy each requested model from default GCS location to temporary deployable + for _, request := range requests { + modelName := request.ModelName + payload, err := p.processPayload(request) + if err != nil { + log.Error().Err(err).Msgf("Failed to parse payload for request %d", request.RequestID) + return fmt.Errorf("failed to parse payload for request %d: %w", request.RequestID, err) + } + + var sourceBucket, sourcePath, sourceModelName string + if payload.ModelSource != "" { + sourceBucket, sourcePath, sourceModelName = extractGCSDetails(payload.ModelSource) + log.Info().Msgf("Using ModelSource from payload for validation: gs://%s/%s/%s", + sourceBucket, sourcePath, sourceModelName) + } else { + sourceBucket = pred.GcsModelBucket + sourcePath = pred.GcsModelBasePath + sourceModelName = modelName + log.Info().Msgf("Using default model source for validation: gs://%s/%s/%s", + sourceBucket, sourcePath, sourceModelName) + } + log.Info().Msgf("Copying model %s from gs://%s/%s/%s to temporary deployable gs://%s/%s", + modelName, sourceBucket, sourcePath, sourceModelName, tempBucket, tempPath) + + if isNotProd { + if err := p.GcsClient.TransferFolder(sourceBucket, sourcePath, sourceModelName, + tempBucket, tempPath, modelName); err != nil { + return fmt.Errorf("failed to copy requested model %s to temporary deployable: %w", modelName, err) + } + } else { + if err := p.GcsClient.TransferFolderWithSplitSources( + sourceBucket, sourcePath, pred.GcsConfigBucket, pred.GcsConfigBasePath, + sourceModelName, tempBucket, tempPath, modelName, + ); err != nil { + return fmt.Errorf("failed to copy requested model %s to temporary deployable: %w", modelName, err) + } + } + + log.Info().Msgf("Successfully copied model %s to temporary deployable", modelName) + } + + return nil +} + +// restartTemporaryDeployable restarts the temporary deployable for validation +func (p *Predator) restartTemporaryDeployable(tempDeployableID int) error { + tempServiceDeployable, err := p.ServiceDeployableRepo.GetById(tempDeployableID) + if err != nil { + return fmt.Errorf("failed to fetch temporary service deployable: %w", err) + } + + // Extract isCanary from deployable config + var deployableConfig map[string]interface{} + isCanary := false + if err := json.Unmarshal(tempServiceDeployable.Config, &deployableConfig); err == nil { + if strategy, ok := deployableConfig["deploymentStrategy"].(string); ok && strategy == "canary" { + isCanary = true + } + } + if err := p.infrastructureHandler.RestartDeployment(tempServiceDeployable.Name, p.workingEnv, isCanary); err != nil { + return fmt.Errorf("failed to restart temporary deployable: %w", err) + } + + log.Info().Msgf("Successfully restarted temporary deployable: %s for validation", tempServiceDeployable.Name) + return nil +} From 2968ade428e9bd4cbeb883eec3b20ed42a8a7928 Mon Sep 17 00:00:00 2001 From: pavan-adari-meesho Date: Tue, 10 Feb 2026 12:22:28 +0530 Subject: [PATCH 22/24] inferflow refractor into multiple files --- horizon/internal/inferflow/README.md | 38 + .../internal/inferflow/handler/inferflow.go | 906 ------------------ .../inferflow/handler/inferflow_constants.go | 26 + .../inferflow/handler/inferflow_fetch.go | 86 ++ .../handler/inferflow_functional_testing.go | 182 ++++ .../inferflow/handler/inferflow_helpers.go | 25 + .../inferflow/handler/inferflow_review.go | 442 +++++++++ .../inferflow/handler/inferflow_validation.go | 194 ++++ 8 files changed, 993 insertions(+), 906 deletions(-) create mode 100644 horizon/internal/inferflow/README.md create mode 100644 horizon/internal/inferflow/handler/inferflow_constants.go create mode 100644 horizon/internal/inferflow/handler/inferflow_fetch.go create mode 100644 horizon/internal/inferflow/handler/inferflow_functional_testing.go create mode 100644 horizon/internal/inferflow/handler/inferflow_helpers.go create mode 100644 horizon/internal/inferflow/handler/inferflow_review.go create mode 100644 horizon/internal/inferflow/handler/inferflow_validation.go diff --git a/horizon/internal/inferflow/README.md b/horizon/internal/inferflow/README.md new file mode 100644 index 00000000..e5cdeaef --- /dev/null +++ b/horizon/internal/inferflow/README.md @@ -0,0 +1,38 @@ +# Inferflow + +Inferflow handles inferflow config lifecycle operations: onboarding, review/approval, promote, edit, clone, scale-up, delete, cancel, list (get-all / get-all-requests), validation, functional testing, and feature schema generation. + +## Package layout + +``` +internal/inferflow/ +├── controller/ (wires routes to handler) +├── handler/ (Config implementation + helpers) +├── etcd/ (ETCD config read/write for inferflow) +├── route/ +├── handler/proto/ and generated Go +├── init.go +└── README.md +``` + +--- + +## Handler package structure + +| File | Purpose | +|------|--------| +| **config.go** | Defines the **Config** interface (public API). Implemented by `*InferFlow`. | +| **init.go** | Singleton init: `InitV1ConfigHandler()` returns Config. | +| **models.go** | Request/response types, payloads, and shared structs. | +| **inferflow.go** | **InferFlow** struct, `InitV1ConfigHandler()`, and **public entrypoints** that implement Config (e.g. `Onboard`, `Review`, `Promote`, `Edit`, `Clone`, `Delete`, `ScaleUp`, `Cancel`, `GetAll`, `GetAllRequests`, `ValidateRequest`, `GenerateFunctionalTestRequest`, `ExecuteFuncitonalTestRequest`, `GetLatestRequest`, `GetLoggingTTL`, `GetFeatureSchema`). | +| **inferflow_constants.go** | All `const` values: request types, statuses, method names, defaults, delimiters, etc. | +| **inferflow_review.go** | Review/approval and DB/ETCD write flow: `handleRejectedRequest`, `handleApprovedRequest`, rollback helpers (`rollbackApprovedRequest`, `rollbackPromoteRequest`, `rollbackEditRequest`, `rollbackCreatedConfigs`, `rollbackDeletedConfigs`), and create/update helpers (`createOrUpdateDiscoveryConfig`, `createOrUpdateInferFlowConfig`, `createOrUpdateEtcdConfig`). | +| **inferflow_fetch.go** | Batch fetch helpers used by `GetAll`: `batchFetchDiscoveryConfigs`, `batchFetchRingMasterConfigs`. | +| **inferflow_validation.go** | Package-level `ValidateInferFlowConfig`; method `ValidateOnboardRequest` (used by Onboard/Edit/Clone). | +| **inferflow_functional_testing.go** | Functional test entrypoints: `GenerateFunctionalTestRequest`, `ExecuteFuncitonalTestRequest` (gRPC/proto and test-result update logic). | +| **inferflow_helpers.go** | Small helpers: `GetDerivedConfigID`, `GetLoggingTTL`. | +| **adaptor.go** | DB/ETCD payload adaptors: `AdaptToEtcdInferFlowConfig`, `AdaptOnboardRequestToDBPayload`, `AdaptFromDbToInferFlowConfig`, `AdaptFromDbToConfigMapping`, etc. | +| **config_builder.go** | Build inferflow config from request: `GetInferflowConfig`, component/ranker/reranker building, and related constants (e.g. `COLON_DELIMITER`, `PIPE_DELIMITER`, `MODEL_FEATURE`). | +| **schema_adapter.go** | Feature schema from inferflow config: `BuildFeatureSchemaFromInferflow`, `ProcessResponseConfigFromInferflow`. | +| **proto/** | Inferflow gRPC proto and generated Go. | +| **inferflow_test.go** | Unit tests for the handler (same package). | diff --git a/horizon/internal/inferflow/handler/inferflow.go b/horizon/internal/inferflow/handler/inferflow.go index 3f7ba534..b2333b0d 100644 --- a/horizon/internal/inferflow/handler/inferflow.go +++ b/horizon/internal/inferflow/handler/inferflow.go @@ -1,31 +1,21 @@ package handler import ( - "context" "errors" "fmt" "strconv" "strings" - "sync" - "time" mainHandler "github.com/Meesho/BharatMLStack/horizon/internal/externalcall" - inferflowPkg "github.com/Meesho/BharatMLStack/horizon/internal/inferflow" etcd "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/etcd" - pb "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/handler/proto/protogen" infrastructurehandler "github.com/Meesho/BharatMLStack/horizon/internal/infrastructure/handler" discovery_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/discoveryconfig" inferflow "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow" inferflow_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow/config" inferflow_request "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow/request" service_deployable_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" - "github.com/Meesho/BharatMLStack/horizon/pkg/grpc" "github.com/Meesho/BharatMLStack/horizon/pkg/infra" - "github.com/Meesho/BharatMLStack/horizon/pkg/random" - mapset "github.com/deckarep/golang-set/v2" "github.com/rs/zerolog/log" - "google.golang.org/grpc/metadata" - "gorm.io/gorm" ) type InferFlow struct { @@ -38,31 +28,6 @@ type InferFlow struct { workingEnv string } -const ( - emptyResponse = "" - rejected = "REJECTED" - approved = "APPROVED" - pendingApproval = "PENDING APPROVAL" - promoteRequestType = "PROMOTE" - onboardRequestType = "ONBOARD" - editRequestType = "EDIT" - cloneRequestType = "CLONE" - scaleUpRequestType = "SCALE UP" - deleteRequestType = "DELETE" - cancelled = "CANCELLED" - adminRole = "ADMIN" - activeTrue = true - activeFalse = false - inferFlowRetrieveModelScoreMethod = "/Inferflow/RetrieveModelScore" - setFunctionalTest = "FunctionalTest" - defaultLoggingTTL = 30 - maxConfigVersion = 15 - defaultModelSchemaPerc = 0 - deployableTagDelimiter = "_" - scaleupTag = "scaleup" - defaultVersion = 1 -) - func InitV1ConfigHandler() Config { if config == nil { conn, err := infra.SQL.GetConnection() @@ -592,435 +557,6 @@ func (m *InferFlow) Review(request ReviewRequest) (Response, error) { return m.handleApprovedRequest(request) } -func (m *InferFlow) handleRejectedRequest(request ReviewRequest) (Response, error) { - requestEntry := &inferflow_request.Table{ - RequestID: request.RequestID, - Status: request.Status, - RejectReason: request.RejectReason, - Reviewer: request.Reviewer, - Active: activeFalse, - } - - if err := m.InferFlowRequestRepo.Update(requestEntry); err != nil { - return Response{}, errors.New("failed to update inferflow config request in db: " + err.Error()) - } - - return Response{ - Error: emptyResponse, - Data: Message{Message: fmt.Sprintf("inferflow config request rejected successfully for Request Id %d", request.RequestID)}, - }, nil -} - -func (m *InferFlow) handleApprovedRequest(request ReviewRequest) (Response, error) { - var requestEntry *inferflow_request.Table - var discoveryID int - var discoveryConfig *discovery_config.DiscoveryConfig - - tempRequest := inferflow_request.Table{} - tempRequest, err := m.InferFlowRequestRepo.GetRequestByID(request.RequestID) - if err != nil { - return Response{}, fmt.Errorf("failed to fetch latest unapproved request for request id: %d: %w", request.RequestID, err) - } - - var configExistedBeforeTx bool - if tempRequest.RequestType == promoteRequestType { - existingConfig, err := m.InferFlowConfigRepo.GetByID(tempRequest.ConfigID) - if err != nil { - return Response{}, fmt.Errorf("failed to check existing config for promote: %w", err) - } - configExistedBeforeTx = existingConfig != nil - } - - err = m.InferFlowRequestRepo.Transaction(func(tx *gorm.DB) error { - requestEntry = &inferflow_request.Table{ - RequestID: request.RequestID, - Status: request.Status, - RejectReason: request.RejectReason, - Reviewer: request.Reviewer, - } - if err := tx.First(requestEntry, request.RequestID).Error; err != nil { - return fmt.Errorf("failed to get request: %w", err) - } - requestEntry.Reviewer = request.Reviewer - requestEntry.RejectReason = request.RejectReason - - var err error - discoveryID, discoveryConfig, err = m.createOrUpdateDiscoveryConfig(tx, requestEntry, configExistedBeforeTx) - if err != nil { - return fmt.Errorf("failed to handle discovery config: %w", err) - } - - if err := m.createOrUpdateInferFlowConfig(tx, requestEntry, discoveryID, configExistedBeforeTx); err != nil { - return fmt.Errorf("failed to handle inferflow config: %w", err) - } - - requestEntry.Status = approved - err = m.InferFlowRequestRepo.UpdateTx(tx, requestEntry) - if err != nil { - return errors.New("failed to update inferflow config request in db: " + err.Error()) - } - - return nil - }) - - if err != nil { - return Response{}, fmt.Errorf("failed to review config (DB rolled back): %w", err) - } - - if err := m.createOrUpdateEtcdConfig(requestEntry, discoveryConfig, configExistedBeforeTx); err != nil { - if rollBackErr := m.rollbackApprovedRequest(request, requestEntry, discoveryID, configExistedBeforeTx); rollBackErr != nil { - log.Error().Err(rollBackErr).Msg("Failed to rollback DB changes after ETCD failure") - return Response{}, fmt.Errorf("ETCD sync failed and DB rollback also failed: etcd=%w, rollback=%v", err, rollBackErr) - } - log.Warn().Msgf("Successfully rolled back the request: %d", request.RequestID) - return Response{}, fmt.Errorf("ETCD sync failed: %w", err) - } - - return Response{ - Error: emptyResponse, - Data: Message{Message: "Inferflow Config reviewed successfully."}, - }, nil -} - -func (m *InferFlow) rollbackApprovedRequest(request ReviewRequest, fullTable *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { - return m.InferFlowRequestRepo.Transaction(func(tx *gorm.DB) error { - table := &inferflow_request.Table{ - RequestID: request.RequestID, - Status: pendingApproval, - Reviewer: emptyResponse, - } - if err := m.InferFlowRequestRepo.UpdateTx(tx, table); err != nil { - return fmt.Errorf("failed to revert request status: %w", err) - } - - switch fullTable.RequestType { - case onboardRequestType, cloneRequestType, scaleUpRequestType: - if err := m.rollbackCreatedConfigs(tx, fullTable.ConfigID, discoveryID); err != nil { - return err - } - - case editRequestType: - if err := m.rollbackEditRequest(tx, fullTable, discoveryID); err != nil { - return err - } - - case deleteRequestType: - updatedBy := fullTable.UpdatedBy - if updatedBy == "" { - updatedBy = fullTable.CreatedBy - } - if err := m.rollbackDeletedConfigs(tx, fullTable.ConfigID, discoveryID, updatedBy); err != nil { - return err - } - - case promoteRequestType: - if err := m.rollbackPromoteRequest(tx, fullTable, discoveryID, configExistedBeforeTx); err != nil { - return err - } - } - - return nil - }) -} - -func (m *InferFlow) rollbackPromoteRequest(tx *gorm.DB, currentRequest *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { - if configExistedBeforeTx { - if err := m.rollbackEditRequest(tx, currentRequest, discoveryID); err != nil { - return err - } - } else { - if err := m.rollbackCreatedConfigs(tx, currentRequest.ConfigID, discoveryID); err != nil { - return err - } - } - return nil -} - -func (m *InferFlow) rollbackEditRequest(tx *gorm.DB, currentRequest *inferflow_request.Table, discoveryID int) error { - approvedRequests, err := m.InferFlowRequestRepo.GetApprovedRequestsByConfigID(currentRequest.ConfigID) - if err != nil { - return fmt.Errorf("failed to retrieve approved requests: %w", err) - } - - var previousRequest *inferflow_request.Table - if len(approvedRequests) > 0 { - if approvedRequests[0].RequestID == currentRequest.RequestID { - if len(approvedRequests) > 1 { - previousRequest = &approvedRequests[1] - } else { - return fmt.Errorf("no other request to revert back to: Requires manual intervention") - } - } else { - previousRequest = &approvedRequests[0] - } - } else { - return fmt.Errorf("no other request to revert back to: Requires manual intervention") - } - - existingConfig, err := m.InferFlowConfigRepo.GetByID(currentRequest.ConfigID) - if err != nil { - return fmt.Errorf("failed to get inferflow config: %w", err) - } - if existingConfig == nil { - return errors.New("inferflow config not found") - } - - restoredConfig := &inferflow_config.Table{ - ConfigID: currentRequest.ConfigID, - DiscoveryID: discoveryID, - ConfigValue: previousRequest.Payload.ConfigValue, - Active: activeTrue, - UpdatedBy: currentRequest.UpdatedBy, - } - - if err := m.InferFlowConfigRepo.UpdateTx(tx, restoredConfig); err != nil { - return fmt.Errorf("failed to restore inferflow config: %w", err) - } - - restoredDiscovery := &discovery_config.DiscoveryConfig{ - ID: discoveryID, - ServiceDeployableID: previousRequest.Payload.ConfigMapping.DeployableID, - AppToken: previousRequest.Payload.ConfigMapping.AppToken, - ServiceConnectionID: previousRequest.Payload.ConfigMapping.ConnectionConfigID, - Active: activeTrue, - UpdatedBy: currentRequest.UpdatedBy, - } - if err := m.DiscoveryConfigRepo.UpdateTx(tx, restoredDiscovery); err != nil { - return fmt.Errorf("failed to restore discovery config: %w", err) - } - - return nil -} - -func (m *InferFlow) rollbackCreatedConfigs(tx *gorm.DB, configID string, discoveryID int) error { - if err := m.InferFlowConfigRepo.DeleteByConfigIDTx(tx, configID); err != nil { - return fmt.Errorf("failed to rollback inferflow config: %w", err) - } - - if err := m.DiscoveryConfigRepo.DeleteByIDTx(tx, discoveryID); err != nil { - return fmt.Errorf("failed to rollback discovery config: %w", err) - } - - return nil -} - -func (m *InferFlow) rollbackDeletedConfigs(tx *gorm.DB, configID string, discoveryID int, updatedby string) error { - latestConfig, err := m.InferFlowConfigRepo.GetLatestInactiveByConfigID(tx, configID) - if err != nil { - return fmt.Errorf("failed to find soft-deleted inferflow config: %w", err) - } - if latestConfig == nil { - return errors.New("no soft-deleted inferflow config found") - } - - if err := m.InferFlowConfigRepo.ReactivateByIDTx(tx, int(latestConfig.ID), updatedby); err != nil { - return fmt.Errorf("failed to reactivate inferflow config: %w", err) - } - - if err := m.DiscoveryConfigRepo.ReactivateByIDTx(tx, discoveryID); err != nil { - return fmt.Errorf("failed to reactivate discovery config: %w", err) - } - - return nil -} - -func (m *InferFlow) createOrUpdateDiscoveryConfig(tx *gorm.DB, requestEntry *inferflow_request.Table, configExistedBeforeTx bool) (int, *discovery_config.DiscoveryConfig, error) { - discovery := &discovery_config.DiscoveryConfig{ - ServiceDeployableID: requestEntry.Payload.ConfigMapping.DeployableID, - AppToken: requestEntry.Payload.ConfigMapping.AppToken, - ServiceConnectionID: requestEntry.Payload.ConfigMapping.ConnectionConfigID, - Active: activeTrue, - } - - switch requestEntry.RequestType { - case onboardRequestType, cloneRequestType, scaleUpRequestType: - if requestEntry.UpdatedBy != "" { - discovery.CreatedBy = requestEntry.UpdatedBy - } else { - discovery.CreatedBy = requestEntry.CreatedBy - } - err := m.DiscoveryConfigRepo.CreateTx(tx, discovery) - if err != nil { - return 0, nil, errors.New("failed to create discovery config: " + err.Error()) - } - case promoteRequestType: - if !configExistedBeforeTx { - if requestEntry.UpdatedBy != "" { - discovery.CreatedBy = requestEntry.UpdatedBy - } else { - discovery.CreatedBy = requestEntry.CreatedBy - } - err := m.DiscoveryConfigRepo.CreateTx(tx, discovery) - if err != nil { - return 0, nil, errors.New("failed to create discovery config: " + err.Error()) - } - } else { - existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) - if err != nil { - return 0, nil, errors.New("failed to query inferflow config repo: " + err.Error()) - } - if requestEntry.UpdatedBy != "" { - discovery.UpdatedBy = requestEntry.UpdatedBy - } else { - discovery.UpdatedBy = requestEntry.CreatedBy - } - discovery.ID = int(existingConfig.DiscoveryID) - err = m.DiscoveryConfigRepo.UpdateTx(tx, discovery) - if err != nil { - return 0, nil, errors.New("failed to update discovery config: " + err.Error()) - } - } - case editRequestType: - if requestEntry.UpdatedBy != "" { - discovery.UpdatedBy = requestEntry.UpdatedBy - } else { - discovery.UpdatedBy = requestEntry.CreatedBy - } - config, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) - if err != nil { - return 0, nil, errors.New("failed to get inferflow config by id: " + err.Error()) - } - if config == nil { - return 0, nil, errors.New("failed to get inferflow config by id") - } - discovery.ID = int(config.DiscoveryID) - err = m.DiscoveryConfigRepo.UpdateTx(tx, discovery) - if err != nil { - return 0, nil, errors.New("failed to update discovery config: " + err.Error()) - } - case deleteRequestType: - config, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) - if err != nil { - return 0, nil, errors.New("failed to get inferflow config by id: " + err.Error()) - } - if config == nil { - return 0, nil, errors.New("failed to get inferflow config by id") - } - if requestEntry.UpdatedBy != "" { - discovery.UpdatedBy = requestEntry.UpdatedBy - } else { - discovery.UpdatedBy = requestEntry.CreatedBy - } - discovery.ID = int(config.DiscoveryID) - discovery.Active = activeFalse - err = m.DiscoveryConfigRepo.UpdateTx(tx, discovery) - if err != nil { - return 0, nil, errors.New("failed to update discovery config: " + err.Error()) - } - default: - return 0, nil, errors.New("invalid request type") - } - - return discovery.ID, discovery, nil -} - -func (m *InferFlow) createOrUpdateInferFlowConfig(tx *gorm.DB, requestEntry *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { - newConfig := &inferflow_config.Table{ - DiscoveryID: discoveryID, - ConfigID: requestEntry.ConfigID, - Active: activeTrue, - ConfigValue: requestEntry.Payload.ConfigValue, - } - - switch requestEntry.RequestType { - case onboardRequestType, cloneRequestType: - if requestEntry.UpdatedBy != "" { - newConfig.CreatedBy = requestEntry.UpdatedBy - } else { - newConfig.CreatedBy = requestEntry.CreatedBy - } - return m.InferFlowConfigRepo.CreateTx(tx, newConfig) - case scaleUpRequestType: - if requestEntry.UpdatedBy != "" { - newConfig.CreatedBy = requestEntry.UpdatedBy - } else { - newConfig.CreatedBy = requestEntry.CreatedBy - } - newConfig.SourceConfigID = requestEntry.Payload.ConfigMapping.SourceConfigID - return m.InferFlowConfigRepo.CreateTx(tx, newConfig) - case promoteRequestType: - if !configExistedBeforeTx { - if requestEntry.UpdatedBy != "" { - newConfig.CreatedBy = requestEntry.UpdatedBy - } else { - newConfig.CreatedBy = requestEntry.CreatedBy - } - return m.InferFlowConfigRepo.CreateTx(tx, newConfig) - } else { - existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) - if err != nil { - return errors.New("failed to query inferflow config repo: " + err.Error()) - } - newConfig.ID = existingConfig.ID - if requestEntry.UpdatedBy != "" { - newConfig.UpdatedBy = requestEntry.UpdatedBy - } else { - newConfig.UpdatedBy = requestEntry.CreatedBy - } - return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) - } - case editRequestType: - existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) - if err != nil { - return errors.New("failed to get inferflow config by id: " + err.Error()) - } - if existingConfig == nil { - return errors.New("failed to get inferflow config by id") - } - newConfig.ID = existingConfig.ID - if requestEntry.UpdatedBy != "" { - newConfig.UpdatedBy = requestEntry.UpdatedBy - } else { - newConfig.UpdatedBy = requestEntry.CreatedBy - } - return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) - case deleteRequestType: - existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) - if err != nil { - return errors.New("failed to get inferflow config by id: " + err.Error()) - } - if existingConfig == nil { - return errors.New("failed to get inferflow config by id") - } - newConfig.ID = existingConfig.ID - if requestEntry.UpdatedBy != "" { - newConfig.UpdatedBy = requestEntry.UpdatedBy - } else { - newConfig.UpdatedBy = requestEntry.CreatedBy - } - newConfig.Active = activeFalse - return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) - default: - return errors.New("invalid request type") - } -} - -func (m *InferFlow) createOrUpdateEtcdConfig(table *inferflow_request.Table, discovery *discovery_config.DiscoveryConfig, configExistedBeforeTx bool) error { - serviceDeployableTable, err := m.ServiceDeployableConfigRepo.GetById(int(discovery.ServiceDeployableID)) - if err != nil { - return errors.New("failed to get service deployable config by id: " + err.Error()) - } - serviceName := strings.ToLower(serviceDeployableTable.Name) - configId := table.ConfigID - inferFlowConfig := AdaptToEtcdInferFlowConfig(table.Payload.ConfigValue) - - switch table.RequestType { - case onboardRequestType, cloneRequestType, scaleUpRequestType: - return m.EtcdConfig.CreateConfig(serviceName, configId, inferFlowConfig) - case promoteRequestType: - if !configExistedBeforeTx { - return m.EtcdConfig.CreateConfig(serviceName, configId, inferFlowConfig) - } - return m.EtcdConfig.UpdateConfig(serviceName, configId, inferFlowConfig) - case editRequestType: - return m.EtcdConfig.UpdateConfig(serviceName, configId, inferFlowConfig) - case deleteRequestType: - return m.EtcdConfig.DeleteConfig(serviceName, configId) - default: - return errors.New("invalid request type") - } -} - func (m *InferFlow) GetAllRequests(request GetAllRequestConfigsRequest) (GetAllRequestConfigsResponse, error) { var tables []inferflow_request.Table @@ -1166,81 +702,6 @@ func (m *InferFlow) GetAll() (GetAllResponse, error) { return response, nil } -func (m *InferFlow) batchFetchDiscoveryConfigs(discoveryIDs []int) ( - map[int]*discovery_config.DiscoveryConfig, - map[int]*service_deployable_config.ServiceDeployableConfig, - error, -) { - emptyDiscoveryMap := make(map[int]*discovery_config.DiscoveryConfig) - emptyServiceDeployableMap := make(map[int]*service_deployable_config.ServiceDeployableConfig) - - if len(discoveryIDs) == 0 { - return emptyDiscoveryMap, emptyServiceDeployableMap, nil - } - - discoveryConfigs, err := m.DiscoveryConfigRepo.GetByDiscoveryIDs(discoveryIDs) - if err != nil { - return nil, nil, fmt.Errorf("failed to get discovery configs: %w", err) - } - - discoveryMap := make(map[int]*discovery_config.DiscoveryConfig) - for i := range discoveryConfigs { - discoveryMap[discoveryConfigs[i].ID] = &discoveryConfigs[i] - } - - serviceDeployableIDsMap := make(map[int]bool) - for _, dc := range discoveryConfigs { - serviceDeployableIDsMap[dc.ServiceDeployableID] = true - } - - serviceDeployableIDs := make([]int, 0, len(serviceDeployableIDsMap)) - for id := range serviceDeployableIDsMap { - serviceDeployableIDs = append(serviceDeployableIDs, id) - } - - if len(serviceDeployableIDs) == 0 { - return discoveryMap, emptyServiceDeployableMap, nil - } - - serviceDeployables, err := m.ServiceDeployableConfigRepo.GetByIds(serviceDeployableIDs) - if err != nil { - return nil, nil, fmt.Errorf("failed to get service deployable configs: %w", err) - } - - serviceDeployableMap := make(map[int]*service_deployable_config.ServiceDeployableConfig) - for i := range serviceDeployables { - serviceDeployableMap[serviceDeployables[i].ID] = &serviceDeployables[i] - } - - return discoveryMap, serviceDeployableMap, nil -} - -func (m *InferFlow) batchFetchRingMasterConfigs(serviceDeployables map[int]*service_deployable_config.ServiceDeployableConfig) (map[int]infrastructurehandler.Config, error) { - ringMasterConfigs := make(map[int]infrastructurehandler.Config) - var mu sync.Mutex - var wg sync.WaitGroup - - semaphore := make(chan struct{}, 10) - - for id, deployable := range serviceDeployables { - wg.Add(1) - go func(deployableID int, sd *service_deployable_config.ServiceDeployableConfig) { - defer wg.Done() - semaphore <- struct{}{} - defer func() { <-semaphore }() - - config := m.infrastructureHandler.GetConfig(sd.Name, inferflowPkg.AppEnv) - - mu.Lock() - ringMasterConfigs[deployableID] = config - mu.Unlock() - }(id, deployable) - } - - wg.Wait() - return ringMasterConfigs, nil -} - func (m *InferFlow) ValidateRequest(request ValidateRequest, token string) (Response, error) { tables, err := m.InferFlowRequestRepo.GetAll() if err != nil { @@ -1262,353 +723,6 @@ func (m *InferFlow) ValidateRequest(request ValidateRequest, token string) (Resp return ValidateInferFlowConfig(configValue, token) } -func ValidateInferFlowConfig(config InferflowConfig, token string) (Response, error) { - ComponentConfig := config.ComponentConfig - if ComponentConfig != nil { - for _, featureComponent := range ComponentConfig.FeatureComponents { - entity := featureComponent.FSRequest.Label - if entity == "dummy" { - continue - } - response, err := mainHandler.Client.ValidateOnlineFeatures(entity, token) - if err != nil { - return Response{ - Error: "failed to validate feature exists: " + err.Error(), - Data: Message{Message: emptyResponse}, - }, err - } - for _, fg := range featureComponent.FSRequest.FeatureGroups { - featureMap := make(map[string]bool) - for _, feature := range fg.Features { - if _, exists := featureMap[feature]; exists { - return Response{ - Error: "feature " + feature + " is duplicated", - Data: Message{Message: emptyResponse}, - }, errors.New("feature " + feature + " is duplicated") - } - featureMap[feature] = true - if !mainHandler.ValidateFeatureExists(fg.Label+COLON_DELIMITER+feature, response) { - return Response{ - Error: "feature \"" + entity + COLON_DELIMITER + fg.Label + COLON_DELIMITER + feature + "\" does not exist", - Data: Message{Message: emptyResponse}, - }, errors.New("feature \"" + entity + COLON_DELIMITER + fg.Label + COLON_DELIMITER + feature + "\" does not exist") - } - } - } - } - - for _, predatorComponent := range config.ComponentConfig.PredatorComponents { - outputMap := make(map[string]bool) - for _, output := range predatorComponent.Outputs { - for _, modelScore := range output.ModelScores { - if _, exists := outputMap[modelScore]; exists { - return Response{ - Error: "model score " + modelScore + " is duplicated for component " + predatorComponent.Component, - Data: Message{Message: emptyResponse}, - }, errors.New("model score " + modelScore + " is duplicated for component " + predatorComponent.Component) - } - outputMap[modelScore] = true - } - } - } - } - - return Response{ - Error: emptyResponse, - Data: Message{Message: "Request validated successfully"}, - }, nil -} - -func (m *InferFlow) ValidateOnboardRequest(request OnboardPayload) (Response, error) { - outputs := mapset.NewSet[string]() - deployableConfig, err := m.ServiceDeployableConfigRepo.GetById(request.ConfigMapping.DeployableID) - if err != nil { - return Response{ - Error: "Failed to fetch deployable config for the request", - Data: Message{Message: emptyResponse}, - }, errors.New("failed to fetch deployable config for the request") - } - permissibleEndpoints := m.EtcdConfig.GetConfiguredEndpoints(deployableConfig.Name) - for _, ranker := range request.Rankers { - if len(ranker.EntityID) == 0 { - return Response{ - Error: "Entity ID is not set for model: " + ranker.ModelName, - Data: Message{Message: emptyResponse}, - }, errors.New("Entity ID is not set for model: " + ranker.ModelName) - } - if !permissibleEndpoints.Contains(ranker.EndPoint) { - errorMsg := fmt.Sprintf( - "invalid endpoint: %s chosen for service deployable: %s for model: %s", - ranker.EndPoint, deployableConfig.Name, ranker.ModelName, - ) - return Response{ - Error: errorMsg, - Data: Message{Message: emptyResponse}, - }, errors.New(errorMsg) - } - for _, output := range ranker.Outputs { - if len(output.ModelScores) != len(output.ModelScoresDims) { - return Response{ - Error: "model scores and model scores dims are not equal for model: " + ranker.ModelName, - Data: Message{Message: emptyResponse}, - }, errors.New("model scores and model scores dims are not equal for model: " + ranker.ModelName) - } - for _, modelScore := range output.ModelScores { - if outputs.Contains(modelScore) { - return Response{ - Error: "duplicate model scores: " + modelScore + " for model: " + ranker.ModelName, - Data: Message{Message: emptyResponse}, - }, errors.New("duplicate model scores: " + modelScore + " for model: " + ranker.ModelName) - } - outputs.Add(modelScore) - } - } - } - - for _, reRanker := range request.ReRankers { - if len(reRanker.EntityID) == 0 { - return Response{ - Error: "Entity ID is not set for re ranker: " + reRanker.Score, - Data: Message{Message: emptyResponse}, - }, errors.New("Entity ID is not set for re ranker: " + reRanker.Score) - } - for _, value := range reRanker.EqVariables { - parts := strings.Split(value, PIPE_DELIMITER) - if len(parts) != 2 { - return Response{ - Error: "invalid eq variable: " + value, - Data: Message{Message: emptyResponse}, - }, errors.New("invalid eq variable: " + value) - } - if parts[1] == "" { - return Response{ - Error: "invalid eq variable: " + value, - Data: Message{Message: emptyResponse}, - }, errors.New("invalid eq variable: " + value) - } - } - if outputs.Contains(reRanker.Score) { - return Response{ - Error: "duplicate score: " + reRanker.Score + " for reRanker: " + reRanker.Score, - Data: Message{Message: emptyResponse}, - }, errors.New("duplicate score: " + reRanker.Score + " for reRanker: " + reRanker.Score) - } - outputs.Add(reRanker.Score) - } - - // Validate MODEL_FEATURE list - for _, ranker := range request.Rankers { - for _, input := range ranker.Inputs { - for _, feature := range input.Features { - featureParts := strings.Split(feature, PIPE_DELIMITER) - if len(featureParts) != 2 { - return Response{ - Error: "invalid feature: " + feature + " in input features of ranker: " + ranker.ModelName, - Data: Message{Message: emptyResponse}, - }, errors.New("invalid feature: " + feature + " in input features of ranker: " + ranker.ModelName) - } - if strings.Contains(featureParts[0], MODEL_FEATURE) { - if !outputs.Contains(featureParts[1]) { - return Response{ - Error: "model score " + featureParts[1] + " is not found in other model scores of ranker: " + ranker.ModelName, - Data: Message{Message: emptyResponse}, - }, errors.New("model score " + featureParts[1] + " is not found in other model scores of ranker: " + ranker.ModelName) - } - } - } - } - } - - for _, reRanker := range request.ReRankers { - for _, feature := range reRanker.EqVariables { - featureParts := strings.Split(feature, PIPE_DELIMITER) - if len(featureParts) != 2 { - return Response{ - Error: "invalid feature: " + feature, - Data: Message{Message: emptyResponse}, - }, errors.New("invalid feature: " + feature) - } - if strings.Contains(featureParts[0], MODEL_FEATURE) { - if !outputs.Contains(featureParts[1]) { - return Response{ - Error: "model score " + featureParts[1] + " is not found in other model scores of re ranker: " + strconv.Itoa(reRanker.EqID), - Data: Message{Message: emptyResponse}, - }, errors.New("model score " + featureParts[1] + " is not found in other model scores of re ranker: " + strconv.Itoa(reRanker.EqID)) - } - } - } - } - - return Response{ - Error: emptyResponse, - Data: Message{Message: "Request validated successfully"}, - }, nil -} - -func (m *InferFlow) GenerateFunctionalTestRequest(request GenerateRequestFunctionalTestingRequest) (GenerateRequestFunctionalTestingResponse, error) { - - response := GenerateRequestFunctionalTestingResponse{ - RequestBody: RequestBody{ - Entities: []Entity{ - { - Entity: request.Entity + "_id", - Ids: []string{}, - Features: []FeatureValue{}, - }, - }, - ModelConfigID: request.ModelConfigID, - }, - } - - batchSize, err := strconv.Atoi(request.BatchSize) - if err != nil { - response.Error = fmt.Errorf("invalid batch size: %w", err).Error() - return response, errors.New("invalid batch size: " + err.Error()) - } - - response.RequestBody.Entities[0].Entity = request.Entity + "_id" - response.RequestBody.Entities[0].Ids = random.GenerateRandomIntSliceWithRange(batchSize, 100000, 1000000) - - for feature, value := range request.DefaultFeatures { - featureValues := make([]string, batchSize) - for i := 0; i < batchSize; i++ { - featureValues[i] = value - } - response.RequestBody.Entities[0].Features = append(response.RequestBody.Entities[0].Features, FeatureValue{ - Name: feature, - IdsFeatureValue: featureValues, - }) - } - - response.MetaData = request.MetaData - - return response, nil -} - -func (m *InferFlow) ExecuteFuncitonalTestRequest(request ExecuteRequestFunctionalTestingRequest) (ExecuteRequestFunctionalTestingResponse, error) { - response := ExecuteRequestFunctionalTestingResponse{} - - normalizedEndpoint := func(raw string) string { - ep := strings.TrimSpace(raw) - if strings.HasPrefix(ep, "http://") { - ep = strings.TrimPrefix(ep, "http://") - } else if strings.HasPrefix(ep, "https://") { - ep = strings.TrimPrefix(ep, "https://") - } - ep = strings.TrimSuffix(ep, "/") - if idx := strings.LastIndex(ep, ":"); idx != -1 { - if idx < len(ep)-1 { - ep = ep[:idx] - } - } - - port := ":8080" - env := strings.ToLower(strings.TrimSpace(inferflowPkg.AppEnv)) - if env == "stg" || env == "int" { - port = ":80" - } - ep = ep + port - return ep - }(request.EndPoint) - - conn, err := grpc.GetConnection(normalizedEndpoint) - if err != nil { - response.Error = err.Error() - return response, errors.New("failed to get connection: " + err.Error()) - } - - protoRequest := &pb.InferflowRequestProto{} - protoRequest.ModelConfigId = request.RequestBody.ModelConfigID - - md := metadata.New(nil) - if len(request.MetaData) > 0 { - for key, value := range request.MetaData { - md.Set(key, value) - } - } - - md.Set(setFunctionalTest, "true") - - protoRequest.Entities = make([]*pb.InferflowRequestProto_Entity, len(request.RequestBody.Entities)) - for i, entity := range request.RequestBody.Entities { - - protoFeatures := make([]*pb.InferflowRequestProto_Entity_Feature, len(entity.Features)) - - for j, feature := range entity.Features { - protoFeatures[j] = &pb.InferflowRequestProto_Entity_Feature{ - Name: feature.Name, - IdsFeatureValue: feature.IdsFeatureValue, - } - } - - protoRequest.Entities[i] = &pb.InferflowRequestProto_Entity{ - Entity: entity.Entity, - Ids: entity.Ids, - Features: protoFeatures, - } - } - - protoResponse := &pb.InferflowResponseProto{} - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - - defer cancel() - err = grpc.SendGRPCRequest(ctx, conn, inferFlowRetrieveModelScoreMethod, protoRequest, protoResponse, md) - if err != nil { - response.Error = err.Error() - log.Error().Msgf("error: %v", err) - return response, errors.New("failed to send grpc request: " + err.Error()) - } - - for _, compData := range protoResponse.GetComponentData() { - response.ComponentData = append(response.ComponentData, ComponentData{ - Data: compData.GetData(), - }) - } - - for i, compData := range response.ComponentData { - if i == 0 { - continue - } - for j, data := range compData.Data { - if data == "" { - response.Error = fmt.Sprintf("response data is empty for field: %s", response.ComponentData[0].Data[j]) - break - } - } - } - - if protoResponse.GetError() != nil { - response.Error = protoResponse.GetError().GetMessage() - } - - inferFlowConfig, err := m.InferFlowConfigRepo.GetByID(request.RequestBody.ModelConfigID) - - if err != nil { - fmt.Println("Error getting inferflow config: ", err) - } else if inferFlowConfig == nil { - log.Error().Msgf("inferflow config '%s' does not exist in DB", request.RequestBody.ModelConfigID) - } else { - if response.Error != emptyResponse { - inferFlowConfig.TestResults = inferflow.TestResults{ - Tested: false, - Message: response.Error, - } - } else { - inferFlowConfig.TestResults = inferflow.TestResults{ - Tested: true, - Message: "Functional test request executed successfully", - } - } - err = m.InferFlowConfigRepo.Update(inferFlowConfig) - if err != nil { - fmt.Println("Error updating inferflow config: ", err) - } - } - - return response, nil -} - func (m *InferFlow) GetLatestRequest(requestID string) (GetLatestRequestResponse, error) { requests, err := m.InferFlowRequestRepo.GetApprovedRequestsByConfigID(requestID) if err != nil { @@ -1651,26 +765,6 @@ func (m *InferFlow) GetLatestRequest(requestID string) (GetLatestRequestResponse }, nil } -func (m *InferFlow) GetDerivedConfigID(configID string, deployableID int) (string, error) { - serviceDeployableConfig, err := m.ServiceDeployableConfigRepo.GetById(deployableID) - if err != nil { - return "", fmt.Errorf("failed to fetch service service deployable config for name generation: %w", err) - } - deployableTag := serviceDeployableConfig.DeployableTag - if deployableTag == "" { - return configID, nil - } - - derivedConfigID := configID + deployableTagDelimiter + deployableTag + deployableTagDelimiter + scaleupTag - return derivedConfigID, nil -} - -func (m *InferFlow) GetLoggingTTL() (GetLoggingTTLResponse, error) { - return GetLoggingTTLResponse{ - Data: []int{30, 60, 90}, - }, nil -} - func (m *InferFlow) GetFeatureSchema(request FeatureSchemaRequest) (FeatureSchemaResponse, error) { version, err := strconv.Atoi(request.Version) if err != nil { diff --git a/horizon/internal/inferflow/handler/inferflow_constants.go b/horizon/internal/inferflow/handler/inferflow_constants.go new file mode 100644 index 00000000..c1602d63 --- /dev/null +++ b/horizon/internal/inferflow/handler/inferflow_constants.go @@ -0,0 +1,26 @@ +package handler + +const ( + emptyResponse = "" + rejected = "REJECTED" + approved = "APPROVED" + pendingApproval = "PENDING APPROVAL" + promoteRequestType = "PROMOTE" + onboardRequestType = "ONBOARD" + editRequestType = "EDIT" + cloneRequestType = "CLONE" + scaleUpRequestType = "SCALE UP" + deleteRequestType = "DELETE" + cancelled = "CANCELLED" + adminRole = "ADMIN" + activeTrue = true + activeFalse = false + inferFlowRetrieveModelScoreMethod = "/Inferflow/RetrieveModelScore" + setFunctionalTest = "FunctionalTest" + defaultLoggingTTL = 30 + maxConfigVersion = 15 + defaultModelSchemaPerc = 0 + deployableTagDelimiter = "_" + scaleupTag = "scaleup" + defaultVersion = 1 +) diff --git a/horizon/internal/inferflow/handler/inferflow_fetch.go b/horizon/internal/inferflow/handler/inferflow_fetch.go new file mode 100644 index 00000000..a1982dbe --- /dev/null +++ b/horizon/internal/inferflow/handler/inferflow_fetch.go @@ -0,0 +1,86 @@ +package handler + +import ( + "fmt" + "sync" + + inferflowPkg "github.com/Meesho/BharatMLStack/horizon/internal/inferflow" + infrastructurehandler "github.com/Meesho/BharatMLStack/horizon/internal/infrastructure/handler" + discovery_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/discoveryconfig" + service_deployable_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" +) + +func (m *InferFlow) batchFetchDiscoveryConfigs(discoveryIDs []int) ( + map[int]*discovery_config.DiscoveryConfig, + map[int]*service_deployable_config.ServiceDeployableConfig, + error, +) { + emptyDiscoveryMap := make(map[int]*discovery_config.DiscoveryConfig) + emptyServiceDeployableMap := make(map[int]*service_deployable_config.ServiceDeployableConfig) + + if len(discoveryIDs) == 0 { + return emptyDiscoveryMap, emptyServiceDeployableMap, nil + } + + discoveryConfigs, err := m.DiscoveryConfigRepo.GetByDiscoveryIDs(discoveryIDs) + if err != nil { + return nil, nil, fmt.Errorf("failed to get discovery configs: %w", err) + } + + discoveryMap := make(map[int]*discovery_config.DiscoveryConfig) + for i := range discoveryConfigs { + discoveryMap[discoveryConfigs[i].ID] = &discoveryConfigs[i] + } + + serviceDeployableIDsMap := make(map[int]bool) + for _, dc := range discoveryConfigs { + serviceDeployableIDsMap[dc.ServiceDeployableID] = true + } + + serviceDeployableIDs := make([]int, 0, len(serviceDeployableIDsMap)) + for id := range serviceDeployableIDsMap { + serviceDeployableIDs = append(serviceDeployableIDs, id) + } + + if len(serviceDeployableIDs) == 0 { + return discoveryMap, emptyServiceDeployableMap, nil + } + + serviceDeployables, err := m.ServiceDeployableConfigRepo.GetByIds(serviceDeployableIDs) + if err != nil { + return nil, nil, fmt.Errorf("failed to get service deployable configs: %w", err) + } + + serviceDeployableMap := make(map[int]*service_deployable_config.ServiceDeployableConfig) + for i := range serviceDeployables { + serviceDeployableMap[serviceDeployables[i].ID] = &serviceDeployables[i] + } + + return discoveryMap, serviceDeployableMap, nil +} + +func (m *InferFlow) batchFetchRingMasterConfigs(serviceDeployables map[int]*service_deployable_config.ServiceDeployableConfig) (map[int]infrastructurehandler.Config, error) { + ringMasterConfigs := make(map[int]infrastructurehandler.Config) + var mu sync.Mutex + var wg sync.WaitGroup + + semaphore := make(chan struct{}, 10) + + for id, deployable := range serviceDeployables { + wg.Add(1) + go func(deployableID int, sd *service_deployable_config.ServiceDeployableConfig) { + defer wg.Done() + semaphore <- struct{}{} + defer func() { <-semaphore }() + + config := m.infrastructureHandler.GetConfig(sd.Name, inferflowPkg.AppEnv) + + mu.Lock() + ringMasterConfigs[deployableID] = config + mu.Unlock() + }(id, deployable) + } + + wg.Wait() + return ringMasterConfigs, nil +} diff --git a/horizon/internal/inferflow/handler/inferflow_functional_testing.go b/horizon/internal/inferflow/handler/inferflow_functional_testing.go new file mode 100644 index 00000000..9ca9839f --- /dev/null +++ b/horizon/internal/inferflow/handler/inferflow_functional_testing.go @@ -0,0 +1,182 @@ +package handler + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + inferflowPkg "github.com/Meesho/BharatMLStack/horizon/internal/inferflow" + pb "github.com/Meesho/BharatMLStack/horizon/internal/inferflow/handler/proto/protogen" + inferflow "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow" + "github.com/Meesho/BharatMLStack/horizon/pkg/grpc" + "github.com/Meesho/BharatMLStack/horizon/pkg/random" + "github.com/rs/zerolog/log" + "google.golang.org/grpc/metadata" +) + +func (m *InferFlow) GenerateFunctionalTestRequest(request GenerateRequestFunctionalTestingRequest) (GenerateRequestFunctionalTestingResponse, error) { + + response := GenerateRequestFunctionalTestingResponse{ + RequestBody: RequestBody{ + Entities: []Entity{ + { + Entity: request.Entity + "_id", + Ids: []string{}, + Features: []FeatureValue{}, + }, + }, + ModelConfigID: request.ModelConfigID, + }, + } + + batchSize, err := strconv.Atoi(request.BatchSize) + if err != nil { + response.Error = fmt.Errorf("invalid batch size: %w", err).Error() + return response, errors.New("invalid batch size: " + err.Error()) + } + + response.RequestBody.Entities[0].Entity = request.Entity + "_id" + response.RequestBody.Entities[0].Ids = random.GenerateRandomIntSliceWithRange(batchSize, 100000, 1000000) + + for feature, value := range request.DefaultFeatures { + featureValues := make([]string, batchSize) + for i := 0; i < batchSize; i++ { + featureValues[i] = value + } + response.RequestBody.Entities[0].Features = append(response.RequestBody.Entities[0].Features, FeatureValue{ + Name: feature, + IdsFeatureValue: featureValues, + }) + } + + response.MetaData = request.MetaData + + return response, nil +} + +func (m *InferFlow) ExecuteFuncitonalTestRequest(request ExecuteRequestFunctionalTestingRequest) (ExecuteRequestFunctionalTestingResponse, error) { + response := ExecuteRequestFunctionalTestingResponse{} + + normalizedEndpoint := func(raw string) string { + ep := strings.TrimSpace(raw) + if strings.HasPrefix(ep, "http://") { + ep = strings.TrimPrefix(ep, "http://") + } else if strings.HasPrefix(ep, "https://") { + ep = strings.TrimPrefix(ep, "https://") + } + ep = strings.TrimSuffix(ep, "/") + if idx := strings.LastIndex(ep, ":"); idx != -1 { + if idx < len(ep)-1 { + ep = ep[:idx] + } + } + + port := ":8080" + env := strings.ToLower(strings.TrimSpace(inferflowPkg.AppEnv)) + if env == "stg" || env == "int" { + port = ":80" + } + ep = ep + port + return ep + }(request.EndPoint) + + conn, err := grpc.GetConnection(normalizedEndpoint) + if err != nil { + response.Error = err.Error() + return response, errors.New("failed to get connection: " + err.Error()) + } + + protoRequest := &pb.InferflowRequestProto{} + protoRequest.ModelConfigId = request.RequestBody.ModelConfigID + + md := metadata.New(nil) + if len(request.MetaData) > 0 { + for key, value := range request.MetaData { + md.Set(key, value) + } + } + + md.Set(setFunctionalTest, "true") + + protoRequest.Entities = make([]*pb.InferflowRequestProto_Entity, len(request.RequestBody.Entities)) + for i, entity := range request.RequestBody.Entities { + + protoFeatures := make([]*pb.InferflowRequestProto_Entity_Feature, len(entity.Features)) + + for j, feature := range entity.Features { + protoFeatures[j] = &pb.InferflowRequestProto_Entity_Feature{ + Name: feature.Name, + IdsFeatureValue: feature.IdsFeatureValue, + } + } + + protoRequest.Entities[i] = &pb.InferflowRequestProto_Entity{ + Entity: entity.Entity, + Ids: entity.Ids, + Features: protoFeatures, + } + } + + protoResponse := &pb.InferflowResponseProto{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + + defer cancel() + err = grpc.SendGRPCRequest(ctx, conn, inferFlowRetrieveModelScoreMethod, protoRequest, protoResponse, md) + if err != nil { + response.Error = err.Error() + log.Error().Msgf("error: %v", err) + return response, errors.New("failed to send grpc request: " + err.Error()) + } + + for _, compData := range protoResponse.GetComponentData() { + response.ComponentData = append(response.ComponentData, ComponentData{ + Data: compData.GetData(), + }) + } + + for i, compData := range response.ComponentData { + if i == 0 { + continue + } + for j, data := range compData.Data { + if data == "" { + response.Error = fmt.Sprintf("response data is empty for field: %s", response.ComponentData[0].Data[j]) + break + } + } + } + + if protoResponse.GetError() != nil { + response.Error = protoResponse.GetError().GetMessage() + } + + inferFlowConfig, err := m.InferFlowConfigRepo.GetByID(request.RequestBody.ModelConfigID) + + if err != nil { + fmt.Println("Error getting inferflow config: ", err) + } else if inferFlowConfig == nil { + log.Error().Msgf("inferflow config '%s' does not exist in DB", request.RequestBody.ModelConfigID) + } else { + if response.Error != emptyResponse { + inferFlowConfig.TestResults = inferflow.TestResults{ + Tested: false, + Message: response.Error, + } + } else { + inferFlowConfig.TestResults = inferflow.TestResults{ + Tested: true, + Message: "Functional test request executed successfully", + } + } + err = m.InferFlowConfigRepo.Update(inferFlowConfig) + if err != nil { + fmt.Println("Error updating inferflow config: ", err) + } + } + + return response, nil +} diff --git a/horizon/internal/inferflow/handler/inferflow_helpers.go b/horizon/internal/inferflow/handler/inferflow_helpers.go new file mode 100644 index 00000000..9680a9fb --- /dev/null +++ b/horizon/internal/inferflow/handler/inferflow_helpers.go @@ -0,0 +1,25 @@ +package handler + +import ( + "fmt" +) + +func (m *InferFlow) GetDerivedConfigID(configID string, deployableID int) (string, error) { + serviceDeployableConfig, err := m.ServiceDeployableConfigRepo.GetById(deployableID) + if err != nil { + return "", fmt.Errorf("failed to fetch service service deployable config for name generation: %w", err) + } + deployableTag := serviceDeployableConfig.DeployableTag + if deployableTag == "" { + return configID, nil + } + + derivedConfigID := configID + deployableTagDelimiter + deployableTag + deployableTagDelimiter + scaleupTag + return derivedConfigID, nil +} + +func (m *InferFlow) GetLoggingTTL() (GetLoggingTTLResponse, error) { + return GetLoggingTTLResponse{ + Data: []int{30, 60, 90}, + }, nil +} diff --git a/horizon/internal/inferflow/handler/inferflow_review.go b/horizon/internal/inferflow/handler/inferflow_review.go new file mode 100644 index 00000000..fd33e7ba --- /dev/null +++ b/horizon/internal/inferflow/handler/inferflow_review.go @@ -0,0 +1,442 @@ +package handler + +import ( + "errors" + "fmt" + "strings" + + discovery_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/discoveryconfig" + inferflow_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow/config" + inferflow_request "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/inferflow/request" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +func (m *InferFlow) handleRejectedRequest(request ReviewRequest) (Response, error) { + requestEntry := &inferflow_request.Table{ + RequestID: request.RequestID, + Status: request.Status, + RejectReason: request.RejectReason, + Reviewer: request.Reviewer, + Active: activeFalse, + } + + if err := m.InferFlowRequestRepo.Update(requestEntry); err != nil { + return Response{}, errors.New("failed to update inferflow config request in db: " + err.Error()) + } + + return Response{ + Error: emptyResponse, + Data: Message{Message: fmt.Sprintf("inferflow config request rejected successfully for Request Id %d", request.RequestID)}, + }, nil +} + +func (m *InferFlow) handleApprovedRequest(request ReviewRequest) (Response, error) { + var requestEntry *inferflow_request.Table + var discoveryID int + var discoveryConfig *discovery_config.DiscoveryConfig + + tempRequest := inferflow_request.Table{} + tempRequest, err := m.InferFlowRequestRepo.GetRequestByID(request.RequestID) + if err != nil { + return Response{}, fmt.Errorf("failed to fetch latest unapproved request for request id: %d: %w", request.RequestID, err) + } + + var configExistedBeforeTx bool + if tempRequest.RequestType == promoteRequestType { + existingConfig, err := m.InferFlowConfigRepo.GetByID(tempRequest.ConfigID) + if err != nil { + return Response{}, fmt.Errorf("failed to check existing config for promote: %w", err) + } + configExistedBeforeTx = existingConfig != nil + } + + err = m.InferFlowRequestRepo.Transaction(func(tx *gorm.DB) error { + requestEntry = &inferflow_request.Table{ + RequestID: request.RequestID, + Status: request.Status, + RejectReason: request.RejectReason, + Reviewer: request.Reviewer, + } + if err := tx.First(requestEntry, request.RequestID).Error; err != nil { + return fmt.Errorf("failed to get request: %w", err) + } + requestEntry.Reviewer = request.Reviewer + requestEntry.RejectReason = request.RejectReason + + var err error + discoveryID, discoveryConfig, err = m.createOrUpdateDiscoveryConfig(tx, requestEntry, configExistedBeforeTx) + if err != nil { + return fmt.Errorf("failed to handle discovery config: %w", err) + } + + if err := m.createOrUpdateInferFlowConfig(tx, requestEntry, discoveryID, configExistedBeforeTx); err != nil { + return fmt.Errorf("failed to handle inferflow config: %w", err) + } + + requestEntry.Status = approved + err = m.InferFlowRequestRepo.UpdateTx(tx, requestEntry) + if err != nil { + return errors.New("failed to update inferflow config request in db: " + err.Error()) + } + + return nil + }) + + if err != nil { + return Response{}, fmt.Errorf("failed to review config (DB rolled back): %w", err) + } + + if err := m.createOrUpdateEtcdConfig(requestEntry, discoveryConfig, configExistedBeforeTx); err != nil { + if rollBackErr := m.rollbackApprovedRequest(request, requestEntry, discoveryID, configExistedBeforeTx); rollBackErr != nil { + log.Error().Err(rollBackErr).Msg("Failed to rollback DB changes after ETCD failure") + return Response{}, fmt.Errorf("ETCD sync failed and DB rollback also failed: etcd=%w, rollback=%v", err, rollBackErr) + } + log.Warn().Msgf("Successfully rolled back the request: %d", request.RequestID) + return Response{}, fmt.Errorf("ETCD sync failed: %w", err) + } + + return Response{ + Error: emptyResponse, + Data: Message{Message: "Inferflow Config reviewed successfully."}, + }, nil +} + +func (m *InferFlow) rollbackApprovedRequest(request ReviewRequest, fullTable *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { + return m.InferFlowRequestRepo.Transaction(func(tx *gorm.DB) error { + table := &inferflow_request.Table{ + RequestID: request.RequestID, + Status: pendingApproval, + Reviewer: emptyResponse, + } + if err := m.InferFlowRequestRepo.UpdateTx(tx, table); err != nil { + return fmt.Errorf("failed to revert request status: %w", err) + } + + switch fullTable.RequestType { + case onboardRequestType, cloneRequestType, scaleUpRequestType: + if err := m.rollbackCreatedConfigs(tx, fullTable.ConfigID, discoveryID); err != nil { + return err + } + + case editRequestType: + if err := m.rollbackEditRequest(tx, fullTable, discoveryID); err != nil { + return err + } + + case deleteRequestType: + updatedBy := fullTable.UpdatedBy + if updatedBy == "" { + updatedBy = fullTable.CreatedBy + } + if err := m.rollbackDeletedConfigs(tx, fullTable.ConfigID, discoveryID, updatedBy); err != nil { + return err + } + + case promoteRequestType: + if err := m.rollbackPromoteRequest(tx, fullTable, discoveryID, configExistedBeforeTx); err != nil { + return err + } + } + + return nil + }) +} + +func (m *InferFlow) rollbackPromoteRequest(tx *gorm.DB, currentRequest *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { + if configExistedBeforeTx { + if err := m.rollbackEditRequest(tx, currentRequest, discoveryID); err != nil { + return err + } + } else { + if err := m.rollbackCreatedConfigs(tx, currentRequest.ConfigID, discoveryID); err != nil { + return err + } + } + return nil +} + +func (m *InferFlow) rollbackEditRequest(tx *gorm.DB, currentRequest *inferflow_request.Table, discoveryID int) error { + approvedRequests, err := m.InferFlowRequestRepo.GetApprovedRequestsByConfigID(currentRequest.ConfigID) + if err != nil { + return fmt.Errorf("failed to retrieve approved requests: %w", err) + } + + var previousRequest *inferflow_request.Table + if len(approvedRequests) > 0 { + if approvedRequests[0].RequestID == currentRequest.RequestID { + if len(approvedRequests) > 1 { + previousRequest = &approvedRequests[1] + } else { + return fmt.Errorf("no other request to revert back to: Requires manual intervention") + } + } else { + previousRequest = &approvedRequests[0] + } + } else { + return fmt.Errorf("no other request to revert back to: Requires manual intervention") + } + + existingConfig, err := m.InferFlowConfigRepo.GetByID(currentRequest.ConfigID) + if err != nil { + return fmt.Errorf("failed to get inferflow config: %w", err) + } + if existingConfig == nil { + return errors.New("inferflow config not found") + } + + restoredConfig := &inferflow_config.Table{ + ConfigID: currentRequest.ConfigID, + DiscoveryID: discoveryID, + ConfigValue: previousRequest.Payload.ConfigValue, + Active: activeTrue, + UpdatedBy: currentRequest.UpdatedBy, + } + + if err := m.InferFlowConfigRepo.UpdateTx(tx, restoredConfig); err != nil { + return fmt.Errorf("failed to restore inferflow config: %w", err) + } + + restoredDiscovery := &discovery_config.DiscoveryConfig{ + ID: discoveryID, + ServiceDeployableID: previousRequest.Payload.ConfigMapping.DeployableID, + AppToken: previousRequest.Payload.ConfigMapping.AppToken, + ServiceConnectionID: previousRequest.Payload.ConfigMapping.ConnectionConfigID, + Active: activeTrue, + UpdatedBy: currentRequest.UpdatedBy, + } + if err := m.DiscoveryConfigRepo.UpdateTx(tx, restoredDiscovery); err != nil { + return fmt.Errorf("failed to restore discovery config: %w", err) + } + + return nil +} + +func (m *InferFlow) rollbackCreatedConfigs(tx *gorm.DB, configID string, discoveryID int) error { + if err := m.InferFlowConfigRepo.DeleteByConfigIDTx(tx, configID); err != nil { + return fmt.Errorf("failed to rollback inferflow config: %w", err) + } + + if err := m.DiscoveryConfigRepo.DeleteByIDTx(tx, discoveryID); err != nil { + return fmt.Errorf("failed to rollback discovery config: %w", err) + } + + return nil +} + +func (m *InferFlow) rollbackDeletedConfigs(tx *gorm.DB, configID string, discoveryID int, updatedby string) error { + latestConfig, err := m.InferFlowConfigRepo.GetLatestInactiveByConfigID(tx, configID) + if err != nil { + return fmt.Errorf("failed to find soft-deleted inferflow config: %w", err) + } + if latestConfig == nil { + return errors.New("no soft-deleted inferflow config found") + } + + if err := m.InferFlowConfigRepo.ReactivateByIDTx(tx, int(latestConfig.ID), updatedby); err != nil { + return fmt.Errorf("failed to reactivate inferflow config: %w", err) + } + + if err := m.DiscoveryConfigRepo.ReactivateByIDTx(tx, discoveryID); err != nil { + return fmt.Errorf("failed to reactivate discovery config: %w", err) + } + + return nil +} + +func (m *InferFlow) createOrUpdateDiscoveryConfig(tx *gorm.DB, requestEntry *inferflow_request.Table, configExistedBeforeTx bool) (int, *discovery_config.DiscoveryConfig, error) { + discovery := &discovery_config.DiscoveryConfig{ + ServiceDeployableID: requestEntry.Payload.ConfigMapping.DeployableID, + AppToken: requestEntry.Payload.ConfigMapping.AppToken, + ServiceConnectionID: requestEntry.Payload.ConfigMapping.ConnectionConfigID, + Active: activeTrue, + } + + switch requestEntry.RequestType { + case onboardRequestType, cloneRequestType, scaleUpRequestType: + if requestEntry.UpdatedBy != "" { + discovery.CreatedBy = requestEntry.UpdatedBy + } else { + discovery.CreatedBy = requestEntry.CreatedBy + } + err := m.DiscoveryConfigRepo.CreateTx(tx, discovery) + if err != nil { + return 0, nil, errors.New("failed to create discovery config: " + err.Error()) + } + case promoteRequestType: + if !configExistedBeforeTx { + if requestEntry.UpdatedBy != "" { + discovery.CreatedBy = requestEntry.UpdatedBy + } else { + discovery.CreatedBy = requestEntry.CreatedBy + } + err := m.DiscoveryConfigRepo.CreateTx(tx, discovery) + if err != nil { + return 0, nil, errors.New("failed to create discovery config: " + err.Error()) + } + } else { + existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) + if err != nil { + return 0, nil, errors.New("failed to query inferflow config repo: " + err.Error()) + } + if requestEntry.UpdatedBy != "" { + discovery.UpdatedBy = requestEntry.UpdatedBy + } else { + discovery.UpdatedBy = requestEntry.CreatedBy + } + discovery.ID = int(existingConfig.DiscoveryID) + err = m.DiscoveryConfigRepo.UpdateTx(tx, discovery) + if err != nil { + return 0, nil, errors.New("failed to update discovery config: " + err.Error()) + } + } + case editRequestType: + if requestEntry.UpdatedBy != "" { + discovery.UpdatedBy = requestEntry.UpdatedBy + } else { + discovery.UpdatedBy = requestEntry.CreatedBy + } + config, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) + if err != nil { + return 0, nil, errors.New("failed to get inferflow config by id: " + err.Error()) + } + if config == nil { + return 0, nil, errors.New("failed to get inferflow config by id") + } + discovery.ID = int(config.DiscoveryID) + err = m.DiscoveryConfigRepo.UpdateTx(tx, discovery) + if err != nil { + return 0, nil, errors.New("failed to update discovery config: " + err.Error()) + } + case deleteRequestType: + config, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) + if err != nil { + return 0, nil, errors.New("failed to get inferflow config by id: " + err.Error()) + } + if config == nil { + return 0, nil, errors.New("failed to get inferflow config by id") + } + if requestEntry.UpdatedBy != "" { + discovery.UpdatedBy = requestEntry.UpdatedBy + } else { + discovery.UpdatedBy = requestEntry.CreatedBy + } + discovery.ID = int(config.DiscoveryID) + discovery.Active = activeFalse + err = m.DiscoveryConfigRepo.UpdateTx(tx, discovery) + if err != nil { + return 0, nil, errors.New("failed to update discovery config: " + err.Error()) + } + default: + return 0, nil, errors.New("invalid request type") + } + + return discovery.ID, discovery, nil +} + +func (m *InferFlow) createOrUpdateInferFlowConfig(tx *gorm.DB, requestEntry *inferflow_request.Table, discoveryID int, configExistedBeforeTx bool) error { + newConfig := &inferflow_config.Table{ + DiscoveryID: discoveryID, + ConfigID: requestEntry.ConfigID, + Active: activeTrue, + ConfigValue: requestEntry.Payload.ConfigValue, + } + + switch requestEntry.RequestType { + case onboardRequestType, cloneRequestType: + if requestEntry.UpdatedBy != "" { + newConfig.CreatedBy = requestEntry.UpdatedBy + } else { + newConfig.CreatedBy = requestEntry.CreatedBy + } + return m.InferFlowConfigRepo.CreateTx(tx, newConfig) + case scaleUpRequestType: + if requestEntry.UpdatedBy != "" { + newConfig.CreatedBy = requestEntry.UpdatedBy + } else { + newConfig.CreatedBy = requestEntry.CreatedBy + } + newConfig.SourceConfigID = requestEntry.Payload.ConfigMapping.SourceConfigID + return m.InferFlowConfigRepo.CreateTx(tx, newConfig) + case promoteRequestType: + if !configExistedBeforeTx { + if requestEntry.UpdatedBy != "" { + newConfig.CreatedBy = requestEntry.UpdatedBy + } else { + newConfig.CreatedBy = requestEntry.CreatedBy + } + return m.InferFlowConfigRepo.CreateTx(tx, newConfig) + } else { + existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) + if err != nil { + return errors.New("failed to query inferflow config repo: " + err.Error()) + } + newConfig.ID = existingConfig.ID + if requestEntry.UpdatedBy != "" { + newConfig.UpdatedBy = requestEntry.UpdatedBy + } else { + newConfig.UpdatedBy = requestEntry.CreatedBy + } + return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) + } + case editRequestType: + existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) + if err != nil { + return errors.New("failed to get inferflow config by id: " + err.Error()) + } + if existingConfig == nil { + return errors.New("failed to get inferflow config by id") + } + newConfig.ID = existingConfig.ID + if requestEntry.UpdatedBy != "" { + newConfig.UpdatedBy = requestEntry.UpdatedBy + } else { + newConfig.UpdatedBy = requestEntry.CreatedBy + } + return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) + case deleteRequestType: + existingConfig, err := m.InferFlowConfigRepo.GetByID(requestEntry.ConfigID) + if err != nil { + return errors.New("failed to get inferflow config by id: " + err.Error()) + } + if existingConfig == nil { + return errors.New("failed to get inferflow config by id") + } + newConfig.ID = existingConfig.ID + if requestEntry.UpdatedBy != "" { + newConfig.UpdatedBy = requestEntry.UpdatedBy + } else { + newConfig.UpdatedBy = requestEntry.CreatedBy + } + newConfig.Active = activeFalse + return m.InferFlowConfigRepo.UpdateTx(tx, newConfig) + default: + return errors.New("invalid request type") + } +} + +func (m *InferFlow) createOrUpdateEtcdConfig(table *inferflow_request.Table, discovery *discovery_config.DiscoveryConfig, configExistedBeforeTx bool) error { + serviceDeployableTable, err := m.ServiceDeployableConfigRepo.GetById(int(discovery.ServiceDeployableID)) + if err != nil { + return errors.New("failed to get service deployable config by id: " + err.Error()) + } + serviceName := strings.ToLower(serviceDeployableTable.Name) + configId := table.ConfigID + inferFlowConfig := AdaptToEtcdInferFlowConfig(table.Payload.ConfigValue) + + switch table.RequestType { + case onboardRequestType, cloneRequestType, scaleUpRequestType: + return m.EtcdConfig.CreateConfig(serviceName, configId, inferFlowConfig) + case promoteRequestType: + if !configExistedBeforeTx { + return m.EtcdConfig.CreateConfig(serviceName, configId, inferFlowConfig) + } + return m.EtcdConfig.UpdateConfig(serviceName, configId, inferFlowConfig) + case editRequestType: + return m.EtcdConfig.UpdateConfig(serviceName, configId, inferFlowConfig) + case deleteRequestType: + return m.EtcdConfig.DeleteConfig(serviceName, configId) + default: + return errors.New("invalid request type") + } +} diff --git a/horizon/internal/inferflow/handler/inferflow_validation.go b/horizon/internal/inferflow/handler/inferflow_validation.go new file mode 100644 index 00000000..a36ca2ac --- /dev/null +++ b/horizon/internal/inferflow/handler/inferflow_validation.go @@ -0,0 +1,194 @@ +package handler + +import ( + "errors" + "fmt" + "strconv" + "strings" + + mainHandler "github.com/Meesho/BharatMLStack/horizon/internal/externalcall" + mapset "github.com/deckarep/golang-set/v2" +) + +func ValidateInferFlowConfig(config InferflowConfig, token string) (Response, error) { + ComponentConfig := config.ComponentConfig + if ComponentConfig != nil { + for _, featureComponent := range ComponentConfig.FeatureComponents { + entity := featureComponent.FSRequest.Label + if entity == "dummy" { + continue + } + response, err := mainHandler.Client.ValidateOnlineFeatures(entity, token) + if err != nil { + return Response{ + Error: "failed to validate feature exists: " + err.Error(), + Data: Message{Message: emptyResponse}, + }, err + } + for _, fg := range featureComponent.FSRequest.FeatureGroups { + featureMap := make(map[string]bool) + for _, feature := range fg.Features { + if _, exists := featureMap[feature]; exists { + return Response{ + Error: "feature " + feature + " is duplicated", + Data: Message{Message: emptyResponse}, + }, errors.New("feature " + feature + " is duplicated") + } + featureMap[feature] = true + if !mainHandler.ValidateFeatureExists(fg.Label+COLON_DELIMITER+feature, response) { + return Response{ + Error: "feature \"" + entity + COLON_DELIMITER + fg.Label + COLON_DELIMITER + feature + "\" does not exist", + Data: Message{Message: emptyResponse}, + }, errors.New("feature \"" + entity + COLON_DELIMITER + fg.Label + COLON_DELIMITER + feature + "\" does not exist") + } + } + } + } + + for _, predatorComponent := range config.ComponentConfig.PredatorComponents { + outputMap := make(map[string]bool) + for _, output := range predatorComponent.Outputs { + for _, modelScore := range output.ModelScores { + if _, exists := outputMap[modelScore]; exists { + return Response{ + Error: "model score " + modelScore + " is duplicated for component " + predatorComponent.Component, + Data: Message{Message: emptyResponse}, + }, errors.New("model score " + modelScore + " is duplicated for component " + predatorComponent.Component) + } + outputMap[modelScore] = true + } + } + } + } + + return Response{ + Error: emptyResponse, + Data: Message{Message: "Request validated successfully"}, + }, nil +} + +func (m *InferFlow) ValidateOnboardRequest(request OnboardPayload) (Response, error) { + outputs := mapset.NewSet[string]() + deployableConfig, err := m.ServiceDeployableConfigRepo.GetById(request.ConfigMapping.DeployableID) + if err != nil { + return Response{ + Error: "Failed to fetch deployable config for the request", + Data: Message{Message: emptyResponse}, + }, errors.New("failed to fetch deployable config for the request") + } + permissibleEndpoints := m.EtcdConfig.GetConfiguredEndpoints(deployableConfig.Name) + for _, ranker := range request.Rankers { + if len(ranker.EntityID) == 0 { + return Response{ + Error: "Entity ID is not set for model: " + ranker.ModelName, + Data: Message{Message: emptyResponse}, + }, errors.New("Entity ID is not set for model: " + ranker.ModelName) + } + if !permissibleEndpoints.Contains(ranker.EndPoint) { + errorMsg := fmt.Sprintf( + "invalid endpoint: %s chosen for service deployable: %s for model: %s", + ranker.EndPoint, deployableConfig.Name, ranker.ModelName, + ) + return Response{ + Error: errorMsg, + Data: Message{Message: emptyResponse}, + }, errors.New(errorMsg) + } + for _, output := range ranker.Outputs { + if len(output.ModelScores) != len(output.ModelScoresDims) { + return Response{ + Error: "model scores and model scores dims are not equal for model: " + ranker.ModelName, + Data: Message{Message: emptyResponse}, + }, errors.New("model scores and model scores dims are not equal for model: " + ranker.ModelName) + } + for _, modelScore := range output.ModelScores { + if outputs.Contains(modelScore) { + return Response{ + Error: "duplicate model scores: " + modelScore + " for model: " + ranker.ModelName, + Data: Message{Message: emptyResponse}, + }, errors.New("duplicate model scores: " + modelScore + " for model: " + ranker.ModelName) + } + outputs.Add(modelScore) + } + } + } + + for _, reRanker := range request.ReRankers { + if len(reRanker.EntityID) == 0 { + return Response{ + Error: "Entity ID is not set for re ranker: " + reRanker.Score, + Data: Message{Message: emptyResponse}, + }, errors.New("Entity ID is not set for re ranker: " + reRanker.Score) + } + for _, value := range reRanker.EqVariables { + parts := strings.Split(value, PIPE_DELIMITER) + if len(parts) != 2 { + return Response{ + Error: "invalid eq variable: " + value, + Data: Message{Message: emptyResponse}, + }, errors.New("invalid eq variable: " + value) + } + if parts[1] == "" { + return Response{ + Error: "invalid eq variable: " + value, + Data: Message{Message: emptyResponse}, + }, errors.New("invalid eq variable: " + value) + } + } + if outputs.Contains(reRanker.Score) { + return Response{ + Error: "duplicate score: " + reRanker.Score + " for reRanker: " + reRanker.Score, + Data: Message{Message: emptyResponse}, + }, errors.New("duplicate score: " + reRanker.Score + " for reRanker: " + reRanker.Score) + } + outputs.Add(reRanker.Score) + } + + // Validate MODEL_FEATURE list + for _, ranker := range request.Rankers { + for _, input := range ranker.Inputs { + for _, feature := range input.Features { + featureParts := strings.Split(feature, PIPE_DELIMITER) + if len(featureParts) != 2 { + return Response{ + Error: "invalid feature: " + feature + " in input features of ranker: " + ranker.ModelName, + Data: Message{Message: emptyResponse}, + }, errors.New("invalid feature: " + feature + " in input features of ranker: " + ranker.ModelName) + } + if strings.Contains(featureParts[0], MODEL_FEATURE) { + if !outputs.Contains(featureParts[1]) { + return Response{ + Error: "model score " + featureParts[1] + " is not found in other model scores of ranker: " + ranker.ModelName, + Data: Message{Message: emptyResponse}, + }, errors.New("model score " + featureParts[1] + " is not found in other model scores of ranker: " + ranker.ModelName) + } + } + } + } + } + + for _, reRanker := range request.ReRankers { + for _, feature := range reRanker.EqVariables { + featureParts := strings.Split(feature, PIPE_DELIMITER) + if len(featureParts) != 2 { + return Response{ + Error: "invalid feature: " + feature, + Data: Message{Message: emptyResponse}, + }, errors.New("invalid feature: " + feature) + } + if strings.Contains(featureParts[0], MODEL_FEATURE) { + if !outputs.Contains(featureParts[1]) { + return Response{ + Error: "model score " + featureParts[1] + " is not found in other model scores of re ranker: " + strconv.Itoa(reRanker.EqID), + Data: Message{Message: emptyResponse}, + }, errors.New("model score " + featureParts[1] + " is not found in other model scores of re ranker: " + strconv.Itoa(reRanker.EqID)) + } + } + } + } + + return Response{ + Error: emptyResponse, + Data: Message{Message: "Request validated successfully"}, + }, nil +} From 1da7cebf65095b5e03642f6680d7c87599aed752 Mon Sep 17 00:00:00 2001 From: Paras Agarwal Date: Wed, 18 Feb 2026 14:09:45 +0530 Subject: [PATCH 23/24] feat: nodepool aware validation --- .pre-commit-config.yaml | 2 +- .../deployable/handler/modelhandler.go | 1 + .../deployable/handler/predatorhandler.go | 1 + .../inferflow/handler/inferflow_test.go | 4 + .../controller/controller_test.go | 5 + .../infrastructure/handler/handler.go | 24 ++++ horizon/internal/predator/handler/predator.go | 3 +- .../predator/handler/predator_test.go | 4 + .../predator/handler/predator_validation.go | 104 ++++++++++++++---- .../sql/servicedeployableconfig/sql.go | 25 +++++ .../sql/servicedeployableconfig/table.go | 6 + horizon/pkg/argocd/hpa.go | 60 ++++++++++ quick-start/db-init/scripts/init-mysql.sh | 7 +- 13 files changed, 220 insertions(+), 26 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c721100c..e1fccdbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,6 @@ repos: - id: trufflehog name: TruffleHog description: Detect secrets in your data. - entry: "trufflehog/trufflehog-hook.sh" + entry: "pre-commit-scripts/runner.sh" language: script stages: ["pre-commit", "pre-push"] diff --git a/horizon/internal/deployable/handler/modelhandler.go b/horizon/internal/deployable/handler/modelhandler.go index ac98b222..46ee4fbc 100644 --- a/horizon/internal/deployable/handler/modelhandler.go +++ b/horizon/internal/deployable/handler/modelhandler.go @@ -41,6 +41,7 @@ func (h *InferflowHandler) CreateDeployable(request *DeployableRequest) error { deployableConfig := &servicedeployableconfig.ServiceDeployableConfig{ Name: request.AppName, Service: request.ServiceName, + DeployableType: servicedeployableconfig.DeployableTypeTarget, Host: request.AppName + "." + hostUrlSuffix, Active: true, CreatedBy: request.CreatedBy, diff --git a/horizon/internal/deployable/handler/predatorhandler.go b/horizon/internal/deployable/handler/predatorhandler.go index 19c45f7a..8feae1fe 100644 --- a/horizon/internal/deployable/handler/predatorhandler.go +++ b/horizon/internal/deployable/handler/predatorhandler.go @@ -519,6 +519,7 @@ func (h *Handler) CreateDeployable(request *DeployableRequest, workingEnv string deployableConfig := &servicedeployableconfig.ServiceDeployableConfig{ Name: request.AppName, Service: request.ServiceName, + DeployableType: servicedeployableconfig.DeployableTypeTarget, Host: host, // Use environment-prefixed host for database uniqueness Active: true, CreatedBy: request.CreatedBy, diff --git a/horizon/internal/inferflow/handler/inferflow_test.go b/horizon/internal/inferflow/handler/inferflow_test.go index 8e5fc9f3..f1832491 100644 --- a/horizon/internal/inferflow/handler/inferflow_test.go +++ b/horizon/internal/inferflow/handler/inferflow_test.go @@ -7,6 +7,7 @@ import ( service_deployable_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" ) func TestInferFlow_GetLoggingTTL(t *testing.T) { @@ -151,3 +152,6 @@ func (m *mockServiceDeployableRepo) GetByNameAndService(_, _ string) (*service_d func (m *mockServiceDeployableRepo) GetByIds(_ []int) ([]service_deployable_config.ServiceDeployableConfig, error) { return nil, nil } +func (m *mockServiceDeployableRepo) GetTestDeployableIDByNodePool(_ string) (int, error) { + return 0, gorm.ErrRecordNotFound +} diff --git a/horizon/internal/infrastructure/controller/controller_test.go b/horizon/internal/infrastructure/controller/controller_test.go index 00daa41e..22822a81 100644 --- a/horizon/internal/infrastructure/controller/controller_test.go +++ b/horizon/internal/infrastructure/controller/controller_test.go @@ -47,6 +47,11 @@ func (m *MockInfrastructureHandler) RestartDeployment(appName, workingEnv string return args.Error(0) } +func (m *MockInfrastructureHandler) ScaleDeployable(appName, workingEnv string, minReplica, maxReplica int) error { + args := m.Called(appName, workingEnv, minReplica, maxReplica) + return args.Error(0) +} + func (m *MockInfrastructureHandler) UpdateCPUThreshold(appName, threshold, email, workingEnv string) error { args := m.Called(appName, threshold, email, workingEnv) return args.Error(0) diff --git a/horizon/internal/infrastructure/handler/handler.go b/horizon/internal/infrastructure/handler/handler.go index 5a3c962c..c3d86662 100644 --- a/horizon/internal/infrastructure/handler/handler.go +++ b/horizon/internal/infrastructure/handler/handler.go @@ -22,6 +22,7 @@ type InfrastructureHandler interface { GetConfig(serviceName, workingEnv string) Config GetResourceDetail(appName, workingEnv string) (*ResourceDetail, error) RestartDeployment(appName, workingEnv string, isCanary bool) error + ScaleDeployable(appName, workingEnv string, minReplica, maxReplica int) error UpdateCPUThreshold(appName, threshold, email, workingEnv string) error UpdateGPUThreshold(appName, threshold, email, workingEnv string) error UpdateSharedMemory(appName, size, email, workingEnv string) error @@ -238,6 +239,29 @@ func (h *infrastructureHandler) RestartDeployment(appName, workingEnv string, is return nil } +func (h *infrastructureHandler) ScaleDeployable(appName, workingEnv string, minReplica, maxReplica int) error { + if appName == "" { + return fmt.Errorf("appName is required") + } + if workingEnv == "" { + return fmt.Errorf("workingEnv is required") + } + argocdAppName := getArgoCDApplicationName(appName, workingEnv) + log.Info(). + Str("appName", appName). + Str("argocdAppName", argocdAppName). + Str("workingEnv", workingEnv). + Int("minReplica", minReplica). + Int("maxReplica", maxReplica). + Msg("ScaleDeployable: scaling deployable") + err := argocd.SetDeployableReplicas(argocdAppName, workingEnv, minReplica, maxReplica) + if err != nil { + log.Error().Err(err).Str("appName", appName).Msg("Failed to scale deployable") + return fmt.Errorf("failed to scale deployable: %w", err) + } + return nil +} + func (h *infrastructureHandler) UpdateCPUThreshold(appName, threshold, email, workingEnv string) error { log.Info().Str("appName", appName).Str("threshold", threshold).Str("workingEnv", workingEnv).Str("email", email).Msg("Updating CPU threshold") diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index be5f0b2d..1fc525c0 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -676,8 +676,7 @@ func (p *Predator) ValidateRequest(groupId string) (string, int) { lock, err := p.validationLockRepo.AcquireLock(lockKey, 30*time.Minute) if err != nil { log.Warn().Err(err).Msgf("Validation request for group ID %s rejected - failed to acquire lock for deployable %d", groupId, testDeployableID) - return fmt.Sprintf("Request Validation Failed. Another validation is already in progress for %s deployable. Please try again later.", - map[int]string{pred.TestDeployableID: "CPU", pred.TestGpuDeployableID: "GPU"}[testDeployableID]), http.StatusConflict + return fmt.Sprintf("Request Validation Failed. Another validation is already in progress for deployable %d. Please try again later.", testDeployableID), http.StatusConflict } log.Info().Msgf("Starting validation for group ID: %s on deployable %d (lock acquired by %s)", groupId, testDeployableID, lock.LockedBy) diff --git a/horizon/internal/predator/handler/predator_test.go b/horizon/internal/predator/handler/predator_test.go index 9ae81749..ae7876d2 100644 --- a/horizon/internal/predator/handler/predator_test.go +++ b/horizon/internal/predator/handler/predator_test.go @@ -8,6 +8,7 @@ import ( "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" ) func TestPredator_ValidateRequest_InvalidGroupIDFormat(t *testing.T) { @@ -188,3 +189,6 @@ func (m *predatorMockServiceDeployableRepo) GetByNameAndService(_, _ string) (*s func (m *predatorMockServiceDeployableRepo) GetByIds(_ []int) ([]servicedeployableconfig.ServiceDeployableConfig, error) { return nil, nil } +func (m *predatorMockServiceDeployableRepo) GetTestDeployableIDByNodePool(_ string) (int, error) { + return 0, gorm.ErrRecordNotFound +} diff --git a/horizon/internal/predator/handler/predator_validation.go b/horizon/internal/predator/handler/predator_validation.go index 15021baf..e809d71f 100644 --- a/horizon/internal/predator/handler/predator_validation.go +++ b/horizon/internal/predator/handler/predator_validation.go @@ -13,6 +13,7 @@ import ( "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorrequest" "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/validationjob" "github.com/rs/zerolog/log" + "gorm.io/gorm" ) func (p *Predator) ValidateDeleteRequest(predatorConfigList []predatorconfig.PredatorConfig, ids []int) (bool, error) { @@ -142,36 +143,69 @@ func (p *Predator) releaseLockWithError(lockID uint, groupID, errorMsg string) { log.Error().Msgf("Validation failed for group ID %s: %s", groupID, errorMsg) } -// getTestDeployableID determines the appropriate test deployable ID based on machine type +// getTestDeployableID resolves the test deployable ID: tries DB lookup by node pool when available; +// if not found or no node pool, falls back to env-based ID by machine type (CPU: TEST_DEPLOYABLE_ID, GPU: TEST_GPU_DEPLOYABLE_ID). func (p *Predator) getTestDeployableID(payload *Payload) (int, error) { - // Get the target deployable ID from the request + if payload == nil { + return 0, fmt.Errorf("payload is required") + } + if payload.ConfigMapping.ServiceDeployableID == 0 { + return 0, fmt.Errorf("service_deployable_id is required in config_mapping") + } + targetDeployableID := int(payload.ConfigMapping.ServiceDeployableID) - // Fetch the service deployable config to check machine type serviceDeployable, err := p.ServiceDeployableRepo.GetById(targetDeployableID) if err != nil { - return 0, fmt.Errorf("failed to fetch service deployable config: %w", err) + return 0, fmt.Errorf("failed to fetch service deployable %d: %w", targetDeployableID, err) + } + + if len(serviceDeployable.Config) == 0 { + return 0, fmt.Errorf("target deployable %d has no config; cannot determine machine type", targetDeployableID) } - // Parse the deployable config to extract machine type var deployableConfig PredatorDeployableConfig if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err != nil { - return 0, fmt.Errorf("failed to parse service deployable config: %w", err) - } - - // Select test deployable ID based on machine type - switch strings.ToUpper(deployableConfig.MachineType) { - case "CPU": - log.Info().Msgf("Using CPU test deployable ID: %d", pred.TestDeployableID) - return pred.TestDeployableID, nil - case "GPU": - log.Info().Msgf("Using GPU test deployable ID: %d", pred.TestGpuDeployableID) - return pred.TestGpuDeployableID, nil - default: - // Default to CPU if machine type is not specified or unknown - log.Warn().Msgf("Unknown machine type '%s', defaulting to CPU test deployable ID: %d", - deployableConfig.MachineType, pred.TestDeployableID) - return pred.TestDeployableID, nil + return 0, fmt.Errorf("failed to parse service deployable config %d: %w", targetDeployableID, err) + } + + return p.resolveTestDeployableID(deployableConfig) +} + +// resolveTestDeployableID resolves test deployable ID from deployableConfig: node-pool lookup first, then machine-type fallback. +func (p *Predator) resolveTestDeployableID(deployableConfig PredatorDeployableConfig) (int, error) { + var testID int + nodePool := strings.TrimSpace(deployableConfig.NodeSelectorValue) + if nodePool != "" { + id, lookupErr := p.ServiceDeployableRepo.GetTestDeployableIDByNodePool(nodePool) + if lookupErr == nil { + testID = id + log.Info().Msgf("Using test deployable ID %d for node pool %s", testID, nodePool) + } else if errors.Is(lookupErr, gorm.ErrRecordNotFound) { + log.Info().Str("nodePool", nodePool).Msgf("no test deployable for node pool %q (deployable_type=test, config.nodeSelectorValue=%q), using machine-type fallback", nodePool, nodePool) + } else { + log.Info().Err(lookupErr).Str("nodePool", nodePool).Msg("Test deployable lookup by node pool failed, using machine-type fallback") + } + } + + if testID == 0 { + switch strings.ToUpper(deployableConfig.MachineType) { + case "CPU": + testID = pred.TestDeployableID + log.Info().Msgf("Using CPU fallback test deployable ID: %d", testID) + case "GPU": + testID = pred.TestGpuDeployableID + log.Info().Msgf("Using GPU fallback test deployable ID: %d", testID) + default: + testID = pred.TestDeployableID + log.Warn().Msgf("Unknown machine type %q, defaulting to CPU fallback test deployable ID: %d", + deployableConfig.MachineType, testID) + } } + + if testID <= 0 { + return 0, fmt.Errorf("invalid test deployable ID (not configured or not found); check TEST_DEPLOYABLE_ID (CPU), TEST_GPU_DEPLOYABLE_ID or deployable_type=test for node pool (GPU)") + } + return testID, nil } // getServiceNameFromDeployable extracts service name from deployable configuration @@ -183,6 +217,21 @@ func (p *Predator) getServiceNameFromDeployable(deployableID int) (string, error return serviceDeployable.Name, nil } +// scaleTestDeployable sets min/max replicas for the test deployable (by ID) via the infrastructure handler. +func (p *Predator) scaleTestDeployable(deployableID int, minReplica, maxReplica int) error { + if deployableID <= 0 { + return fmt.Errorf("invalid deployable ID for scaling: %d", deployableID) + } + sd, err := p.ServiceDeployableRepo.GetById(deployableID) + if err != nil { + return fmt.Errorf("failed to fetch test deployable: %w", err) + } + if strings.TrimSpace(sd.Name) == "" { + return fmt.Errorf("test deployable %d has no name; cannot scale", deployableID) + } + return p.infrastructureHandler.ScaleDeployable(sd.Name, p.workingEnv, minReplica, maxReplica) +} + // performAsyncValidation performs the actual validation process asynchronously func (p *Predator) performAsyncValidation(job *validationjob.Table, requests []predatorrequest.PredatorRequest, payload *Payload, testDeployableID int) { defer func() { @@ -192,9 +241,22 @@ func (p *Predator) performAsyncValidation(job *validationjob.Table, requests []p } log.Info().Msgf("Released validation lock for job %d", job.ID) }() + defer func() { + // Always scale test deployable back to 0 when validation finishes (success or failure) + if scaleErr := p.scaleTestDeployable(testDeployableID, 0, 0); scaleErr != nil { + log.Error().Err(scaleErr).Msgf("Failed to scale down test deployable %d after validation", testDeployableID) + } + }() log.Info().Msgf("Starting async validation for job %d, group %s", job.ID, job.GroupID) + // Scale up test deployable from 0 to 1 so validation can run + if err := p.scaleTestDeployable(testDeployableID, 1, 1); err != nil { + log.Error().Err(err).Msg("Failed to scale up test deployable") + p.failValidationJob(job.ID, "Failed to scale up test deployable: "+err.Error()) + return + } + // Step 1: Clear temporary deployable if err := p.clearTemporaryDeployable(testDeployableID); err != nil { log.Error().Err(err).Msg("Failed to clear temporary deployable") diff --git a/horizon/internal/repositories/sql/servicedeployableconfig/sql.go b/horizon/internal/repositories/sql/servicedeployableconfig/sql.go index 09744fe9..bab04688 100644 --- a/horizon/internal/repositories/sql/servicedeployableconfig/sql.go +++ b/horizon/internal/repositories/sql/servicedeployableconfig/sql.go @@ -19,6 +19,8 @@ type ServiceDeployableRepository interface { GetByDeployableHealth(health string) ([]ServiceDeployableConfig, error) GetByNameAndService(name, service string) (*ServiceDeployableConfig, error) GetByIds(ids []int) ([]ServiceDeployableConfig, error) + // GetTestDeployableIDByNodePool returns the ID of a test deployable whose config.nodeSelectorValue matches the node pool. + GetTestDeployableIDByNodePool(nodePool string) (int, error) } type serviceDeployableRepo struct { @@ -108,3 +110,26 @@ func (r *serviceDeployableRepo) GetByIds(ids []int) ([]ServiceDeployableConfig, err := r.db.Where("id IN ?", ids).Find(&deployables).Error return deployables, err } + +func (r *serviceDeployableRepo) GetTestDeployableIDByNodePool(nodePool string) (int, error) { + var id int + + tx := r.db. + Model(&ServiceDeployableConfig{}). + Select("id"). + Where("deployable_type = ?", DeployableTypeTest). + Where("JSON_UNQUOTE(JSON_EXTRACT(config, '$.nodeSelectorValue')) = ?", nodePool). + Limit(1). + Scan(&id) + + if tx.Error != nil { + return 0, tx.Error + } + + if tx.RowsAffected == 0 || id == 0 { + return 0, gorm.ErrRecordNotFound + } + + + return id, nil +} diff --git a/horizon/internal/repositories/sql/servicedeployableconfig/table.go b/horizon/internal/repositories/sql/servicedeployableconfig/table.go index a9132d6e..1d76f3f2 100644 --- a/horizon/internal/repositories/sql/servicedeployableconfig/table.go +++ b/horizon/internal/repositories/sql/servicedeployableconfig/table.go @@ -10,11 +10,17 @@ import ( const ServiceDeployableTableName = "service_deployable_config" +const ( + DeployableTypeTest = "test" + DeployableTypeTarget = "target" +) + type ServiceDeployableConfig struct { ID int `gorm:"primaryKey,autoIncrement"` Name string Host string `gorm:"unique;not null"` Service string `gorm:"type:ENUM('inferflow', 'predator', 'numerix')"` + DeployableType string `gorm:"column:deployable_type;type:ENUM('test', 'target');default:'target';not null"` Active bool `gorm:"default:false"` // Port int `gorm:"default:8080"` // Port field for the deployable CreatedBy string diff --git a/horizon/pkg/argocd/hpa.go b/horizon/pkg/argocd/hpa.go index 61c5231f..f72d7283 100644 --- a/horizon/pkg/argocd/hpa.go +++ b/horizon/pkg/argocd/hpa.go @@ -1,6 +1,8 @@ package argocd import ( + "fmt" + "github.com/Meesho/BharatMLStack/horizon/pkg/kubernetes" "github.com/rs/zerolog/log" ) @@ -60,3 +62,61 @@ func GetScaledObjectProperties(applicationName string, workingEnv string) (kuber } return policy, nil } + +// SetDeployableReplicas sets min and max replicas for a deployable by patching HPA or KEDA ScaledObject. +// applicationName is the ArgoCD application name (e.g. from GetArgocdApplicationNameFromEnv). +func SetDeployableReplicas(applicationName, workingEnv string, minReplica, maxReplica int) error { + if minReplica < 0 || maxReplica < 0 { + return fmt.Errorf("replicas must be non-negative: minReplica=%d, maxReplica=%d", minReplica, maxReplica) + } + if minReplica > maxReplica { + return fmt.Errorf("minReplica (%d) must not exceed maxReplica (%d)", minReplica, maxReplica) + } + + log.Info(). + Str("applicationName", applicationName). + Str("workingEnv", workingEnv). + Int("minReplica", minReplica). + Int("maxReplica", maxReplica). + Msg("SetDeployableReplicas: patching scaling resource") + + isCanary := IsCanary(applicationName, workingEnv) + + // Try HPA first + var hpaErr error + hpaResource, err := GetArgoCDResource("HorizontalPodAutoscaler", applicationName, isCanary) + if err == nil { + _, patchErr := hpaResource.PatchArgoCDResource(map[string]interface{}{ + "spec": map[string]interface{}{ + "minReplicas": minReplica, + "maxReplicas": maxReplica, + }, + }, workingEnv) + if patchErr == nil { + log.Info().Str("applicationName", applicationName).Msg("SetDeployableReplicas: successfully patched HPA") + return nil + } + log.Error().Err(patchErr).Str("applicationName", applicationName).Msg("SetDeployableReplicas: failed to patch HPA, trying ScaledObject") + hpaErr = patchErr + } else { + hpaErr = err + } + + // HPA not found or patch failed; try KEDA ScaledObject + scaledObjResource, scaledErr := GetArgoCDResource("ScaledObject", applicationName, isCanary) + if scaledErr != nil { + return fmt.Errorf("neither HPA nor ScaledObject found for application %s (HPA: %v; ScaledObject: %v)", applicationName, hpaErr, scaledErr) + } + _, patchErr := scaledObjResource.PatchArgoCDResource(map[string]interface{}{ + "spec": map[string]interface{}{ + "minReplicaCount": minReplica, + "maxReplicaCount": maxReplica, + }, + }, workingEnv) + if patchErr != nil { + log.Error().Err(patchErr).Str("applicationName", applicationName).Msg("SetDeployableReplicas: failed to patch ScaledObject") + return fmt.Errorf("neither HPA nor ScaledObject succeeded for application %s (HPA: %v; ScaledObject patch: %w)", applicationName, hpaErr, patchErr) + } + log.Info().Str("applicationName", applicationName).Msg("SetDeployableReplicas: successfully patched ScaledObject") + return nil +} diff --git a/quick-start/db-init/scripts/init-mysql.sh b/quick-start/db-init/scripts/init-mysql.sh index e20e335a..8e0c19f9 100644 --- a/quick-start/db-init/scripts/init-mysql.sh +++ b/quick-start/db-init/scripts/init-mysql.sh @@ -391,6 +391,7 @@ mysql -hmysql -uroot -proot --skip-ssl -e " work_flow_status enum('WORKFLOW_COMPLETED','WORKFLOW_NOT_FOUND','WORKFLOW_RUNNING','WORKFLOW_FAILED','WORKFLOW_NOT_STARTED'), override_testing TINYINT(1) DEFAULT 0, deployable_tag varchar(255) NULL, + deployable_type enum('test', 'target') NOT NULL DEFAULT 'target', PRIMARY KEY (id), UNIQUE KEY host (host) ); @@ -459,7 +460,7 @@ mysql -hmysql -uroot -proot --skip-ssl testdb -e " # VALUES (1, 1, NOW(), NOW()); INSERT IGNORE INTO service_deployable_config ( - id, name, host, service, active, created_by, updated_by, + id, name, host, service, deployable_type, active, created_by, updated_by, created_at, updated_at, config, monitoring_url, deployable_running_status, deployable_work_flow_id, deployment_run_id, deployable_health, work_flow_status ) VALUES ( @@ -467,6 +468,7 @@ mysql -hmysql -uroot -proot --skip-ssl testdb -e " 'numerix', 'numerix:8083', 'numerix', + 'target', 1, 'admin@admin.com', NULL, @@ -482,7 +484,7 @@ mysql -hmysql -uroot -proot --skip-ssl testdb -e " ); INSERT IGNORE INTO service_deployable_config ( - id, name, host, service, active, created_by, updated_by, + id, name, host, service, deployable_type, active, created_by, updated_by, created_at, updated_at, config, monitoring_url, deployable_running_status, deployable_work_flow_id, deployment_run_id, deployable_health, work_flow_status ) VALUES ( @@ -490,6 +492,7 @@ mysql -hmysql -uroot -proot --skip-ssl testdb -e " 'inferflow', 'inferflow:8085', 'inferflow', + 'target', 1, 'admin@admin.com', NULL, From 080fd4f1d580955a2965135faef4489163955bde Mon Sep 17 00:00:00 2001 From: Paras Agarwal Date: Wed, 18 Feb 2026 17:04:03 +0530 Subject: [PATCH 24/24] Merge Conflicts corrected --- inferflow/handlers/inferflow/inferflow.go | 3 --- quick-start/db-init/scripts/init-mysql.sh | 6 ------ 2 files changed, 9 deletions(-) diff --git a/inferflow/handlers/inferflow/inferflow.go b/inferflow/handlers/inferflow/inferflow.go index fb0837b4..bbd379e0 100644 --- a/inferflow/handlers/inferflow/inferflow.go +++ b/inferflow/handlers/inferflow/inferflow.go @@ -72,13 +72,10 @@ func InitInferflowHandler(configs *configs.AppConfigs) { }, }, } -<<<<<<< HEAD -======= // Initialize Kafka writers for inference logging kafkaLogger.InitKafkaLogger(configs) ->>>>>>> origin/develop logger.Info("Inferflow handler initialized") } diff --git a/quick-start/db-init/scripts/init-mysql.sh b/quick-start/db-init/scripts/init-mysql.sh index 8a615861..7c4d4f2a 100644 --- a/quick-start/db-init/scripts/init-mysql.sh +++ b/quick-start/db-init/scripts/init-mysql.sh @@ -520,15 +520,9 @@ mysql -hmysql -uroot -proot --skip-ssl -e " deployment_run_id varchar(255), deployable_health enum('DEPLOYMENT_REASON_ARGO_APP_HEALTH_DEGRADED', 'DEPLOYMENT_REASON_ARGO_APP_HEALTHY'), work_flow_status enum('WORKFLOW_COMPLETED','WORKFLOW_NOT_FOUND','WORKFLOW_RUNNING','WORKFLOW_FAILED','WORKFLOW_NOT_STARTED'), -<<<<<<< HEAD - override_testing TINYINT(1) DEFAULT 0, - deployable_tag varchar(255) NULL, - deployable_type enum('test', 'target') NOT NULL DEFAULT 'target', -======= override_testing tinyint(1) DEFAULT 0, deployable_tag varchar(255), bulk_delete_enabled tinyint(1) NOT NULL DEFAULT 0, ->>>>>>> origin/develop PRIMARY KEY (id), UNIQUE KEY host (host) );