diff --git a/horizon/internal/deployable/handler/modelhandler.go b/horizon/internal/deployable/handler/modelhandler.go index ac98b222..46ee4fbc 100644 --- a/horizon/internal/deployable/handler/modelhandler.go +++ b/horizon/internal/deployable/handler/modelhandler.go @@ -41,6 +41,7 @@ func (h *InferflowHandler) CreateDeployable(request *DeployableRequest) error { deployableConfig := &servicedeployableconfig.ServiceDeployableConfig{ Name: request.AppName, Service: request.ServiceName, + DeployableType: servicedeployableconfig.DeployableTypeTarget, Host: request.AppName + "." + hostUrlSuffix, Active: true, CreatedBy: request.CreatedBy, diff --git a/horizon/internal/deployable/handler/predatorhandler.go b/horizon/internal/deployable/handler/predatorhandler.go index 39e1172d..956d8f52 100644 --- a/horizon/internal/deployable/handler/predatorhandler.go +++ b/horizon/internal/deployable/handler/predatorhandler.go @@ -519,6 +519,7 @@ func (h *Handler) CreateDeployable(request *DeployableRequest, workingEnv string deployableConfig := &servicedeployableconfig.ServiceDeployableConfig{ Name: request.AppName, Service: request.ServiceName, + DeployableType: servicedeployableconfig.DeployableTypeTarget, Host: host, // Use environment-prefixed host for database uniqueness Active: true, CreatedBy: request.CreatedBy, diff --git a/horizon/internal/inferflow/handler/inferflow_test.go b/horizon/internal/inferflow/handler/inferflow_test.go index 8e5fc9f3..f1832491 100644 --- a/horizon/internal/inferflow/handler/inferflow_test.go +++ b/horizon/internal/inferflow/handler/inferflow_test.go @@ -7,6 +7,7 @@ import ( service_deployable_config "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" ) func TestInferFlow_GetLoggingTTL(t *testing.T) { @@ -151,3 +152,6 @@ func (m *mockServiceDeployableRepo) GetByNameAndService(_, _ string) (*service_d func (m *mockServiceDeployableRepo) GetByIds(_ []int) ([]service_deployable_config.ServiceDeployableConfig, error) { return nil, nil } +func (m *mockServiceDeployableRepo) GetTestDeployableIDByNodePool(_ string) (int, error) { + return 0, gorm.ErrRecordNotFound +} diff --git a/horizon/internal/infrastructure/controller/controller_test.go b/horizon/internal/infrastructure/controller/controller_test.go index 00daa41e..22822a81 100644 --- a/horizon/internal/infrastructure/controller/controller_test.go +++ b/horizon/internal/infrastructure/controller/controller_test.go @@ -47,6 +47,11 @@ func (m *MockInfrastructureHandler) RestartDeployment(appName, workingEnv string return args.Error(0) } +func (m *MockInfrastructureHandler) ScaleDeployable(appName, workingEnv string, minReplica, maxReplica int) error { + args := m.Called(appName, workingEnv, minReplica, maxReplica) + return args.Error(0) +} + func (m *MockInfrastructureHandler) UpdateCPUThreshold(appName, threshold, email, workingEnv string) error { args := m.Called(appName, threshold, email, workingEnv) return args.Error(0) diff --git a/horizon/internal/infrastructure/handler/handler.go b/horizon/internal/infrastructure/handler/handler.go index 5a3c962c..c3d86662 100644 --- a/horizon/internal/infrastructure/handler/handler.go +++ b/horizon/internal/infrastructure/handler/handler.go @@ -22,6 +22,7 @@ type InfrastructureHandler interface { GetConfig(serviceName, workingEnv string) Config GetResourceDetail(appName, workingEnv string) (*ResourceDetail, error) RestartDeployment(appName, workingEnv string, isCanary bool) error + ScaleDeployable(appName, workingEnv string, minReplica, maxReplica int) error UpdateCPUThreshold(appName, threshold, email, workingEnv string) error UpdateGPUThreshold(appName, threshold, email, workingEnv string) error UpdateSharedMemory(appName, size, email, workingEnv string) error @@ -238,6 +239,29 @@ func (h *infrastructureHandler) RestartDeployment(appName, workingEnv string, is return nil } +func (h *infrastructureHandler) ScaleDeployable(appName, workingEnv string, minReplica, maxReplica int) error { + if appName == "" { + return fmt.Errorf("appName is required") + } + if workingEnv == "" { + return fmt.Errorf("workingEnv is required") + } + argocdAppName := getArgoCDApplicationName(appName, workingEnv) + log.Info(). + Str("appName", appName). + Str("argocdAppName", argocdAppName). + Str("workingEnv", workingEnv). + Int("minReplica", minReplica). + Int("maxReplica", maxReplica). + Msg("ScaleDeployable: scaling deployable") + err := argocd.SetDeployableReplicas(argocdAppName, workingEnv, minReplica, maxReplica) + if err != nil { + log.Error().Err(err).Str("appName", appName).Msg("Failed to scale deployable") + return fmt.Errorf("failed to scale deployable: %w", err) + } + return nil +} + func (h *infrastructureHandler) UpdateCPUThreshold(appName, threshold, email, workingEnv string) error { log.Info().Str("appName", appName).Str("threshold", threshold).Str("workingEnv", workingEnv).Str("email", email).Msg("Updating CPU threshold") diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index a2459cdf..8235fee0 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -676,8 +676,7 @@ func (p *Predator) ValidateRequest(groupId string) (string, int) { lock, err := p.validationLockRepo.AcquireLock(lockKey, 30*time.Minute) if err != nil { log.Warn().Err(err).Msgf("Validation request for group ID %s rejected - failed to acquire lock for deployable %d", groupId, testDeployableID) - return fmt.Sprintf("Request Validation Failed. Another validation is already in progress for %s deployable. Please try again later.", - map[int]string{pred.TestDeployableID: "CPU", pred.TestGpuDeployableID: "GPU"}[testDeployableID]), http.StatusConflict + return fmt.Sprintf("Request Validation Failed. Another validation is already in progress for deployable %d. Please try again later.", testDeployableID), http.StatusConflict } log.Info().Msgf("Starting validation for group ID: %s on deployable %d (lock acquired by %s)", groupId, testDeployableID, lock.LockedBy) diff --git a/horizon/internal/predator/handler/predator_test.go b/horizon/internal/predator/handler/predator_test.go index 9ae81749..ae7876d2 100644 --- a/horizon/internal/predator/handler/predator_test.go +++ b/horizon/internal/predator/handler/predator_test.go @@ -8,6 +8,7 @@ import ( "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/servicedeployableconfig" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" ) func TestPredator_ValidateRequest_InvalidGroupIDFormat(t *testing.T) { @@ -188,3 +189,6 @@ func (m *predatorMockServiceDeployableRepo) GetByNameAndService(_, _ string) (*s func (m *predatorMockServiceDeployableRepo) GetByIds(_ []int) ([]servicedeployableconfig.ServiceDeployableConfig, error) { return nil, nil } +func (m *predatorMockServiceDeployableRepo) GetTestDeployableIDByNodePool(_ string) (int, error) { + return 0, gorm.ErrRecordNotFound +} diff --git a/horizon/internal/predator/handler/predator_validation.go b/horizon/internal/predator/handler/predator_validation.go index 15021baf..c57f4bba 100644 --- a/horizon/internal/predator/handler/predator_validation.go +++ b/horizon/internal/predator/handler/predator_validation.go @@ -13,6 +13,7 @@ import ( "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/predatorrequest" "github.com/Meesho/BharatMLStack/horizon/internal/repositories/sql/validationjob" "github.com/rs/zerolog/log" + "gorm.io/gorm" ) func (p *Predator) ValidateDeleteRequest(predatorConfigList []predatorconfig.PredatorConfig, ids []int) (bool, error) { @@ -142,36 +143,69 @@ func (p *Predator) releaseLockWithError(lockID uint, groupID, errorMsg string) { log.Error().Msgf("Validation failed for group ID %s: %s", groupID, errorMsg) } -// getTestDeployableID determines the appropriate test deployable ID based on machine type +// getTestDeployableID resolves the test deployable ID: tries DB lookup by node pool when available; +// if not found or no node pool, falls back to env-based ID by machine type (CPU: TEST_DEPLOYABLE_ID, GPU: TEST_GPU_DEPLOYABLE_ID). func (p *Predator) getTestDeployableID(payload *Payload) (int, error) { - // Get the target deployable ID from the request + if payload == nil { + return 0, fmt.Errorf("payload is required") + } + if payload.ConfigMapping.ServiceDeployableID == 0 { + return 0, fmt.Errorf("service_deployable_id is required in config_mapping") + } + targetDeployableID := int(payload.ConfigMapping.ServiceDeployableID) - // Fetch the service deployable config to check machine type serviceDeployable, err := p.ServiceDeployableRepo.GetById(targetDeployableID) if err != nil { - return 0, fmt.Errorf("failed to fetch service deployable config: %w", err) + return 0, fmt.Errorf("failed to fetch service deployable %d: %w", targetDeployableID, err) + } + + if len(serviceDeployable.Config) == 0 { + return 0, fmt.Errorf("target deployable %d has no config; cannot determine machine type", targetDeployableID) } - // Parse the deployable config to extract machine type var deployableConfig PredatorDeployableConfig if err := json.Unmarshal(serviceDeployable.Config, &deployableConfig); err != nil { - return 0, fmt.Errorf("failed to parse service deployable config: %w", err) - } - - // Select test deployable ID based on machine type - switch strings.ToUpper(deployableConfig.MachineType) { - case "CPU": - log.Info().Msgf("Using CPU test deployable ID: %d", pred.TestDeployableID) - return pred.TestDeployableID, nil - case "GPU": - log.Info().Msgf("Using GPU test deployable ID: %d", pred.TestGpuDeployableID) - return pred.TestGpuDeployableID, nil - default: - // Default to CPU if machine type is not specified or unknown - log.Warn().Msgf("Unknown machine type '%s', defaulting to CPU test deployable ID: %d", - deployableConfig.MachineType, pred.TestDeployableID) - return pred.TestDeployableID, nil + return 0, fmt.Errorf("failed to parse service deployable config %d: %w", targetDeployableID, err) + } + + return p.resolveTestDeployableID(deployableConfig) +} + +// resolveTestDeployableID resolves test deployable ID from deployableConfig: node-pool lookup first, then machine-type fallback. +func (p *Predator) resolveTestDeployableID(deployableConfig PredatorDeployableConfig) (int, error) { + var testID int + nodePool := strings.TrimSpace(deployableConfig.NodeSelectorValue) + if nodePool != "" { + id, lookupErr := p.ServiceDeployableRepo.GetTestDeployableIDByNodePool(nodePool) + if lookupErr == nil { + testID = id + log.Info().Msgf("Using test deployable ID %d for node pool %s", testID, nodePool) + } else if errors.Is(lookupErr, gorm.ErrRecordNotFound) { + log.Info().Str("nodePool", nodePool).Msgf("no test deployable for node pool %q (deployable_type=test, config.nodeSelectorValue=%q), using machine-type fallback", nodePool, nodePool) + } else { + log.Info().Err(lookupErr).Str("nodePool", nodePool).Msg("Test deployable lookup by node pool failed, using machine-type fallback") + } + } + + if testID == 0 { + switch strings.ToUpper(deployableConfig.MachineType) { + case "CPU": + testID = pred.TestDeployableID + log.Info().Msgf("Using CPU fallback test deployable ID: %d", testID) + case "GPU": + testID = pred.TestGpuDeployableID + log.Info().Msgf("Using GPU fallback test deployable ID: %d", testID) + default: + testID = pred.TestDeployableID + log.Warn().Msgf("Unknown machine type %q, defaulting to CPU fallback test deployable ID: %d", + deployableConfig.MachineType, testID) + } + } + + if testID <= 0 { + return 0, fmt.Errorf("invalid test deployable ID (not configured or not found); check TEST_DEPLOYABLE_ID (CPU), TEST_GPU_DEPLOYABLE_ID or deployable_type=test for node pool (GPU)") } + return testID, nil } // getServiceNameFromDeployable extracts service name from deployable configuration diff --git a/horizon/internal/repositories/sql/servicedeployableconfig/sql.go b/horizon/internal/repositories/sql/servicedeployableconfig/sql.go index 09744fe9..bab04688 100644 --- a/horizon/internal/repositories/sql/servicedeployableconfig/sql.go +++ b/horizon/internal/repositories/sql/servicedeployableconfig/sql.go @@ -19,6 +19,8 @@ type ServiceDeployableRepository interface { GetByDeployableHealth(health string) ([]ServiceDeployableConfig, error) GetByNameAndService(name, service string) (*ServiceDeployableConfig, error) GetByIds(ids []int) ([]ServiceDeployableConfig, error) + // GetTestDeployableIDByNodePool returns the ID of a test deployable whose config.nodeSelectorValue matches the node pool. + GetTestDeployableIDByNodePool(nodePool string) (int, error) } type serviceDeployableRepo struct { @@ -108,3 +110,26 @@ func (r *serviceDeployableRepo) GetByIds(ids []int) ([]ServiceDeployableConfig, err := r.db.Where("id IN ?", ids).Find(&deployables).Error return deployables, err } + +func (r *serviceDeployableRepo) GetTestDeployableIDByNodePool(nodePool string) (int, error) { + var id int + + tx := r.db. + Model(&ServiceDeployableConfig{}). + Select("id"). + Where("deployable_type = ?", DeployableTypeTest). + Where("JSON_UNQUOTE(JSON_EXTRACT(config, '$.nodeSelectorValue')) = ?", nodePool). + Limit(1). + Scan(&id) + + if tx.Error != nil { + return 0, tx.Error + } + + if tx.RowsAffected == 0 || id == 0 { + return 0, gorm.ErrRecordNotFound + } + + + return id, nil +} diff --git a/horizon/internal/repositories/sql/servicedeployableconfig/table.go b/horizon/internal/repositories/sql/servicedeployableconfig/table.go index a9132d6e..1d76f3f2 100644 --- a/horizon/internal/repositories/sql/servicedeployableconfig/table.go +++ b/horizon/internal/repositories/sql/servicedeployableconfig/table.go @@ -10,11 +10,17 @@ import ( const ServiceDeployableTableName = "service_deployable_config" +const ( + DeployableTypeTest = "test" + DeployableTypeTarget = "target" +) + type ServiceDeployableConfig struct { ID int `gorm:"primaryKey,autoIncrement"` Name string Host string `gorm:"unique;not null"` Service string `gorm:"type:ENUM('inferflow', 'predator', 'numerix')"` + DeployableType string `gorm:"column:deployable_type;type:ENUM('test', 'target');default:'target';not null"` Active bool `gorm:"default:false"` // Port int `gorm:"default:8080"` // Port field for the deployable CreatedBy string diff --git a/horizon/pkg/argocd/hpa.go b/horizon/pkg/argocd/hpa.go index 61c5231f..f72d7283 100644 --- a/horizon/pkg/argocd/hpa.go +++ b/horizon/pkg/argocd/hpa.go @@ -1,6 +1,8 @@ package argocd import ( + "fmt" + "github.com/Meesho/BharatMLStack/horizon/pkg/kubernetes" "github.com/rs/zerolog/log" ) @@ -60,3 +62,61 @@ func GetScaledObjectProperties(applicationName string, workingEnv string) (kuber } return policy, nil } + +// SetDeployableReplicas sets min and max replicas for a deployable by patching HPA or KEDA ScaledObject. +// applicationName is the ArgoCD application name (e.g. from GetArgocdApplicationNameFromEnv). +func SetDeployableReplicas(applicationName, workingEnv string, minReplica, maxReplica int) error { + if minReplica < 0 || maxReplica < 0 { + return fmt.Errorf("replicas must be non-negative: minReplica=%d, maxReplica=%d", minReplica, maxReplica) + } + if minReplica > maxReplica { + return fmt.Errorf("minReplica (%d) must not exceed maxReplica (%d)", minReplica, maxReplica) + } + + log.Info(). + Str("applicationName", applicationName). + Str("workingEnv", workingEnv). + Int("minReplica", minReplica). + Int("maxReplica", maxReplica). + Msg("SetDeployableReplicas: patching scaling resource") + + isCanary := IsCanary(applicationName, workingEnv) + + // Try HPA first + var hpaErr error + hpaResource, err := GetArgoCDResource("HorizontalPodAutoscaler", applicationName, isCanary) + if err == nil { + _, patchErr := hpaResource.PatchArgoCDResource(map[string]interface{}{ + "spec": map[string]interface{}{ + "minReplicas": minReplica, + "maxReplicas": maxReplica, + }, + }, workingEnv) + if patchErr == nil { + log.Info().Str("applicationName", applicationName).Msg("SetDeployableReplicas: successfully patched HPA") + return nil + } + log.Error().Err(patchErr).Str("applicationName", applicationName).Msg("SetDeployableReplicas: failed to patch HPA, trying ScaledObject") + hpaErr = patchErr + } else { + hpaErr = err + } + + // HPA not found or patch failed; try KEDA ScaledObject + scaledObjResource, scaledErr := GetArgoCDResource("ScaledObject", applicationName, isCanary) + if scaledErr != nil { + return fmt.Errorf("neither HPA nor ScaledObject found for application %s (HPA: %v; ScaledObject: %v)", applicationName, hpaErr, scaledErr) + } + _, patchErr := scaledObjResource.PatchArgoCDResource(map[string]interface{}{ + "spec": map[string]interface{}{ + "minReplicaCount": minReplica, + "maxReplicaCount": maxReplica, + }, + }, workingEnv) + if patchErr != nil { + log.Error().Err(patchErr).Str("applicationName", applicationName).Msg("SetDeployableReplicas: failed to patch ScaledObject") + return fmt.Errorf("neither HPA nor ScaledObject succeeded for application %s (HPA: %v; ScaledObject patch: %w)", applicationName, hpaErr, patchErr) + } + log.Info().Str("applicationName", applicationName).Msg("SetDeployableReplicas: successfully patched ScaledObject") + return nil +} diff --git a/quick-start/db-init/scripts/init-mysql.sh b/quick-start/db-init/scripts/init-mysql.sh index 2e0f2be6..7c4d4f2a 100644 --- a/quick-start/db-init/scripts/init-mysql.sh +++ b/quick-start/db-init/scripts/init-mysql.sh @@ -359,6 +359,7 @@ mysql -hmysql -uroot -proot --skip-ssl -e " created_at datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, test_results json, + source_config_id varchar(255) NULL, PRIMARY KEY (id), UNIQUE KEY config_id (config_id) ); @@ -439,6 +440,7 @@ mysql -hmysql -uroot -proot --skip-ssl -e " updated_at datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, test_results json, has_nil_data boolean DEFAULT false, + source_model_name varchar(255) NULL, PRIMARY KEY (id) ); @@ -601,7 +603,7 @@ mysql -hmysql -uroot -proot --skip-ssl testdb -e " # VALUES (1, 1, NOW(), NOW()); INSERT IGNORE INTO service_deployable_config ( - id, name, host, service, active, created_by, updated_by, + id, name, host, service, deployable_type, active, created_by, updated_by, created_at, updated_at, config, monitoring_url, deployable_running_status, deployable_work_flow_id, deployment_run_id, deployable_health, work_flow_status ) VALUES ( @@ -609,6 +611,7 @@ mysql -hmysql -uroot -proot --skip-ssl testdb -e " 'numerix', 'numerix:8083', 'numerix', + 'target', 1, 'admin@admin.com', NULL, @@ -624,7 +627,7 @@ mysql -hmysql -uroot -proot --skip-ssl testdb -e " ); INSERT IGNORE INTO service_deployable_config ( - id, name, host, service, active, created_by, updated_by, + id, name, host, service, deployable_type, active, created_by, updated_by, created_at, updated_at, config, monitoring_url, deployable_running_status, deployable_work_flow_id, deployment_run_id, deployable_health, work_flow_status ) VALUES ( @@ -632,6 +635,7 @@ mysql -hmysql -uroot -proot --skip-ssl testdb -e " 'inferflow', 'inferflow:8085', 'inferflow', + 'target', 1, 'admin@admin.com', NULL,