diff --git a/horizon/internal/externalcall/gcs_client.go b/horizon/internal/externalcall/gcs_client.go index 08a38ef5..2eed96a3 100644 --- a/horizon/internal/externalcall/gcs_client.go +++ b/horizon/internal/externalcall/gcs_client.go @@ -33,6 +33,7 @@ type GCSClientInterface interface { GetFolderInfo(bucket, folderPrefix string) (*GCSFolderInfo, error) ListFoldersWithTimestamp(bucket, prefix string) ([]GCSFolderInfo, error) FindFileWithSuffix(bucket, folderPath, suffix string) (bool, string, error) + ListFilesWithSuffix(bucket, folderPath, suffix string) ([]string, error) } const ( @@ -771,6 +772,39 @@ func (g *GCSClient) FindFileWithSuffix(bucket, folderPath, suffix string) (bool, return true, foundFile, nil } +// ListFilesWithSuffix returns all object paths (full GCS object keys) ending with the given +// suffix under folderPath. Unlike FindFileWithSuffix which returns only the first match, +// this method collects every matching file path. Directory markers are skipped. +func (g *GCSClient) ListFilesWithSuffix(bucket, folderPath, suffix string) ([]string, error) { + if g.client == nil { + return nil, fmt.Errorf("GCS client not initialized properly") + } + + if !strings.HasSuffix(folderPath, "/") { + folderPath += "/" + } + + log.Info().Msgf("Listing files with suffix '%s' in GCS bucket %s with prefix %s", suffix, bucket, folderPath) + + var files []string + err := g.forEachObject(bucket, folderPath, func(attrs *storage.ObjectAttrs) error { + // Skip directory markers + if strings.HasSuffix(attrs.Name, "/") { + return nil + } + if strings.HasSuffix(attrs.Name, suffix) { + files = append(files, attrs.Name) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to list files with suffix '%s': %w", suffix, err) + } + + log.Info().Msgf("Found %d files with suffix '%s' under %s/%s", len(files), suffix, bucket, folderPath) + return files, nil +} + // forEachObject iterates over all objects with the given prefix and calls the visitor for each. func (g *GCSClient) forEachObject(bucket, prefix string, visitor ObjectVisitor) error { if g.client == nil { diff --git a/horizon/internal/predator/handler/model.go b/horizon/internal/predator/handler/model.go index c05cb329..5d750423 100644 --- a/horizon/internal/predator/handler/model.go +++ b/horizon/internal/predator/handler/model.go @@ -44,7 +44,7 @@ type IOField struct { } type ConfigMapping struct { - ServiceDeployableID uint `json:"service_deployable_id"` + ServiceDeployableID uint `json:"service_deployable_id"` SourceModelName string `json:"source_model_name,omitempty"` } @@ -270,3 +270,17 @@ type TritonInputTensor struct { type TritonOutputTensor struct { Name string `json:"name"` } + +// fileViolationInfo groups all detected violations for a single Python file. +// Used by validateNoLoggerOrPrintStatements and buildBalancedViolationSummary +// to produce balanced, capped error messages. +type fileViolationInfo struct { + fileName string + details []string +} + +// funcScopeEntry represents a Python function scope on the tracking stack. +type funcScopeEntry struct { + name string + indent int +} diff --git a/horizon/internal/predator/handler/predator_constants.go b/horizon/internal/predator/handler/predator_constants.go index 4d304dbb..b8830838 100644 --- a/horizon/internal/predator/handler/predator_constants.go +++ b/horizon/internal/predator/handler/predator_constants.go @@ -52,7 +52,7 @@ const ( successDeleteRequestMsg = "Model deletion request raised successfully" fieldModelSourcePath = "model_source_path" fieldMetaData = "meta_data" - fieldDiscoveryConfigID = "discovery_config_id" + fieldDiscoveryConfigID = "discovery_config_id" fieldConfigMapping = "config_mapping" errReadConfigFileFormat = "failed to read config.pbtxt: %v" errUnmarshalProtoFormat = "failed to unmarshal proto text: %v" @@ -63,7 +63,7 @@ const ( errMaxBatchSizeMissing = "max_batch_size is missing or zero in config" errBackendMissing = "backend is missing in config" errNoInputDefinitions = "no input definitions found in config" - errNoOutputDefinitions = "no output definitions found in config" + errNoOutputDefinitions = "no output definitions found in config" errInstanceGroupMissing = "instance group is missing in config" errInvalidRequestIDFormat = "invalid group ID format" errFailedToFetchRequest = "failed to fetch request for group id %s" @@ -112,4 +112,13 @@ const ( predatorInferMethod = "inference.GRPCInferenceService/ModelInfer" deployableTagDelimiter = "_" scaleupTag = "scaleup" + errWarmupConfigMissing = "model_warmup configuration is missing in config.pbtxt for non-ensemble model; warmup is required to pre-load models and avoid cold-start latency" + errLoggerOrPrintStatementsFound = "model validation rejected: Python model files contain logger/print statements outside of initialize() and finalize() functions; these must be removed or commented out before deployment to prevent excessive logging in production" + pythonBackend = "python" + ensembleBackend = "ensemble" + ensemblePlatform = "ensemble" + pyFileSuffix = ".py" + allowedFuncInitialize = "initialize" + allowedFuncFinalize = "finalize" + maxDisplayedViolationsForPythonModel = 5 ) diff --git a/horizon/internal/predator/handler/predator_upload.go b/horizon/internal/predator/handler/predator_upload.go index 9015f352..5e48c29b 100644 --- a/horizon/internal/predator/handler/predator_upload.go +++ b/horizon/internal/predator/handler/predator_upload.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "net/http" "path" @@ -144,7 +145,11 @@ func (p *Predator) validateUploadPrerequisites(bucket, destPath string, isPartia return nil } -// validateSourceModel validates the source model structure and configuration +// validateSourceModel validates the source model structure and configuration. +// For full uploads it also validates: +// - Model configuration (config.pbtxt is valid and parseable) and warmup configuration is present for non-ensemble models +// - Complete model structure (version 1/ folder exists with files) +// - Python backend models do not contain logger/print statements outside initialize()/finalize() func (p *Predator) validateSourceModel(gcsPath string, isPartial bool) error { srcBucket, srcPath := extractGCSPath(gcsPath) if srcBucket == "" || srcPath == "" { @@ -159,6 +164,10 @@ func (p *Predator) validateSourceModel(gcsPath string, isPartial bool) error { if err := p.validateCompleteModelStructure(srcBucket, srcPath); err != nil { return fmt.Errorf("complete model structure validation failed: %w", err) } + + if err := p.validateNoLoggerOrPrintStatements(gcsPath); err != nil { + return fmt.Errorf("logger/print statement validation failed: %w", err) + } } return nil @@ -422,7 +431,9 @@ func (p *Predator) syncPartialFiles(gcsPath, destBucket, destPath, modelName str return nil } -// validateModelConfiguration validates the model configuration +// validateModelConfiguration validates the model configuration, including: +// 1. Parsing config.pbtxt as a valid protobuf ModelConfig +// 2. Checking that non-ensemble models have a warmup configuration defined func (p *Predator) validateModelConfiguration(gcsPath string) error { log.Info().Msgf("Validating model configuration for GCS path: %s", gcsPath) @@ -443,6 +454,16 @@ func (p *Predator) validateModelConfiguration(gcsPath string) error { } log.Info().Msgf("Parsed model config - Name: %s, Backend: %s", modelConfig.Name, modelConfig.Backend) + + if !isEnsembleModel(&modelConfig) { + if len(modelConfig.GetModelWarmup()) == 0 { + return errors.New(errWarmupConfigMissing) + } + log.Info().Msg("Warmup configuration validation passed for non-ensemble model") + } else { + log.Info().Msg("Skipping warmup validation for ensemble model") + } + return nil } @@ -528,3 +549,246 @@ func (p *Predator) replaceModelNameInConfigPreservingFormat(data []byte, destMod return []byte(strings.Join(lines, "\n")) } + +// validateNoLoggerOrPrintStatements checks that Python backend models do not contain +// uncommented logger.info, logger.debug, or print statements outside of initialize() +// and finalize() functions. +func (p *Predator) validateNoLoggerOrPrintStatements(gcsPath string) error { + srcBucket, srcPath := extractGCSPath(gcsPath) + if srcBucket == "" || srcPath == "" { + return fmt.Errorf("invalid GCS path format: %s", gcsPath) + } + + // Read and parse config.pbtxt to determine backend type. + configPath := path.Join(srcPath, configFile) + configData, err := p.GcsClient.ReadFile(srcBucket, configPath) + if err != nil { + return fmt.Errorf("failed to read config.pbtxt for logger/print check: %w", err) + } + + var modelConfig ModelConfig + if err := prototext.Unmarshal(configData, &modelConfig); err != nil { + return fmt.Errorf("failed to parse config.pbtxt for logger/print check: %w", err) + } + + if !isPythonBackendModel(&modelConfig) { + log.Info().Msg("Skipping logger/print statement check for non-Python backend model") + return nil + } + + log.Info().Msg("Python backend model detected, scanning .py files for logger/print statements") + + // Enumerate all .py files in the model source directory. + pyFiles, err := p.GcsClient.ListFilesWithSuffix(srcBucket, srcPath, pyFileSuffix) + if err != nil { + return fmt.Errorf("failed to list Python files for logger/print check: %w", err) + } + + if len(pyFiles) == 0 { + log.Info().Msg("No Python files found in model directory, skipping logger/print check") + return nil + } + + var perFileViolations []fileViolationInfo + totalViolations := 0 + + for _, pyFile := range pyFiles { + content, err := p.GcsClient.ReadFile(srcBucket, pyFile) + if err != nil { + return fmt.Errorf("failed to read Python file %s for logger/print check: %w", pyFile, err) + } + + found, details := hasPythonLoggerOrPrintStatements(content) + if found { + perFileViolations = append(perFileViolations, fileViolationInfo{ + fileName: path.Base(pyFile), + details: details, + }) + totalViolations += len(details) + } + } + + if totalViolations == 0 { + log.Info().Msg("Logger/print statement validation passed for Python backend model") + return nil + } + + log.Warn().Msgf("Found %d logger/print statement violation(s) across %d Python model file(s)", + totalViolations, len(perFileViolations)) + + return fmt.Errorf("%s; %s", + errLoggerOrPrintStatementsFound, + buildBalancedViolationSummary(perFileViolations, totalViolations)) +} + +// buildBalancedViolationSummary produces a human-readable error string that shows +// at most maxDisplayedViolations individual violation details, distributed fairly +// across all files using round-robin allocation. Each file entry includes its +// total violation count. +func buildBalancedViolationSummary(perFileViolations []fileViolationInfo, totalViolations int) string { + const maxDisplayedViolations = maxDisplayedViolationsForPythonModel + + numFiles := len(perFileViolations) + + slots := make([]int, numFiles) + remaining := maxDisplayedViolations + if totalViolations < remaining { + remaining = totalViolations + } + + for remaining > 0 { + advanced := false + for i := range perFileViolations { + if remaining <= 0 { + break + } + if slots[i] < len(perFileViolations[i].details) { + slots[i]++ + remaining-- + advanced = true + } + } + if !advanced { + break + } + } + + // Build a per-file summary: shown violations + "...and N more" if truncated. + var summaries []string + for i, fv := range perFileViolations { + var parts []string + for j := 0; j < slots[i]; j++ { + parts = append(parts, fv.details[j]) + } + extra := len(fv.details) - slots[i] + if extra > 0 { + parts = append(parts, fmt.Sprintf("...and %d more", extra)) + } + + label := "violations" + if len(fv.details) == 1 { + label = "violation" + } + summaries = append(summaries, fmt.Sprintf("%s (%d %s): %s", + fv.fileName, len(fv.details), label, strings.Join(parts, ", "))) + } + + return strings.Join(summaries, "; ") +} + +// hasPythonLoggerOrPrintStatements scans Python source code for uncommented +// logger.info, logger.debug, or print() statements that are NOT inside +// initialize() or finalize() functions. +func hasPythonLoggerOrPrintStatements(content []byte) (found bool, details []string) { + lines := strings.Split(string(content), "\n") + + funcDefPattern := regexp.MustCompile(`^def\s+(\w+)\s*\(`) + loggerPattern := regexp.MustCompile(`(?i)logger\.(info|debug)\s*\(`) + printPattern := regexp.MustCompile(`\bprint\s*\(`) + + var functionStack []funcScopeEntry + + for lineNum, line := range lines { + strippedLine := strings.TrimSpace(line) + + if strippedLine == "" { + continue + } + + if strings.HasPrefix(strippedLine, "#") { + continue + } + + lineIndent := len(line) - len(strings.TrimLeft(line, " \t")) + + for len(functionStack) > 0 && functionStack[len(functionStack)-1].indent >= lineIndent { + functionStack = functionStack[:len(functionStack)-1] + } + + if funcMatch := funcDefPattern.FindStringSubmatch(strippedLine); len(funcMatch) > 1 { + functionStack = append(functionStack, funcScopeEntry{name: funcMatch[1], indent: lineIndent}) + continue + } + + if isInsideAllowedFunction(functionStack) { + continue + } + + codePortion := stripInlineComment(line) + + if loggerPattern.MatchString(codePortion) { + details = append(details, fmt.Sprintf("line %d: %s", lineNum+1, strippedLine)) + found = true + } else if printPattern.MatchString(codePortion) { + details = append(details, fmt.Sprintf("line %d: %s", lineNum+1, strippedLine)) + found = true + } + } + + return found, details +} + +// isInsideAllowedFunction returns true if any entry on the function scope stack +// is initialize() or finalize(). +func isInsideAllowedFunction(stack []funcScopeEntry) bool { + for _, entry := range stack { + if entry.name == allowedFuncInitialize || entry.name == allowedFuncFinalize { + return true + } + } + return false +} + +// stripInlineComment removes a trailing Python inline comment (# ...) from a +// line while preserving '#' characters that appear inside string literals. +func stripInlineComment(line string) string { + inSingle := false + inDouble := false + escaped := false + for i := 0; i < len(line); i++ { + ch := line[i] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + switch ch { + case '\'': + if !inDouble { + inSingle = !inSingle + } + case '"': + if !inSingle { + inDouble = !inDouble + } + case '#': + if !inSingle && !inDouble { + return line[:i] + } + } + } + return line +} + +// isEnsembleModel checks if the model configuration represents a Triton ensemble model. +func isEnsembleModel(config *ModelConfig) bool { + if config.GetBackend() == ensembleBackend { + return true + } + if config.GetPlatform() == ensemblePlatform { + return true + } + + if config.GetEnsembleScheduling() != nil { + return true + } + return false +} + +// isPythonBackendModel checks if the model config specifies "python" as its backend. +func isPythonBackendModel(config *ModelConfig) bool { + return config.GetBackend() == pythonBackend +} diff --git a/horizon/internal/predator/handler/predator_upload_test.go b/horizon/internal/predator/handler/predator_upload_test.go new file mode 100644 index 00000000..1cf36efe --- /dev/null +++ b/horizon/internal/predator/handler/predator_upload_test.go @@ -0,0 +1,741 @@ +package handler + +import ( + "fmt" + "strings" + "testing" + + "github.com/Meesho/BharatMLStack/horizon/internal/externalcall" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockGCSClient implements externalcall.GCSClientInterface for unit tests. +// Only methods needed by the functions under test carry real logic; +// the rest return nil/zero to satisfy the interface. +type mockGCSClient struct { + readFile func(bucket, objectPath string) ([]byte, error) + listFilesWithSuffix func(bucket, folderPath, suffix string) ([]string, error) +} + +func (m *mockGCSClient) ReadFile(bucket, objectPath string) ([]byte, error) { + if m.readFile != nil { + return m.readFile(bucket, objectPath) + } + return nil, fmt.Errorf("ReadFile not mocked") +} + +func (m *mockGCSClient) ListFilesWithSuffix(bucket, folderPath, suffix string) ([]string, error) { + if m.listFilesWithSuffix != nil { + return m.listFilesWithSuffix(bucket, folderPath, suffix) + } + return nil, fmt.Errorf("ListFilesWithSuffix not mocked") +} + +func (m *mockGCSClient) TransferFolder(_, _, _, _, _, _ string) error { return nil } +func (m *mockGCSClient) TransferAndDeleteFolder(_, _, _, _, _, _ string) error { return nil } +func (m *mockGCSClient) TransferFolderWithSplitSources(_, _, _, _, _, _, _, _ string) error { + return nil +} +func (m *mockGCSClient) DeleteFolder(_, _, _ string) error { return nil } +func (m *mockGCSClient) ListFolders(_, _ string) ([]string, error) { return nil, nil } +func (m *mockGCSClient) UploadFile(_, _ string, _ []byte) error { return nil } +func (m *mockGCSClient) CheckFileExists(_, _ string) (bool, error) { return false, nil } +func (m *mockGCSClient) CheckFolderExists(_, _ string) (bool, error) { return false, nil } +func (m *mockGCSClient) UploadFolderFromLocal(_, _, _ string) error { return nil } +func (m *mockGCSClient) GetFolderInfo(_, _ string) (*externalcall.GCSFolderInfo, error) { + return nil, nil +} +func (m *mockGCSClient) ListFoldersWithTimestamp(_, _ string) ([]externalcall.GCSFolderInfo, error) { + return nil, nil +} +func (m *mockGCSClient) FindFileWithSuffix(_, _, _ string) (bool, string, error) { + return false, "", nil +} + +// Tests for isEnsembleModel +func TestIsEnsembleModel(t *testing.T) { + tests := []struct { + name string + config *ModelConfig + want bool + }{ + { + name: "backend is ensemble", + config: &ModelConfig{Backend: "ensemble"}, + want: true, + }, + { + name: "platform is ensemble", + config: &ModelConfig{Platform: "ensemble"}, + want: true, + }, + { + name: "ensemble_scheduling is set", + config: &ModelConfig{ + Backend: "python", + SchedulingChoice: &ModelConfig_EnsembleScheduling{ + EnsembleScheduling: &ModelEnsembling{}, + }, + }, + want: true, + }, + { + name: "python backend - not ensemble", + config: &ModelConfig{Backend: "python"}, + want: false, + }, + { + name: "onnxruntime backend - not ensemble", + config: &ModelConfig{Backend: "onnxruntime"}, + want: false, + }, + { + name: "empty config - not ensemble", + config: &ModelConfig{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isEnsembleModel(tt.config) + assert.Equal(t, tt.want, got) + }) + } +} + +// Tests for isPythonBackendModel +func TestIsPythonBackendModel(t *testing.T) { + tests := []struct { + name string + config *ModelConfig + want bool + }{ + { + name: "python backend", + config: &ModelConfig{Backend: "python"}, + want: true, + }, + { + name: "onnxruntime backend", + config: &ModelConfig{Backend: "onnxruntime"}, + want: false, + }, + { + name: "tensorrt_llm backend", + config: &ModelConfig{Backend: "tensorrt_llm"}, + want: false, + }, + { + name: "ensemble backend", + config: &ModelConfig{Backend: "ensemble"}, + want: false, + }, + { + name: "empty backend", + config: &ModelConfig{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isPythonBackendModel(tt.config) + assert.Equal(t, tt.want, got) + }) + } +} + +// Tests for hasPythonLoggerOrPrintStatements +func TestHasPythonLoggerOrPrintStatements(t *testing.T) { + tests := []struct { + name string + content string + wantFound bool + wantDetails int // expected number of violations + }{ + { + name: "clean file with no violations", + content: "import numpy as np\n\ndef execute(requests):\n result = np.array([1, 2])\n return result\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "logger.info outside any function", + content: "import triton\nlogger.info('loading model')\n", + wantFound: true, + wantDetails: 1, + }, + { + name: "print statement outside any function", + content: "import triton\nprint('hello')\n", + wantFound: true, + wantDetails: 1, + }, + { + name: "logger.debug inside execute function", + content: "def execute(requests):\n logger.debug('processing')\n return []\n", + wantFound: true, + wantDetails: 1, + }, + { + name: "logger.info inside initialize is allowed", + content: "def initialize(self, args):\n logger.info('model loaded')\n self.model = load()\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "print inside finalize is allowed", + content: "def finalize(self):\n print('cleanup done')\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "logger in initialize allowed but in execute flagged", + content: `def initialize(self, args): + logger.info('init ok') + +def execute(self, requests): + logger.info('processing request') + return [] +`, + wantFound: true, + wantDetails: 1, + }, + { + name: "commented out logger.info is skipped", + content: "def execute(requests):\n # logger.info('debug')\n return []\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "commented out print is skipped", + content: "# print('debug')\ndef execute(requests):\n return []\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "fingerprint should not match print pattern", + content: "def execute(requests):\n fp = fingerprint(data)\n return fp\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "blueprint should not match print pattern", + content: "def execute(requests):\n bp = blueprint(config)\n return bp\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "multiple violations across functions", + content: `import os + +def helper(): + print('helper debug') + logger.info('helper info') + +def execute(requests): + print('exec debug') + return [] +`, + wantFound: true, + wantDetails: 3, + }, + { + name: "logger.INFO case insensitive match", + content: "def execute(requests):\n LOGGER.INFO('upper case')\n return []\n", + wantFound: true, + wantDetails: 1, + }, + { + name: "scope exit resets after initialize", + content: `def initialize(self, args): + logger.info('init ok') + +logger.info('module level - should flag') +`, + wantFound: true, + wantDetails: 1, + }, + { + name: "empty file", + content: "", + wantFound: false, + wantDetails: 0, + }, + { + name: "nested def inside initialize is allowed", + content: `def initialize(self, args): + def _load_model(): + logger.info('loading weights') + print('progress') + _load_model() +`, + wantFound: false, + wantDetails: 0, + }, + { + name: "nested def inside finalize is allowed", + content: `def finalize(self): + def _cleanup(): + logger.info('releasing resources') + _cleanup() +`, + wantFound: false, + wantDetails: 0, + }, + { + name: "nested def inside execute is still flagged", + content: `def execute(self, requests): + def _process(): + logger.info('processing batch') + _process() + return [] +`, + wantFound: true, + wantDetails: 1, + }, + { + name: "inline comment with logger is not flagged", + content: "def execute(requests):\n x = compute() # logger.info('debug')\n return x\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "inline comment with print is not flagged", + content: "def execute(requests):\n x = compute() # print('debug')\n return x\n", + wantFound: false, + wantDetails: 0, + }, + { + name: "hash inside string literal is not treated as comment", + content: "def execute(requests):\n print('value is # 42')\n return []\n", + wantFound: true, + wantDetails: 1, + }, + { + name: "line matching both logger and print reports once", + content: "def execute(requests):\n print(logger.info('x'))\n return []\n", + wantFound: true, + wantDetails: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + found, details := hasPythonLoggerOrPrintStatements([]byte(tt.content)) + assert.Equal(t, tt.wantFound, found) + assert.Equal(t, tt.wantDetails, len(details), "violation details: %v", details) + }) + } +} + +// Tests for buildBalancedViolationSummary +func TestBuildBalancedViolationSummary(t *testing.T) { + tests := []struct { + name string + violations []fileViolationInfo + totalViolations int + wantContains []string + wantNotContain []string + }{ + { + name: "single file single violation", + violations: []fileViolationInfo{ + {fileName: "model.py", details: []string{"line 5: print('x')"}}, + }, + totalViolations: 1, + wantContains: []string{"model.py (1 violation)", "line 5: print('x')"}, + }, + { + name: "single file multiple violations within cap", + violations: []fileViolationInfo{ + {fileName: "model.py", details: []string{ + "line 5: print('a')", + "line 10: print('b')", + "line 15: print('c')", + }}, + }, + totalViolations: 3, + wantContains: []string{"model.py (3 violations)", "line 5", "line 10", "line 15"}, + }, + { + name: "single file exceeds cap shows truncation", + violations: []fileViolationInfo{ + {fileName: "model.py", details: []string{ + "line 1: print('a')", + "line 2: print('b')", + "line 3: print('c')", + "line 4: print('d')", + "line 5: print('e')", + "line 6: print('f')", + "line 7: print('g')", + }}, + }, + totalViolations: 7, + wantContains: []string{"model.py (7 violations)", "...and 2 more"}, + }, + { + name: "two files with round-robin distribution", + violations: []fileViolationInfo{ + {fileName: "a.py", details: []string{"line 1: print('a1')", "line 2: print('a2')", "line 3: print('a3')"}}, + {fileName: "b.py", details: []string{"line 10: print('b1')", "line 20: print('b2')", "line 30: print('b3')"}}, + }, + totalViolations: 6, + wantContains: []string{"a.py", "b.py"}, + }, + { + name: "uses singular 'violation' for single-violation file", + violations: []fileViolationInfo{ + {fileName: "single.py", details: []string{"line 1: print('x')"}}, + }, + totalViolations: 1, + wantContains: []string{"1 violation)"}, + wantNotContain: []string{"1 violations)"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildBalancedViolationSummary(tt.violations, tt.totalViolations) + for _, want := range tt.wantContains { + assert.Contains(t, got, want) + } + for _, notWant := range tt.wantNotContain { + assert.NotContains(t, got, notWant) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Tests for validateModelConfiguration (warmup check via GCS mock) +// --------------------------------------------------------------------------- + +func TestValidateModelConfiguration_WarmupCheck(t *testing.T) { + pythonWithWarmup := `name: "my_model" +backend: "python" +max_batch_size: 8 +model_warmup { + name: "warmup_request" + batch_size: 1 + inputs { + key: "input" + value: { + data_type: TYPE_FP32 + dims: [1, 3] + zero_data: true + } + } +} +` + + pythonWithoutWarmup := `name: "my_model" +backend: "python" +max_batch_size: 8 +` + + ensembleConfig := `name: "ensemble_model" +platform: "ensemble" +max_batch_size: 8 +ensemble_scheduling { + step { + model_name: "preprocess" + model_version: 1 + } +} +` + + ensembleBackendConfig := `name: "ensemble_model" +backend: "ensemble" +max_batch_size: 8 +` + + tests := []struct { + name string + configData string + wantErr bool + errContain string + }{ + { + name: "python model with warmup passes", + configData: pythonWithWarmup, + wantErr: false, + }, + { + name: "python model without warmup fails", + configData: pythonWithoutWarmup, + wantErr: true, + errContain: "model_warmup configuration is missing", + }, + { + name: "ensemble model skips warmup check", + configData: ensembleConfig, + wantErr: false, + }, + { + name: "ensemble backend model skips warmup check", + configData: ensembleBackendConfig, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gcs := &mockGCSClient{ + readFile: func(bucket, objectPath string) ([]byte, error) { + return []byte(tt.configData), nil + }, + } + + p := &Predator{GcsClient: gcs} + err := p.validateModelConfiguration("gs://test-bucket/models/my_model") + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContain) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateModelConfiguration_InvalidGCSPath(t *testing.T) { + p := &Predator{} + err := p.validateModelConfiguration("invalid-path") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid GCS path format") +} + +func TestValidateModelConfiguration_ReadFileError(t *testing.T) { + gcs := &mockGCSClient{ + readFile: func(_, _ string) ([]byte, error) { + return nil, fmt.Errorf("bucket not found") + }, + } + p := &Predator{GcsClient: gcs} + err := p.validateModelConfiguration("gs://test-bucket/models/my_model") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read config.pbtxt") +} + +func TestValidateModelConfiguration_InvalidProto(t *testing.T) { + gcs := &mockGCSClient{ + readFile: func(_, _ string) ([]byte, error) { + return []byte("this is not valid proto text {{{"), nil + }, + } + p := &Predator{GcsClient: gcs} + err := p.validateModelConfiguration("gs://test-bucket/models/my_model") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse config.pbtxt") +} + +// Tests for validateNoLoggerOrPrintStatements (integration with GCS mock) +func TestValidateNoLoggerOrPrintStatements_NonPythonBackend(t *testing.T) { + configData := `name: "onnx_model" +backend: "onnxruntime" +max_batch_size: 8 +` + gcs := &mockGCSClient{ + readFile: func(_, objectPath string) ([]byte, error) { + if strings.HasSuffix(objectPath, "config.pbtxt") { + return []byte(configData), nil + } + return nil, fmt.Errorf("unexpected read: %s", objectPath) + }, + } + + p := &Predator{GcsClient: gcs} + err := p.validateNoLoggerOrPrintStatements("gs://test-bucket/models/onnx_model") + require.NoError(t, err) +} + +func TestValidateNoLoggerOrPrintStatements_PythonBackendClean(t *testing.T) { + configData := `name: "py_model" +backend: "python" +max_batch_size: 8 +` + cleanPython := `import numpy as np + +def initialize(self, args): + logger.info('init') + +def execute(self, requests): + result = np.array([1, 2, 3]) + return result + +def finalize(self): + print('done') +` + + gcs := &mockGCSClient{ + readFile: func(_, objectPath string) ([]byte, error) { + if strings.HasSuffix(objectPath, "config.pbtxt") { + return []byte(configData), nil + } + return []byte(cleanPython), nil + }, + listFilesWithSuffix: func(_, _, _ string) ([]string, error) { + return []string{"models/py_model/1/model.py"}, nil + }, + } + + p := &Predator{GcsClient: gcs} + err := p.validateNoLoggerOrPrintStatements("gs://test-bucket/models/py_model") + require.NoError(t, err) +} + +func TestValidateNoLoggerOrPrintStatements_PythonBackendWithViolations(t *testing.T) { + configData := `name: "py_model" +backend: "python" +max_batch_size: 8 +` + dirtyPython := `import numpy as np + +def execute(self, requests): + logger.info('processing') + print('debug output') + return [] +` + + gcs := &mockGCSClient{ + readFile: func(_, objectPath string) ([]byte, error) { + if strings.HasSuffix(objectPath, "config.pbtxt") { + return []byte(configData), nil + } + return []byte(dirtyPython), nil + }, + listFilesWithSuffix: func(_, _, _ string) ([]string, error) { + return []string{"models/py_model/1/model.py"}, nil + }, + } + + p := &Predator{GcsClient: gcs} + err := p.validateNoLoggerOrPrintStatements("gs://test-bucket/models/py_model") + require.Error(t, err) + assert.Contains(t, err.Error(), "logger/print statements") + assert.Contains(t, err.Error(), "model.py") +} + +func TestValidateNoLoggerOrPrintStatements_NoPythonFiles(t *testing.T) { + configData := `name: "py_model" +backend: "python" +max_batch_size: 8 +` + gcs := &mockGCSClient{ + readFile: func(_, objectPath string) ([]byte, error) { + if strings.HasSuffix(objectPath, "config.pbtxt") { + return []byte(configData), nil + } + return nil, fmt.Errorf("not found") + }, + listFilesWithSuffix: func(_, _, _ string) ([]string, error) { + return []string{}, nil + }, + } + + p := &Predator{GcsClient: gcs} + err := p.validateNoLoggerOrPrintStatements("gs://test-bucket/models/py_model") + require.NoError(t, err) +} + +func TestValidateNoLoggerOrPrintStatements_InvalidGCSPath(t *testing.T) { + p := &Predator{} + err := p.validateNoLoggerOrPrintStatements("invalid-path") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid GCS path format") +} + +func TestValidateNoLoggerOrPrintStatements_MultipleFiles(t *testing.T) { + configData := `name: "py_model" +backend: "python" +max_batch_size: 8 +` + fileContents := map[string]string{ + "models/py_model/config.pbtxt": configData, + "models/py_model/1/model.py": `def execute(self, requests): + print('debug in model') + return [] +`, + "models/py_model/1/utils.py": `def helper(): + logger.info('helper log') + return True +`, + } + + gcs := &mockGCSClient{ + readFile: func(_, objectPath string) ([]byte, error) { + if content, ok := fileContents[objectPath]; ok { + return []byte(content), nil + } + return nil, fmt.Errorf("file not found: %s", objectPath) + }, + listFilesWithSuffix: func(_, _, _ string) ([]string, error) { + return []string{"models/py_model/1/model.py", "models/py_model/1/utils.py"}, nil + }, + } + + p := &Predator{GcsClient: gcs} + err := p.validateNoLoggerOrPrintStatements("gs://test-bucket/models/py_model") + require.Error(t, err) + errMsg := err.Error() + assert.Contains(t, errMsg, "model.py") + assert.Contains(t, errMsg, "utils.py") +} + +func TestStripInlineComment(t *testing.T) { + tests := []struct { + name string + line string + want string + }{ + {name: "no comment", line: "x = 1", want: "x = 1"}, + {name: "trailing comment", line: "x = 1 # set x", want: "x = 1 "}, + {name: "hash in single quotes", line: "x = 'a#b'", want: "x = 'a#b'"}, + {name: "hash in double quotes", line: `x = "a#b"`, want: `x = "a#b"`}, + {name: "hash after quoted string", line: `x = "a#b" # comment`, want: `x = "a#b" `}, + {name: "escaped quote", line: `x = 'it\'s' # note`, want: `x = 'it\'s' `}, + {name: "full line comment", line: "# comment", want: ""}, + {name: "empty line", line: "", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripInlineComment(tt.line) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestIsInsideAllowedFunction(t *testing.T) { + tests := []struct { + name string + stack []funcScopeEntry + want bool + }{ + {name: "empty stack", stack: nil, want: false}, + {name: "non-allowed function", stack: []funcScopeEntry{{name: "execute", indent: 0}}, want: false}, + {name: "initialize on stack", stack: []funcScopeEntry{{name: "initialize", indent: 0}}, want: true}, + {name: "finalize on stack", stack: []funcScopeEntry{{name: "finalize", indent: 0}}, want: true}, + { + name: "nested inside initialize", + stack: []funcScopeEntry{{name: "initialize", indent: 0}, {name: "helper", indent: 4}}, + want: true, + }, + { + name: "nested inside non-allowed", + stack: []funcScopeEntry{{name: "execute", indent: 0}, {name: "helper", indent: 4}}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isInsideAllowedFunction(tt.stack) + assert.Equal(t, tt.want, got) + }) + } +}