diff --git a/cmd/run/run.go b/cmd/run/run.go index d0f58991..fe2cf2e2 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -417,7 +417,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } cmd.Flags().String("file", "", "Path to a .prompt.yml file.") - cmd.Flags().StringSlice("var", []string{}, "Template variables for prompt files (can be used multiple times: --var name=value)") + cmd.Flags().StringArray("var", []string{}, "Template variables for prompt files (can be used multiple times: --var name=value)") cmd.Flags().String("max-tokens", "", "Limit the maximum tokens for the model response.") cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.") cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.") @@ -429,7 +429,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { // parseTemplateVariables parses template variables from the --var flags func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { - varFlags, err := flags.GetStringSlice("var") + varFlags, err := flags.GetStringArray("var") if err != nil { return nil, err } diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 94db2b63..f4b4233e 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -450,6 +450,11 @@ func TestParseTemplateVariables(t *testing.T) { varFlags: []string{"equation=x = y + 2"}, expected: map[string]string{"equation": "x = y + 2"}, }, + { + name: "value with commas", + varFlags: []string{"city=paris, milan", "countries=france, italy, spain"}, + expected: map[string]string{"city": "paris, milan", "countries": "france, italy, spain"}, + }, { name: "empty strings are skipped", varFlags: []string{"", "name=John", " "}, @@ -475,7 +480,7 @@ func TestParseTemplateVariables(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { flags := pflag.NewFlagSet("test", pflag.ContinueOnError) - flags.StringSlice("var", tt.varFlags, "test flag") + flags.StringArray("var", tt.varFlags, "test flag") result, err := parseTemplateVariables(flags)