diff --git a/horizon/internal/predator/handler/model.go b/horizon/internal/predator/handler/model.go index c05cb329..08b0e50c 100644 --- a/horizon/internal/predator/handler/model.go +++ b/horizon/internal/predator/handler/model.go @@ -6,11 +6,13 @@ import ( ) type Payload struct { - ModelName string `json:"model_name"` - ModelSource string `json:"model_source_path,omitempty"` - MetaData MetaData `json:"meta_data"` - ConfigMapping ConfigMapping `json:"config_mapping"` - DiscoveryConfigID uint `json:"discovery_config_id"` + ModelName string `json:"model_name"` + ModelSource string `json:"model_source_path,omitempty"` + MetaData MetaData `json:"meta_data"` + ConfigMapping ConfigMapping `json:"config_mapping"` + DiscoveryConfigID uint `json:"discovery_config_id"` + IsLoadTested bool `json:"is_load_tested,omitempty"` + LoadTestResultsLink string `json:"load_test_results_link,omitempty"` } type MetaData struct { @@ -44,7 +46,7 @@ type IOField struct { } type ConfigMapping struct { - ServiceDeployableID uint `json:"service_deployable_id"` + ServiceDeployableID uint `json:"service_deployable_id"` SourceModelName string `json:"source_model_name,omitempty"` } diff --git a/horizon/internal/predator/handler/predator.go b/horizon/internal/predator/handler/predator.go index a2459cdf..0a369ddb 100644 --- a/horizon/internal/predator/handler/predator.go +++ b/horizon/internal/predator/handler/predator.go @@ -172,6 +172,13 @@ func (p *Predator) HandleModelRequest(req ModelRequest, requestType string) (str if err := json.Unmarshal(payloadBytes, &payloadObject); err != nil { return constant.EmptyString, http.StatusInternalServerError, errors.New(errMsgProcessPayload) } + // Validate load test fields for promote requests + if requestType == PromoteRequestType && payloadObject.IsLoadTested { + if payloadObject.LoadTestResultsLink == constant.EmptyString { + return constant.EmptyString, http.StatusBadRequest, errors.New("load test results link is required when load tested is true for the model requested") + } + } + derivedModelName, err := p.GetDerivedModelName(payloadObject, requestType) if err != nil { return constant.EmptyString, http.StatusInternalServerError, fmt.Errorf("failed to fetch derived model name: %w", err)