From 6a8ca68205afdde89a189e83f6b682354986b8ce Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 30 Jul 2025 15:46:14 +0000 Subject: [PATCH 1/7] Initial plan From 309b6d94da1078f087d7679c7b5c1a49861ce9c7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 30 Jul 2025 16:06:49 +0000 Subject: [PATCH 2/7] Add --var template variable support to generate command Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- cmd/generate/generate.go | 77 +++++++++--- cmd/generate/generate_test.go | 218 ++++++++++++++++++++++++++++++++++ cmd/generate/pipeline.go | 8 ++ 3 files changed, 289 insertions(+), 14 deletions(-) diff --git a/cmd/generate/generate.go b/cmd/generate/generate.go index be2cf91f..964adc00 100644 --- a/cmd/generate/generate.go +++ b/cmd/generate/generate.go @@ -4,22 +4,25 @@ package generate import ( "context" "fmt" + "strings" "github.com/MakeNowJust/heredoc" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" + "github.com/spf13/pflag" ) type generateCommandHandler struct { - ctx context.Context - cfg *command.Config - client azuremodels.Client - options *PromptPexOptions - promptFile string - org string - sessionFile *string + ctx context.Context + cfg *command.Config + client azuremodels.Client + options *PromptPexOptions + promptFile string + org string + sessionFile *string + templateVars map[string]string } // NewGenerateCommand returns a new command to generate tests using PromptPex. @@ -37,6 +40,7 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { gh models generate prompt.yml gh models generate --org my-org --groundtruth-model "openai/gpt-4.1" prompt.yml gh models generate --session-file prompt.session.json prompt.yml + gh models generate --var name=Alice --var topic="machine learning" prompt.yml `), Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { @@ -50,6 +54,12 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { return fmt.Errorf("failed to parse flags: %w", err) } + // Parse template variables from flags + templateVars, err := parseTemplateVariables(cmd.Flags()) + if err != nil { + return err + } + // Get organization org, _ := cmd.Flags().GetString("org") @@ -67,13 +77,14 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { // Create the command handler handler := &generateCommandHandler{ - ctx: ctx, - cfg: cfg, - client: cfg.Client, - options: options, - promptFile: promptFile, - org: org, - sessionFile: util.Ptr(sessionFile), + ctx: ctx, + cfg: cfg, + client: cfg.Client, + options: options, + promptFile: promptFile, + org: org, + sessionFile: util.Ptr(sessionFile), + templateVars: templateVars, } // Create context @@ -105,6 +116,7 @@ func AddCommandLineFlags(cmd *cobra.Command) { flags.String("effort", "", "Effort level (low, medium, high)") flags.String("groundtruth-model", "", "Model to use for generating groundtruth outputs. Defaults to openai/gpt-4o. Use 'none' to disable groundtruth generation.") flags.String("session-file", "", "Session file to load existing context from") + flags.StringSlice("var", []string{}, "Template variables for prompt files (can be used multiple times: --var name=value)") // Custom instruction flags for each phase flags.String("instruction-intent", "", "Custom system instruction for intent generation phase") @@ -162,3 +174,40 @@ func ParseFlags(cmd *cobra.Command, options *PromptPexOptions) error { return nil } + +// parseTemplateVariables parses template variables from the --var flags +func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { + varFlags, err := flags.GetStringSlice("var") + if err != nil { + return nil, err + } + + templateVars := make(map[string]string) + for _, varFlag := range varFlags { + // Handle empty strings + if strings.TrimSpace(varFlag) == "" { + continue + } + + parts := strings.SplitN(varFlag, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag) + } + + key := strings.TrimSpace(parts[0]) + value := parts[1] // Don't trim value to preserve intentional whitespace + + if key == "" { + return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) + } + + // Check for duplicate keys + if _, exists := templateVars[key]; exists { + return nil, fmt.Errorf("duplicate variable key '%s'", key) + } + + templateVars[key] = value + } + + return templateVars, nil +} diff --git a/cmd/generate/generate_test.go b/cmd/generate/generate_test.go index 05e05cbd..0867941b 100644 --- a/cmd/generate/generate_test.go +++ b/cmd/generate/generate_test.go @@ -11,7 +11,9 @@ import ( "testing" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/pkg/command" + "github.com/spf13/pflag" "github.com/stretchr/testify/require" ) @@ -393,3 +395,219 @@ messages: require.Contains(t, err.Error(), "failed to load prompt file") }) } + +func TestParseTemplateVariables(t *testing.T) { + tests := []struct { + name string + varFlags []string + expected map[string]string + expectErr bool + }{ + { + name: "empty flags", + varFlags: []string{}, + expected: map[string]string{}, + }, + { + name: "single variable", + varFlags: []string{"name=Alice"}, + expected: map[string]string{"name": "Alice"}, + }, + { + name: "multiple variables", + varFlags: []string{"name=Alice", "age=30", "city=Boston"}, + expected: map[string]string{"name": "Alice", "age": "30", "city": "Boston"}, + }, + { + name: "variable with spaces in value", + varFlags: []string{"description=Hello World"}, + expected: map[string]string{"description": "Hello World"}, + }, + { + name: "variable with equals in value", + varFlags: []string{"equation=x=y+1"}, + expected: map[string]string{"equation": "x=y+1"}, + }, + { + name: "variable with empty value", + varFlags: []string{"empty="}, + expected: map[string]string{"empty": ""}, + }, + { + name: "variable with whitespace around key", + varFlags: []string{" name =Alice"}, + expected: map[string]string{"name": "Alice"}, + }, + { + name: "preserve whitespace in value", + varFlags: []string{"message= Hello World "}, + expected: map[string]string{"message": " Hello World "}, + }, + { + name: "empty string flag is ignored", + varFlags: []string{"", "name=Alice"}, + expected: map[string]string{"name": "Alice"}, + expectErr: false, + }, + { + name: "whitespace only flag is ignored", + varFlags: []string{" ", "name=Alice"}, + expected: map[string]string{"name": "Alice"}, + expectErr: false, + }, + { + name: "missing equals sign", + varFlags: []string{"name"}, + expectErr: true, + }, + { + name: "missing equals sign with multiple vars", + varFlags: []string{"name=Alice", "age"}, + expectErr: true, + }, + { + name: "empty key", + varFlags: []string{"=value"}, + expectErr: true, + }, + { + name: "whitespace only key", + varFlags: []string{" =value"}, + expectErr: true, + }, + { + name: "duplicate keys", + varFlags: []string{"name=Alice", "name=Bob"}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.StringSlice("var", tt.varFlags, "test flag") + + result, err := parseTemplateVariables(flags) + + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +} + +func TestGenerateCommandWithTemplateVariables(t *testing.T) { + t.Run("parse template variables in command handler", func(t *testing.T) { + client := azuremodels.NewMockClient() + cfg := command.NewConfig(new(bytes.Buffer), new(bytes.Buffer), client, true, 100) + + cmd := NewGenerateCommand(cfg) + args := []string{ + "--var", "name=Bob", + "--var", "location=Seattle", + "dummy.yml", + } + + // Parse flags without executing + err := cmd.ParseFlags(args[:len(args)-1]) // Exclude positional arg + require.NoError(t, err) + + // Test that the parseTemplateVariables function works correctly + templateVars, err := parseTemplateVariables(cmd.Flags()) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "name": "Bob", + "location": "Seattle", + }, templateVars) + }) + + t.Run("runSingleTestWithContext applies template variables", func(t *testing.T) { + // Create test prompt file with template variables + const yamlBody = ` +name: Template Variable Test +description: Test prompt with template variables +model: openai/gpt-4o-mini +messages: + - role: system + content: "You are a helpful assistant for {{name}}." + - role: user + content: "Tell me about {{topic}} in {{style}} style." +` + + tmpDir := t.TempDir() + promptFile := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFile, []byte(yamlBody), 0644) + require.NoError(t, err) + + // Setup mock client to capture template-rendered messages + var capturedOptions azuremodels.ChatCompletionOptions + client := azuremodels.NewMockClient() + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { + capturedOptions = opt + + // Create a proper mock response with reader + mockResponse := "test response" + mockCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{ + { + Message: &azuremodels.ChatChoiceMessage{ + Content: &mockResponse, + }, + }, + }, + } + + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{mockCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + + // Create handler with template variables + templateVars := map[string]string{ + "name": "Alice", + "topic": "machine learning", + "style": "academic", + } + + handler := &generateCommandHandler{ + ctx: context.Background(), + cfg: cfg, + client: client, + options: GetDefaultOptions(), + promptFile: promptFile, + org: "", + templateVars: templateVars, + } + + // Create context from prompt + promptCtx, err := handler.CreateContextFromPrompt() + require.NoError(t, err) + + // Call runSingleTestWithContext directly + _, err = handler.runSingleTestWithContext("test input", "openai/gpt-4o-mini", promptCtx) + require.NoError(t, err) + + // Verify that template variables were applied correctly + require.NotNil(t, capturedOptions.Messages) + require.Len(t, capturedOptions.Messages, 2) + + // Check system message + systemMsg := capturedOptions.Messages[0] + require.Equal(t, azuremodels.ChatMessageRoleSystem, systemMsg.Role) + require.NotNil(t, systemMsg.Content) + require.Contains(t, *systemMsg.Content, "helpful assistant for Alice") + + // Check user message + userMsg := capturedOptions.Messages[1] + require.Equal(t, azuremodels.ChatMessageRoleUser, userMsg.Role) + require.NotNil(t, userMsg.Content) + require.Contains(t, *userMsg.Content, "about machine learning") + require.Contains(t, *userMsg.Content, "academic style") + }) +} diff --git a/cmd/generate/pipeline.go b/cmd/generate/pipeline.go index 1a6615cd..f8ea7d6c 100644 --- a/cmd/generate/pipeline.go +++ b/cmd/generate/pipeline.go @@ -460,7 +460,15 @@ func (h *generateCommandHandler) runSingleTestWithContext(input string, modelNam openaiMessages := []azuremodels.ChatMessage{} for _, msg := range messages { templateData := make(map[string]interface{}) + + // Add the input variable (backward compatibility) templateData["input"] = input + + // Add custom variables + for key, value := range h.templateVars { + templateData[key] = value + } + // Replace template variables in content content, err := prompt.TemplateString(msg.Content, templateData) if err != nil { From e04b68610fea7f5c6384ed4e70fbb7f128441494 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 30 Jul 2025 16:08:52 +0000 Subject: [PATCH 3/7] Apply gofmt formatting fixes --- cmd/generate/generate_test.go | 4 ++-- cmd/generate/pipeline.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/generate/generate_test.go b/cmd/generate/generate_test.go index 0867941b..6e9aa01f 100644 --- a/cmd/generate/generate_test.go +++ b/cmd/generate/generate_test.go @@ -547,7 +547,7 @@ messages: client := azuremodels.NewMockClient() client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { capturedOptions = opt - + // Create a proper mock response with reader mockResponse := "test response" mockCompletion := azuremodels.ChatCompletion{ @@ -559,7 +559,7 @@ messages: }, }, } - + return &azuremodels.ChatCompletionResponse{ Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{mockCompletion}), }, nil diff --git a/cmd/generate/pipeline.go b/cmd/generate/pipeline.go index f8ea7d6c..673782f9 100644 --- a/cmd/generate/pipeline.go +++ b/cmd/generate/pipeline.go @@ -460,15 +460,15 @@ func (h *generateCommandHandler) runSingleTestWithContext(input string, modelNam openaiMessages := []azuremodels.ChatMessage{} for _, msg := range messages { templateData := make(map[string]interface{}) - + // Add the input variable (backward compatibility) templateData["input"] = input - + // Add custom variables for key, value := range h.templateVars { templateData[key] = value } - + // Replace template variables in content content, err := prompt.TemplateString(msg.Content, templateData) if err != nil { From 808ccb4ff5947d6cb961a518584a7fba4c9bbf67 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 30 Jul 2025 16:34:44 +0000 Subject: [PATCH 4/7] Refactor parseTemplateVariables into shared utility function Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- cmd/generate/generate.go | 41 +------------ cmd/generate/generate_test.go | 109 +-------------------------------- cmd/run/run.go | 39 +----------- cmd/run/run_test.go | 2 +- pkg/util/util.go | 40 ++++++++++++ pkg/util/util_test.go | 111 ++++++++++++++++++++++++++++++++++ 6 files changed, 157 insertions(+), 185 deletions(-) create mode 100644 pkg/util/util_test.go diff --git a/cmd/generate/generate.go b/cmd/generate/generate.go index 964adc00..260dfe7c 100644 --- a/cmd/generate/generate.go +++ b/cmd/generate/generate.go @@ -4,14 +4,12 @@ package generate import ( "context" "fmt" - "strings" "github.com/MakeNowJust/heredoc" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" - "github.com/spf13/pflag" ) type generateCommandHandler struct { @@ -55,7 +53,7 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { } // Parse template variables from flags - templateVars, err := parseTemplateVariables(cmd.Flags()) + templateVars, err := util.ParseTemplateVariables(cmd.Flags()) if err != nil { return err } @@ -174,40 +172,3 @@ func ParseFlags(cmd *cobra.Command, options *PromptPexOptions) error { return nil } - -// parseTemplateVariables parses template variables from the --var flags -func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { - varFlags, err := flags.GetStringSlice("var") - if err != nil { - return nil, err - } - - templateVars := make(map[string]string) - for _, varFlag := range varFlags { - // Handle empty strings - if strings.TrimSpace(varFlag) == "" { - continue - } - - parts := strings.SplitN(varFlag, "=", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag) - } - - key := strings.TrimSpace(parts[0]) - value := parts[1] // Don't trim value to preserve intentional whitespace - - if key == "" { - return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) - } - - // Check for duplicate keys - if _, exists := templateVars[key]; exists { - return nil, fmt.Errorf("duplicate variable key '%s'", key) - } - - templateVars[key] = value - } - - return templateVars, nil -} diff --git a/cmd/generate/generate_test.go b/cmd/generate/generate_test.go index 6e9aa01f..da982505 100644 --- a/cmd/generate/generate_test.go +++ b/cmd/generate/generate_test.go @@ -13,7 +13,7 @@ import ( "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/pkg/command" - "github.com/spf13/pflag" + "github.com/github/gh-models/pkg/util" "github.com/stretchr/testify/require" ) @@ -396,109 +396,6 @@ messages: }) } -func TestParseTemplateVariables(t *testing.T) { - tests := []struct { - name string - varFlags []string - expected map[string]string - expectErr bool - }{ - { - name: "empty flags", - varFlags: []string{}, - expected: map[string]string{}, - }, - { - name: "single variable", - varFlags: []string{"name=Alice"}, - expected: map[string]string{"name": "Alice"}, - }, - { - name: "multiple variables", - varFlags: []string{"name=Alice", "age=30", "city=Boston"}, - expected: map[string]string{"name": "Alice", "age": "30", "city": "Boston"}, - }, - { - name: "variable with spaces in value", - varFlags: []string{"description=Hello World"}, - expected: map[string]string{"description": "Hello World"}, - }, - { - name: "variable with equals in value", - varFlags: []string{"equation=x=y+1"}, - expected: map[string]string{"equation": "x=y+1"}, - }, - { - name: "variable with empty value", - varFlags: []string{"empty="}, - expected: map[string]string{"empty": ""}, - }, - { - name: "variable with whitespace around key", - varFlags: []string{" name =Alice"}, - expected: map[string]string{"name": "Alice"}, - }, - { - name: "preserve whitespace in value", - varFlags: []string{"message= Hello World "}, - expected: map[string]string{"message": " Hello World "}, - }, - { - name: "empty string flag is ignored", - varFlags: []string{"", "name=Alice"}, - expected: map[string]string{"name": "Alice"}, - expectErr: false, - }, - { - name: "whitespace only flag is ignored", - varFlags: []string{" ", "name=Alice"}, - expected: map[string]string{"name": "Alice"}, - expectErr: false, - }, - { - name: "missing equals sign", - varFlags: []string{"name"}, - expectErr: true, - }, - { - name: "missing equals sign with multiple vars", - varFlags: []string{"name=Alice", "age"}, - expectErr: true, - }, - { - name: "empty key", - varFlags: []string{"=value"}, - expectErr: true, - }, - { - name: "whitespace only key", - varFlags: []string{" =value"}, - expectErr: true, - }, - { - name: "duplicate keys", - varFlags: []string{"name=Alice", "name=Bob"}, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - flags := pflag.NewFlagSet("test", pflag.ContinueOnError) - flags.StringSlice("var", tt.varFlags, "test flag") - - result, err := parseTemplateVariables(flags) - - if tt.expectErr { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expected, result) - } - }) - } -} - func TestGenerateCommandWithTemplateVariables(t *testing.T) { t.Run("parse template variables in command handler", func(t *testing.T) { client := azuremodels.NewMockClient() @@ -515,8 +412,8 @@ func TestGenerateCommandWithTemplateVariables(t *testing.T) { err := cmd.ParseFlags(args[:len(args)-1]) // Exclude positional arg require.NoError(t, err) - // Test that the parseTemplateVariables function works correctly - templateVars, err := parseTemplateVariables(cmd.Flags()) + // Test that the util.ParseTemplateVariables function works correctly + templateVars, err := util.ParseTemplateVariables(cmd.Flags()) require.NoError(t, err) require.Equal(t, map[string]string{ "name": "Bob", diff --git a/cmd/run/run.go b/cmd/run/run.go index 2d90da4f..6a33218f 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -236,7 +236,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } // Parse template variables from flags - templateVars, err := parseTemplateVariables(cmd.Flags()) + templateVars, err := util.ParseTemplateVariables(cmd.Flags()) if err != nil { return err } @@ -427,43 +427,6 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { return cmd } -// parseTemplateVariables parses template variables from the --var flags -func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { - varFlags, err := flags.GetStringSlice("var") - if err != nil { - return nil, err - } - - templateVars := make(map[string]string) - for _, varFlag := range varFlags { - // Handle empty strings - if strings.TrimSpace(varFlag) == "" { - continue - } - - parts := strings.SplitN(varFlag, "=", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag) - } - - key := strings.TrimSpace(parts[0]) - value := parts[1] // Don't trim value to preserve intentional whitespace - - if key == "" { - return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) - } - - // Check for duplicate keys - if _, exists := templateVars[key]; exists { - return nil, fmt.Errorf("duplicate variable key '%s'", key) - } - - templateVars[key] = value - } - - return templateVars, nil -} - type runCommandHandler struct { ctx context.Context cfg *command.Config diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 94db2b63..d104e365 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -477,7 +477,7 @@ func TestParseTemplateVariables(t *testing.T) { flags := pflag.NewFlagSet("test", pflag.ContinueOnError) flags.StringSlice("var", tt.varFlags, "test flag") - result, err := parseTemplateVariables(flags) + result, err := util.ParseTemplateVariables(flags) if tt.expectErr { require.Error(t, err) diff --git a/pkg/util/util.go b/pkg/util/util.go index 1856f20b..c0005f21 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -4,6 +4,9 @@ package util import ( "fmt" "io" + "strings" + + "github.com/spf13/pflag" ) // WriteToOut writes a message to the given io.Writer. @@ -18,3 +21,40 @@ func WriteToOut(out io.Writer, message string) { func Ptr[T any](value T) *T { return &value } + +// ParseTemplateVariables parses template variables from the --var flags +func ParseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { + varFlags, err := flags.GetStringSlice("var") + if err != nil { + return nil, err + } + + templateVars := make(map[string]string) + for _, varFlag := range varFlags { + // Handle empty strings + if strings.TrimSpace(varFlag) == "" { + continue + } + + parts := strings.SplitN(varFlag, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag) + } + + key := strings.TrimSpace(parts[0]) + value := parts[1] // Don't trim value to preserve intentional whitespace + + if key == "" { + return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) + } + + // Check for duplicate keys + if _, exists := templateVars[key]; exists { + return nil, fmt.Errorf("duplicate variable key '%s'", key) + } + + templateVars[key] = value + } + + return templateVars, nil +} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go new file mode 100644 index 00000000..cb05decc --- /dev/null +++ b/pkg/util/util_test.go @@ -0,0 +1,111 @@ +package util + +import ( + "testing" + + "github.com/spf13/pflag" + "github.com/stretchr/testify/require" +) + +func TestParseTemplateVariables(t *testing.T) { + tests := []struct { + name string + varFlags []string + expected map[string]string + expectErr bool + }{ + { + name: "empty flags", + varFlags: []string{}, + expected: map[string]string{}, + }, + { + name: "single variable", + varFlags: []string{"name=Alice"}, + expected: map[string]string{"name": "Alice"}, + }, + { + name: "multiple variables", + varFlags: []string{"name=Alice", "age=30", "city=Boston"}, + expected: map[string]string{"name": "Alice", "age": "30", "city": "Boston"}, + }, + { + name: "variable with spaces in value", + varFlags: []string{"description=Hello World"}, + expected: map[string]string{"description": "Hello World"}, + }, + { + name: "variable with equals in value", + varFlags: []string{"equation=x=y+1"}, + expected: map[string]string{"equation": "x=y+1"}, + }, + { + name: "variable with empty value", + varFlags: []string{"empty="}, + expected: map[string]string{"empty": ""}, + }, + { + name: "variable with whitespace around key", + varFlags: []string{" name =Alice"}, + expected: map[string]string{"name": "Alice"}, + }, + { + name: "preserve whitespace in value", + varFlags: []string{"message= Hello World "}, + expected: map[string]string{"message": " Hello World "}, + }, + { + name: "empty string flag is ignored", + varFlags: []string{"", "name=Alice"}, + expected: map[string]string{"name": "Alice"}, + expectErr: false, + }, + { + name: "whitespace only flag is ignored", + varFlags: []string{" ", "name=Alice"}, + expected: map[string]string{"name": "Alice"}, + expectErr: false, + }, + { + name: "missing equals sign", + varFlags: []string{"name"}, + expectErr: true, + }, + { + name: "missing equals sign with multiple vars", + varFlags: []string{"name=Alice", "age"}, + expectErr: true, + }, + { + name: "empty key", + varFlags: []string{"=value"}, + expectErr: true, + }, + { + name: "whitespace only key", + varFlags: []string{" =value"}, + expectErr: true, + }, + { + name: "duplicate keys", + varFlags: []string{"name=Alice", "name=Bob"}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.StringSlice("var", tt.varFlags, "test flag") + + result, err := ParseTemplateVariables(flags) + + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +} \ No newline at end of file From f3a96f68d11eb272f70661a576d5390e093f7ac9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 30 Jul 2025 16:35:29 +0000 Subject: [PATCH 5/7] Apply gofmt formatting fixes for util_test.go --- pkg/util/util_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index cb05decc..eef7cd88 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -108,4 +108,4 @@ func TestParseTemplateVariables(t *testing.T) { } }) } -} \ No newline at end of file +} From b01b3b7eba78c4a7bae04b1a7a6a22168f53b565 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 30 Jul 2025 21:41:47 +0000 Subject: [PATCH 6/7] Prevent 'input' from being used as template variable key Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- pkg/util/util.go | 5 +++++ pkg/util/util_test.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/pkg/util/util.go b/pkg/util/util.go index c0005f21..575b0747 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -48,6 +48,11 @@ func ParseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) } + // Check for reserved keys + if key == "input" { + return nil, fmt.Errorf("'input' is a reserved variable name and cannot be used with --var") + } + // Check for duplicate keys if _, exists := templateVars[key]; exists { return nil, fmt.Errorf("duplicate variable key '%s'", key) diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index eef7cd88..380512e4 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -91,6 +91,11 @@ func TestParseTemplateVariables(t *testing.T) { varFlags: []string{"name=Alice", "name=Bob"}, expectErr: true, }, + { + name: "reserved input variable", + varFlags: []string{"input=test"}, + expectErr: true, + }, } for _, tt := range tests { From 7a60a818ed381b909a61c7a2a56ad46758de3415 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 30 Jul 2025 22:18:20 +0000 Subject: [PATCH 7/7] Move 'input' variable validation from shared utility to generate command Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- cmd/generate/generate.go | 5 +++++ cmd/generate/generate_test.go | 12 ++++++++++++ cmd/run/run_test.go | 5 +++++ pkg/util/util.go | 5 ----- pkg/util/util_test.go | 5 ----- 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/cmd/generate/generate.go b/cmd/generate/generate.go index 260dfe7c..f5864227 100644 --- a/cmd/generate/generate.go +++ b/cmd/generate/generate.go @@ -58,6 +58,11 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { return err } + // Check for reserved keys specific to generate command + if _, exists := templateVars["input"]; exists { + return fmt.Errorf("'input' is a reserved variable name and cannot be used with --var") + } + // Get organization org, _ := cmd.Flags().GetString("org") diff --git a/cmd/generate/generate_test.go b/cmd/generate/generate_test.go index da982505..b0f81d47 100644 --- a/cmd/generate/generate_test.go +++ b/cmd/generate/generate_test.go @@ -507,4 +507,16 @@ messages: require.Contains(t, *userMsg.Content, "about machine learning") require.Contains(t, *userMsg.Content, "academic style") }) + + t.Run("rejects input as template variable", func(t *testing.T) { + client := azuremodels.NewMockClient() + cfg := command.NewConfig(new(bytes.Buffer), new(bytes.Buffer), client, true, 100) + + cmd := NewGenerateCommand(cfg) + cmd.SetArgs([]string{"--var", "input=test", "dummy.yml"}) + + err := cmd.Execute() + require.Error(t, err) + require.Contains(t, err.Error(), "'input' is a reserved variable name and cannot be used with --var") + }) } diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index d104e365..7b21a06c 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -470,6 +470,11 @@ func TestParseTemplateVariables(t *testing.T) { varFlags: []string{"name=John", "name=Jane"}, expectErr: true, }, + { + name: "input variable is allowed in run command", + varFlags: []string{"input=test value"}, + expected: map[string]string{"input": "test value"}, + }, } for _, tt := range tests { diff --git a/pkg/util/util.go b/pkg/util/util.go index 575b0747..c0005f21 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -48,11 +48,6 @@ func ParseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) } - // Check for reserved keys - if key == "input" { - return nil, fmt.Errorf("'input' is a reserved variable name and cannot be used with --var") - } - // Check for duplicate keys if _, exists := templateVars[key]; exists { return nil, fmt.Errorf("duplicate variable key '%s'", key) diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 380512e4..eef7cd88 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -91,11 +91,6 @@ func TestParseTemplateVariables(t *testing.T) { varFlags: []string{"name=Alice", "name=Bob"}, expectErr: true, }, - { - name: "reserved input variable", - varFlags: []string{"input=test"}, - expectErr: true, - }, } for _, tt := range tests {