From 62d2e3e45f7aa37066ab4d8d5ec46dca87eb7576 Mon Sep 17 00:00:00 2001 From: Nikita Aleksandrov Date: Thu, 30 Apr 2026 02:45:55 +0300 Subject: [PATCH 1/8] fix(cli): run command args parsing corner cases --- cmd/stroppy/commands/run/run.go | 200 +++++++++++++++++++++------ cmd/stroppy/commands/run/run_test.go | 194 ++++++++++++++++++++++++++ internal/runner/driver_preset.go | 126 +++++++++++++++-- 3 files changed, 466 insertions(+), 54 deletions(-) diff --git a/cmd/stroppy/commands/run/run.go b/cmd/stroppy/commands/run/run.go index fbd7e6a0..79673304 100644 --- a/cmd/stroppy/commands/run/run.go +++ b/cmd/stroppy/commands/run/run.go @@ -16,13 +16,23 @@ import ( "github.com/stroppy-io/stroppy/pkg/common/logger" ) -const consumedPairFlag = 2 // number of tokens consumed for a two-token flag (e.g. "-d pg") +const ( + consumedPairFlag = 2 // number of tokens consumed for a two-token flag (e.g. "-d pg") + flagSteps = "--steps" + flagNoSteps = "--no-steps" +) var ( - errNoScript = errors.New("script argument is required") - errFlagRequiresValue = errors.New("flag requires a value") - errStepsMutExclusive = errors.New("--steps and --no-steps are mutually exclusive") - errBadKeyValue = errors.New("expected key=value format") + errNoScript = errors.New("script argument is required") + errFlagRequiresValue = errors.New("flag requires a value") + errStepsMutExclusive = errors.New("--steps and --no-steps are mutually exclusive") + errBadKeyValue = errors.New("expected key=value format") + errUnknownRunFlag = errors.New("unknown run flag") + errPositionalAfterOpt = errors.New("unexpected positional argument after options") + errKeyValuePositional = errors.New("unexpected key=value positional argument") + errTooManyPositionals = errors.New( + "too many positional arguments; expected script and optional sql_file before --", + ) ) var Cmd = &cobra.Command{ @@ -39,6 +49,7 @@ var Cmd = &cobra.Command{ Files are searched in: current directory → ~/.stroppy/ → built-in workloads. SQL files are auto-derived from the preset/script name unless specified explicitly. See 'stroppy help resolution' for details on how files are found. +The script and optional sql_file positionals must be adjacent before --. Environment flags: -e, --env KEY=VALUE Set env var for the script (lowercase auto-uppercased) @@ -143,7 +154,9 @@ Config file flags: for idx, opts := range parsed.driverOpts { for _, kv := range opts { - applyDriverOpt(driverConfigs, idx, kv[0], kv[1]) + if err := applyDriverOpt(driverConfigs, idx, kv[0], kv[1]); err != nil { + return err + } } } @@ -198,6 +211,14 @@ type runArgs struct { // Returns the number of tokens consumed, or 0 if the arg is not this flag. type flagParser func(args []string, i int, parsed *runArgs) (int, error) +type positionalState int + +const ( + beforePositionals positionalState = iota + inPositionals + afterPositionals +) + // parseRunArgs parses the raw CLI args (after cobra hands them to RunE) and // returns the structured result without performing any file or preset resolution. func parseRunArgs(args []string) (runArgs, error) { @@ -219,30 +240,68 @@ func parseRunArgs(args []string) (runArgs, error) { parseDriverFlags, } + if err := parseRunArgsBeforeDash(positional, parsers, &parsed); err != nil { + return runArgs{}, err + } + + if len(parsed.steps) > 0 && len(parsed.noSteps) > 0 { + return runArgs{}, errStepsMutExclusive + } + + return parsed, nil +} + +func parseRunArgsBeforeDash(positional []string, parsers []flagParser, parsed *runArgs) error { + state := beforePositionals + for i := 0; i < len(positional); i++ { - consumed, err := dispatchFlag(parsers, positional, i, &parsed) + consumed, err := dispatchFlag(parsers, positional, i, parsed) if err != nil { - return runArgs{}, err + return err } if consumed > 0 { + if state == inPositionals { + state = afterPositionals + } + i += consumed - 1 continue } - if parsed.scriptArg == "" { - parsed.scriptArg = positional[i] - } else { - parsed.sqlArg = positional[i] + state, err = applyPositionalArg(positional[i], state, parsed) + if err != nil { + return err } } - if len(parsed.steps) > 0 && len(parsed.noSteps) > 0 { - return runArgs{}, errStepsMutExclusive + return nil +} + +func applyPositionalArg(arg string, state positionalState, parsed *runArgs) (positionalState, error) { + if strings.HasPrefix(arg, "-") && arg != "-" { + return state, fmt.Errorf("%w %q; pass k6 flags after --", errUnknownRunFlag, arg) } - return parsed, nil + if state == afterPositionals { + return state, positionalAfterOptionsError(arg) + } + + if isKeyValuePositional(arg) { + return state, keyValuePositionalError(arg) + } + + switch { + case parsed.scriptArg == "": + parsed.scriptArg = arg + case parsed.sqlArg == "": + parsed.sqlArg = arg + default: + return state, fmt.Errorf("%w: %q", errTooManyPositionals, arg) + } + + return inPositionals, nil } // dispatchFlag tries each parser in order until one consumes the arg at @@ -268,13 +327,14 @@ func parseStepsFlag(args []string, i int, parsed *runArgs) (int, error) { arg := args[i] switch { - case arg == "--steps" || arg == "--no-steps": - if i+1 >= len(args) { - return 0, fmt.Errorf("%s: %w", arg, errFlagRequiresValue) + case arg == flagSteps || arg == flagNoSteps: + value, err := nextFlagValue(args, i) + if err != nil { + return 0, err } - vals := strings.Split(args[i+1], ",") - if arg == "--steps" { + vals := strings.Split(value, ",") + if arg == flagSteps { parsed.steps = append(parsed.steps, vals...) } else { parsed.noSteps = append(parsed.noSteps, vals...) @@ -282,13 +342,13 @@ func parseStepsFlag(args []string, i int, parsed *runArgs) (int, error) { return consumedPairFlag, nil - case strings.HasPrefix(arg, "--steps="): - parsed.steps = append(parsed.steps, strings.Split(strings.TrimPrefix(arg, "--steps="), ",")...) + case strings.HasPrefix(arg, flagSteps+"="): + parsed.steps = append(parsed.steps, strings.Split(strings.TrimPrefix(arg, flagSteps+"="), ",")...) return 1, nil - case strings.HasPrefix(arg, "--no-steps="): - parsed.noSteps = append(parsed.noSteps, strings.Split(strings.TrimPrefix(arg, "--no-steps="), ",")...) + case strings.HasPrefix(arg, flagNoSteps+"="): + parsed.noSteps = append(parsed.noSteps, strings.Split(strings.TrimPrefix(arg, flagNoSteps+"="), ",")...) return 1, nil } @@ -303,11 +363,12 @@ func parseFileFlag(args []string, i int, parsed *runArgs) (int, error) { switch { case arg == "-f" || arg == "--file": - if i+1 >= len(args) { - return 0, fmt.Errorf("%s: %w", arg, errFlagRequiresValue) + value, err := nextFlagValue(args, i) + if err != nil { + return 0, err } - parsed.fileArg = args[i+1] + parsed.fileArg = value return consumedPairFlag, nil @@ -332,11 +393,12 @@ func parseEnvFlag(args []string, i int, parsed *runArgs) (int, error) { switch { case arg == "-e" || arg == "--env": - if i+1 >= len(args) { - return 0, fmt.Errorf("%s: %w", arg, errFlagRequiresValue) + value, err := nextFlagValue(args, i) + if err != nil { + return 0, err } - parsed.envArgs = append(parsed.envArgs, args[i+1]) + parsed.envArgs = append(parsed.envArgs, value) return consumedPairFlag, nil @@ -397,11 +459,12 @@ func parseFlagNextArg( for _, prefix := range []string{shortPrefix, longPrefix} { if idx, ok := parseShortFlag(arg, prefix); ok { - if i+1 >= len(args) { - return 0, "", 0, fmt.Errorf("%s: %w", arg, errFlagRequiresValue) + next, err := nextFlagValue(args, i) + if err != nil { + return 0, "", 0, err } - return idx, args[i+1], consumedPairFlag, nil + return idx, next, consumedPairFlag, nil } } @@ -472,13 +535,14 @@ func parseDriverOptFlag(args []string, i int) (driverIndex int, key, value strin // -D / -D0 / -D1 (short form, two tokens) if idx, ok := parseShortFlag(arg, "-D"); ok { - if i+1 >= len(args) { - return 0, "", "", 0, fmt.Errorf("%s: %w", arg, errFlagRequiresValue) + raw, err := nextFlagValue(args, i) + if err != nil { + return 0, "", "", 0, err } - key, value, err = splitKeyValue(args[i+1]) + key, value, err = splitKeyValue(raw) if err != nil { - return 0, "", "", 0, fmt.Errorf("%s %s: %w", arg, args[i+1], err) + return 0, "", "", 0, fmt.Errorf("%s %s: %w", arg, raw, err) } return idx, key, value, consumedPairFlag, nil @@ -486,13 +550,14 @@ func parseDriverOptFlag(args []string, i int) (driverIndex int, key, value strin // --driver-opt / --driver1-opt / --driver0-opt (long form, two tokens) if idx, ok := parseIndexedInfixFlag(arg, "--driver", "-opt"); ok { - if i+1 >= len(args) { - return 0, "", "", 0, fmt.Errorf("%s: %w", arg, errFlagRequiresValue) + raw, err := nextFlagValue(args, i) + if err != nil { + return 0, "", "", 0, err } - key, value, err = splitKeyValue(args[i+1]) + key, value, err = splitKeyValue(raw) if err != nil { - return 0, "", "", 0, fmt.Errorf("%s %s: %w", arg, args[i+1], err) + return 0, "", "", 0, fmt.Errorf("%s %s: %w", arg, raw, err) } return idx, key, value, consumedPairFlag, nil @@ -619,6 +684,55 @@ func splitKeyValue(s string) (key, val string, err error) { return key, val, nil } +func positionalAfterOptionsError(arg string) error { + message := "script and sql_file must be adjacent before --" + if strings.Contains(arg, "=") { + message += "; quote driver/env values that contain spaces" + } + + return fmt.Errorf("%w: %q; %s", errPositionalAfterOpt, arg, message) +} + +func keyValuePositionalError(arg string) error { + return fmt.Errorf( + "%w: %q; key=value arguments must follow -D/--driver-opt or -e/--env; quote values that contain spaces", + errKeyValuePositional, + arg, + ) +} + +func isKeyValuePositional(arg string) bool { + key, _, ok := strings.Cut(arg, "=") + if !ok || key == "" || strings.ContainsAny(arg, " \t\n") { + return false + } + + for _, r := range key { + if r == '_' || r == '-' || r == '.' || r >= '0' && r <= '9' || + r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' { + continue + } + + return false + } + + return true +} + +func nextFlagValue(args []string, i int) (string, error) { + flag := args[i] + if i+1 >= len(args) { + return "", fmt.Errorf("%s: %w", flag, errFlagRequiresValue) + } + + next := args[i+1] + if strings.HasPrefix(next, "-") && next != "-" { + return "", fmt.Errorf("%s: %w", flag, errFlagRequiresValue) + } + + return next, nil +} + // applyDriverPreset loads a preset or parses raw JSON and sets it on the config map. // If the value starts with '{', it's treated as a JSON driver config; otherwise as a preset name. func applyDriverPreset(configs runner.DriverCLIConfigs, idx int, value string) error { @@ -645,12 +759,12 @@ func applyDriverPreset(configs runner.DriverCLIConfigs, idx int, value string) e } // applyDriverOpt applies a -D key=value override to the driver at the given index. -func applyDriverOpt(configs runner.DriverCLIConfigs, idx int, key, value string) { +func applyDriverOpt(configs runner.DriverCLIConfigs, idx int, key, value string) error { cfg, ok := configs[idx] if !ok { cfg = &runner.DriverCLIConfig{} configs[idx] = cfg } - cfg.ApplyOverride(key, value) + return cfg.ApplyOverride(key, value) } diff --git a/cmd/stroppy/commands/run/run_test.go b/cmd/stroppy/commands/run/run_test.go index 40c9b80c..fbdd06ad 100644 --- a/cmd/stroppy/commands/run/run_test.go +++ b/cmd/stroppy/commands/run/run_test.go @@ -1,6 +1,7 @@ package run import ( + "encoding/json" "errors" "os" "testing" @@ -50,6 +51,22 @@ func TestParseRunArgs(t *testing.T) { wantScript: "./benchmarks/custom.ts", wantSQL: "data.sql", }, + { + name: "third positional returns error", + args: []string{"tpcc", "pg.sql", "extra.sql"}, + wantErrStr: "too many positional arguments", + }, + { + name: "unknown flag before separator returns error", + args: []string{"tpcc", "--vus", "10"}, + wantErrStr: "pass k6 flags after --", + }, + { + name: "inline SQL query with spaces and equals is single positional", + args: []string{"select a=1", "-d", "pg"}, + wantScript: "select a=1", + wantPresets: map[int]string{0: "pg"}, + }, // ── Missing script ───────────────────────────────────────────────── { @@ -82,6 +99,23 @@ func TestParseRunArgs(t *testing.T) { args: []string{"-f", "myconfig.json"}, wantFile: "myconfig.json", }, + { + name: "-f followed by driver flag returns missing value", + args: []string{"-f", "-d", "pg"}, + wantErrStr: "-f: flag requires a value", + }, + + // ── -e / --env ───────────────────────────────────────────────────── + { + name: "-e accepts values starting with dash after equals", + args: []string{"tpcc", "-e", "TOKEN=-abc"}, + wantScript: "tpcc", + }, + { + name: "-e followed by steps flag returns missing value", + args: []string{"tpcc", "-e", "--steps", "load"}, + wantErrStr: "-e: flag requires a value", + }, // ── --steps / --no-steps ─────────────────────────────────────────── { @@ -123,6 +157,16 @@ func TestParseRunArgs(t *testing.T) { args: []string{"tpcc", "--no-steps"}, wantErrStr: "flag requires a value", }, + { + name: "--steps followed by known flag returns missing value", + args: []string{"tpcc", "--steps", "-d", "pg"}, + wantErrStr: "--steps: flag requires a value", + }, + { + name: "--steps followed by unknown flag returns missing value", + args: []string{"tpcc", "--steps", "--vus", "10"}, + wantErrStr: "--steps: flag requires a value", + }, // ── Driver preset flags ──────────────────────────────────────────── { @@ -183,6 +227,16 @@ func TestParseRunArgs(t *testing.T) { args: []string{"tpcc", "--driver"}, wantErrStr: "flag requires a value", }, + { + name: "-d followed by driver option flag returns missing value", + args: []string{"tpcc", "-d", "-D", "url=postgres://prod"}, + wantErrStr: "-d: flag requires a value", + }, + { + name: "--driver followed by steps flag returns missing value", + args: []string{"tpcc", "--driver", "--steps", "load"}, + wantErrStr: "--driver: flag requires a value", + }, { name: "two drivers -d and -d1", args: []string{"tpcc", "-d", "pg", "-d1", "mysql"}, @@ -197,6 +251,16 @@ func TestParseRunArgs(t *testing.T) { wantScript: "tpcc", wantOpts: map[int][][2]string{0: {{"url", "postgres://prod:5432"}}}, }, + { + name: "unquoted driver value fragment returns quote hint", + args: []string{"tpcc", "-D", "url=host=localhost", "user=postgres"}, + wantErrStr: "quote driver/env values", + }, + { + name: "unquoted driver value fragment before script returns key value hint", + args: []string{"-D", "url=host=localhost", "user=postgres", "tpcc"}, + wantErrStr: "key=value arguments must follow", + }, { name: "-D1 key=value", args: []string{"tpcc", "-D1", "url=mysql://prod:3306"}, @@ -257,6 +321,11 @@ func TestParseRunArgs(t *testing.T) { args: []string{"tpcc", "--driver-opt"}, wantErrStr: "flag requires a value", }, + { + name: "--driver-opt followed by steps flag returns missing value", + args: []string{"tpcc", "--driver-opt", "--steps", "load"}, + wantErrStr: "--driver-opt: flag requires a value", + }, { name: "-D value without = returns error", args: []string{"tpcc", "-D", "noequals"}, @@ -298,6 +367,19 @@ func TestParseRunArgs(t *testing.T) { wantSteps: []string{"load", "run"}, wantAfterDash: []string{"--duration", "5m"}, }, + { + name: "flags may wrap adjacent script sql block", + args: []string{"-f", "prod.json", "tpcc", "tpcc/pico", "-d", "pico"}, + wantScript: "tpcc", + wantSQL: "tpcc/pico", + wantFile: "prod.json", + wantPresets: map[int]string{0: "pico"}, + }, + { + name: "positional after option following script returns adjacency error", + args: []string{"tpcc", "-d", "pg", "tpcc/pico"}, + wantErrStr: "script and sql_file must be adjacent", + }, { name: "script + sql + two drivers + driver opt", args: []string{"tpcc", "tpcc-scale-100", "-d", "pg", "-d1", "mysql", "-D1", "url=mysql://prod"}, @@ -496,6 +578,86 @@ func TestApplyDriverPresetInvalidJSON(t *testing.T) { } } +func TestApplyDriverOptDottedPool(t *testing.T) { + t.Parallel() + + configs := runner.DriverCLIConfigs{} + + if err := applyDriverOpt(configs, 0, "pool.maxConns", "20"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err := applyDriverOpt(configs, 0, "pool.maxConnLifetime", "30m"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + got := marshalDriverConfig(t, configs[0]) + pool := objectField(t, got, "pool") + + if pool["maxConns"] != float64(20) { + t.Errorf("pool.maxConns: got %v, want 20", pool["maxConns"]) + } + + if pool["maxConnLifetime"] != "30m" { + t.Errorf("pool.maxConnLifetime: got %v, want 30m", pool["maxConnLifetime"]) + } +} + +func TestApplyDriverOptDottedPoolMergesJSONPreset(t *testing.T) { + t.Parallel() + + configs := runner.DriverCLIConfigs{} + + if err := applyDriverPreset(configs, 0, `{"driverType":"postgres","pool":{"minConns":5}}`); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err := applyDriverOpt(configs, 0, "pool.maxConns", "20"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + got := marshalDriverConfig(t, configs[0]) + pool := objectField(t, got, "pool") + + if pool["minConns"] != float64(5) { + t.Errorf("pool.minConns: got %v, want 5", pool["minConns"]) + } + + if pool["maxConns"] != float64(20) { + t.Errorf("pool.maxConns: got %v, want 20", pool["maxConns"]) + } +} + +func TestApplyDriverOptDottedPoolUnknownField(t *testing.T) { + t.Parallel() + + configs := runner.DriverCLIConfigs{} + + err := applyDriverOpt(configs, 0, "pool.maximum", "20") + if err == nil { + t.Fatal("expected error for unknown pool field") + } + + if !contains(err.Error(), "unknown pool option") { + t.Fatalf("got error %q, want it to contain unknown pool option", err.Error()) + } +} + +func TestApplyDriverOptTLSAliases(t *testing.T) { + t.Parallel() + + configs := runner.DriverCLIConfigs{} + + if err := applyDriverOpt(configs, 0, "tls.insecureSkipVerify", "true"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + got := marshalDriverConfig(t, configs[0]) + if got["tlsInsecureSkipVerify"] != true { + t.Errorf("tlsInsecureSkipVerify: got %v, want true", got["tlsInsecureSkipVerify"]) + } +} + func TestToEnvVarsRespectsExistingEnv(t *testing.T) { t.Setenv("STROPPY_DRIVER_0", `{"url":"from-env"}`) @@ -533,6 +695,38 @@ func TestToEnvVarsSetsWhenNotInEnv(t *testing.T) { } } +func marshalDriverConfig(t *testing.T, cfg *runner.DriverCLIConfig) map[string]any { + t.Helper() + + data, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal driver config: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal driver config: %v", err) + } + + return got +} + +func objectField(t *testing.T, m map[string]any, key string) map[string]any { + t.Helper() + + raw, ok := m[key] + if !ok { + t.Fatalf("missing object field %q in %#v", key, m) + } + + obj, ok := raw.(map[string]any) + if !ok { + t.Fatalf("field %q has type %T, want object", key, raw) + } + + return obj +} + func driverOptMapsEqual(a, b map[int][][2]string) bool { if len(a) == 0 && len(b) == 0 { return true diff --git a/internal/runner/driver_preset.go b/internal/runner/driver_preset.go index 7f773477..5c1053a1 100644 --- a/internal/runner/driver_preset.go +++ b/internal/runner/driver_preset.go @@ -24,7 +24,11 @@ var pathFields = map[string]bool{ "cacertfile": true, } -var errUnknownDriver = errors.New("unknown driver") +var ( + errUnknownDriver = errors.New("unknown driver") + errInvalidDriverOverride = errors.New("invalid driver override") + errDriverOverrideConflict = errors.New("driver override conflicts with existing non-object value") +) // inferType converts a CLI string value to its most specific Go type // so that JSON serialization emits a number/bool instead of a quoted string. @@ -151,7 +155,15 @@ func (d DriverCLIConfig) MarshalJSON() ([]byte, error) { // ApplyOverride sets a field by key=value. Known fields are set on the struct, // unknown fields go into Extra for pass-through to TS. -func (d *DriverCLIConfig) ApplyOverride(key, value string) { +func (d *DriverCLIConfig) ApplyOverride(key, value string) error { + if key == "" { + return fmt.Errorf("%w: empty key", errInvalidDriverOverride) + } + + if parent, child, ok := strings.Cut(key, "."); ok { + return d.applyDottedOverride(parent, child, value) + } + switch strings.ToLower(key) { case "drivertype": d.DriverType = value @@ -160,18 +172,110 @@ func (d *DriverCLIConfig) ApplyOverride(key, value string) { case "defaultinsertmethod": d.DefaultInsertMethod = value default: - if d.Extra == nil { - d.Extra = make(map[string]any) - } + d.setExtra(key, convertOverrideValue(key, value)) + } - if pathFields[strings.ToLower(key)] { - if abs, err := filepath.Abs(value); err == nil { - value = abs - } + return nil +} + +func (d *DriverCLIConfig) applyDottedOverride(parent, child, value string) error { + if child == "" { + return fmt.Errorf("%w: empty nested field in %q", errInvalidDriverOverride, parent) + } + + switch strings.ToLower(parent) { + case "pool": + field, err := canonicalPoolField(child) + if err != nil { + return err } - d.Extra[key] = inferType(value) + return d.setNestedExtra("pool", field, inferType(value)) + case "postgres", "sql": + return d.setNestedExtra(strings.ToLower(parent), child, inferType(value)) + case "tls": + return d.applyTLSOverride(child, value) + default: + // Preserve unknown dotted keys for forward-compatible top-level TS fields. + d.setExtra(parent+"."+child, inferType(value)) + + return nil + } +} + +func canonicalPoolField(field string) (string, error) { + switch normalizeKey(field) { + case "maxconns": + return "maxConns", nil + case "minconns": + return "minConns", nil + case "maxconnlifetime": + return "maxConnLifetime", nil + case "maxconnidletime": + return "maxConnIdleTime", nil + default: + return "", fmt.Errorf("%w: unknown pool option %q", errInvalidDriverOverride, field) + } +} + +func (d *DriverCLIConfig) applyTLSOverride(field, value string) error { + switch normalizeKey(field) { + case "cacert", "cacertfile": + d.setExtra("caCertFile", convertOverrideValue("caCertFile", value)) + case "insecureskipverify", "tlsinsecureskipverify": + d.setExtra("tlsInsecureSkipVerify", inferType(value)) + default: + return fmt.Errorf("%w: unknown tls option %q", errInvalidDriverOverride, field) + } + + return nil +} + +func (d *DriverCLIConfig) setNestedExtra(parent, child string, value any) error { + if d.Extra == nil { + d.Extra = make(map[string]any) + } + + existing, ok := d.Extra[parent] + if !ok { + nested := map[string]any{child: value} + d.Extra[parent] = nested + + return nil + } + + nested, ok := existing.(map[string]any) + if !ok { + return fmt.Errorf("%w: %q", errDriverOverrideConflict, parent) + } + + nested[child] = value + + return nil +} + +func (d *DriverCLIConfig) setExtra(key string, value any) { + if d.Extra == nil { + d.Extra = make(map[string]any) + } + + d.Extra[key] = value +} + +func convertOverrideValue(key, value string) any { + if pathFields[normalizeKey(key)] { + if abs, err := filepath.Abs(value); err == nil { + return abs + } } + + return inferType(value) +} + +func normalizeKey(key string) string { + replacer := strings.NewReplacer("_", "", "-", "") + + return strings.ToLower(replacer.Replace(key)) } // NewDriverCLIConfigFromPreset creates a DriverCLIConfig from a preset. @@ -208,7 +312,7 @@ func NewDriverCLIConfigFromJSON(raw string) (DriverCLIConfig, error) { cfg.Extra = make(map[string]any) } - if pathFields[strings.ToLower(field)] { + if pathFields[normalizeKey(field)] { if s, ok := val.(string); ok { if abs, err := filepath.Abs(s); err == nil { val = abs From f3d3e2dc757e4d5f00b92eaed1e50f7412b3ba59 Mon Sep 17 00:00:00 2001 From: Nikita Aleksandrov Date: Thu, 30 Apr 2026 02:46:47 +0300 Subject: [PATCH 2/8] refactor(cli,ts): driver options validation in ts --- cmd/stroppy/commands/run/run_test.go | 42 ++++-- internal/runner/driver_preset.go | 129 ++++++++--------- internal/static/helpers.ts | 139 ++++++++++++++++--- internal/static/tests/helpers.test.ts | 55 ++++++++ internal/static/tests/stubs/k6_execution.ts | 8 ++ internal/static/tests/stubs/k6_metrics.ts | 14 ++ internal/static/tests/stubs/k6_x_encoding.ts | 4 + internal/static/tests/stubs/k6_x_stroppy.ts | 40 ++++++ internal/static/tests/tsconfig.json | 9 +- internal/static/tests/vitest.config.ts | 5 +- 10 files changed, 339 insertions(+), 106 deletions(-) create mode 100644 internal/static/tests/helpers.test.ts create mode 100644 internal/static/tests/stubs/k6_execution.ts create mode 100644 internal/static/tests/stubs/k6_metrics.ts create mode 100644 internal/static/tests/stubs/k6_x_encoding.ts create mode 100644 internal/static/tests/stubs/k6_x_stroppy.ts diff --git a/cmd/stroppy/commands/run/run_test.go b/cmd/stroppy/commands/run/run_test.go index fbdd06ad..794ec511 100644 --- a/cmd/stroppy/commands/run/run_test.go +++ b/cmd/stroppy/commands/run/run_test.go @@ -633,28 +633,52 @@ func TestApplyDriverOptDottedPoolUnknownField(t *testing.T) { configs := runner.DriverCLIConfigs{} - err := applyDriverOpt(configs, 0, "pool.maximum", "20") - if err == nil { - t.Fatal("expected error for unknown pool field") + if err := applyDriverOpt(configs, 0, "pool.maximum", "20"); err != nil { + t.Fatalf("unexpected error: %v", err) } - if !contains(err.Error(), "unknown pool option") { - t.Fatalf("got error %q, want it to contain unknown pool option", err.Error()) + got := marshalDriverConfig(t, configs[0]) + pool := objectField(t, got, "pool") + + if pool["maximum"] != float64(20) { + t.Errorf("pool.maximum: got %v, want 20", pool["maximum"]) } } -func TestApplyDriverOptTLSAliases(t *testing.T) { +func TestApplyDriverOptDottedPathIsGeneric(t *testing.T) { t.Parallel() configs := runner.DriverCLIConfigs{} - if err := applyDriverOpt(configs, 0, "tls.insecureSkipVerify", "true"); err != nil { + if err := applyDriverOpt(configs, 0, "custom.deep.value", "1"); err != nil { t.Fatalf("unexpected error: %v", err) } got := marshalDriverConfig(t, configs[0]) - if got["tlsInsecureSkipVerify"] != true { - t.Errorf("tlsInsecureSkipVerify: got %v, want true", got["tlsInsecureSkipVerify"]) + custom := objectField(t, got, "custom") + deep := objectField(t, custom, "deep") + + if deep["value"] != float64(1) { + t.Errorf("custom.deep.value: got %v, want 1", deep["value"]) + } +} + +func TestApplyDriverOptDottedPathConflict(t *testing.T) { + t.Parallel() + + configs := runner.DriverCLIConfigs{} + + if err := applyDriverOpt(configs, 0, "pool", "not-object"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + err := applyDriverOpt(configs, 0, "pool.maxConns", "20") + if err == nil { + t.Fatal("expected structural conflict error") + } + + if !contains(err.Error(), "conflicts") { + t.Fatalf("got error %q, want conflict", err.Error()) } } diff --git a/internal/runner/driver_preset.go b/internal/runner/driver_preset.go index 5c1053a1..6edaaa6b 100644 --- a/internal/runner/driver_preset.go +++ b/internal/runner/driver_preset.go @@ -30,6 +30,12 @@ var ( errDriverOverrideConflict = errors.New("driver override conflicts with existing non-object value") ) +const ( + driverTypeKey = "drivertype" + urlKey = "url" + defaultInsertMethodKey = "defaultinsertmethod" +) + // inferType converts a CLI string value to its most specific Go type // so that JSON serialization emits a number/bool instead of a quoted string. // This is required because protobuf (TS side) rejects "20" for int32 fields. @@ -160,106 +166,81 @@ func (d *DriverCLIConfig) ApplyOverride(key, value string) error { return fmt.Errorf("%w: empty key", errInvalidDriverOverride) } - if parent, child, ok := strings.Cut(key, "."); ok { - return d.applyDottedOverride(parent, child, value) - } - - switch strings.ToLower(key) { - case "drivertype": + switch normalizeKey(key) { + case driverTypeKey: d.DriverType = value - case "url": + case urlKey: d.URL = value - case "defaultinsertmethod": + case defaultInsertMethodKey: d.DefaultInsertMethod = value default: - d.setExtra(key, convertOverrideValue(key, value)) + return d.setExtraPath(strings.Split(key, "."), convertOverrideValue(key, value)) } return nil } -func (d *DriverCLIConfig) applyDottedOverride(parent, child, value string) error { - if child == "" { - return fmt.Errorf("%w: empty nested field in %q", errInvalidDriverOverride, parent) +func (d *DriverCLIConfig) setExtraPath(path []string, value any) error { + if err := validateOverridePath(path); err != nil { + return err } - switch strings.ToLower(parent) { - case "pool": - field, err := canonicalPoolField(child) - if err != nil { - return err + if d.Extra == nil { + d.Extra = make(map[string]any) + } + + target := d.Extra + for _, part := range path[:len(path)-1] { + next, ok := target[part] + if !ok { + nested := make(map[string]any) + target[part] = nested + target = nested + + continue } - return d.setNestedExtra("pool", field, inferType(value)) - case "postgres", "sql": - return d.setNestedExtra(strings.ToLower(parent), child, inferType(value)) - case "tls": - return d.applyTLSOverride(child, value) - default: - // Preserve unknown dotted keys for forward-compatible top-level TS fields. - d.setExtra(parent+"."+child, inferType(value)) + nested, ok := next.(map[string]any) + if !ok { + return fmt.Errorf("%w: %q", errDriverOverrideConflict, part) + } - return nil + target = nested } -} -func canonicalPoolField(field string) (string, error) { - switch normalizeKey(field) { - case "maxconns": - return "maxConns", nil - case "minconns": - return "minConns", nil - case "maxconnlifetime": - return "maxConnLifetime", nil - case "maxconnidletime": - return "maxConnIdleTime", nil - default: - return "", fmt.Errorf("%w: unknown pool option %q", errInvalidDriverOverride, field) + last := path[len(path)-1] + if existing, exists := target[last]; exists { + if _, isObject := existing.(map[string]any); isObject { + return fmt.Errorf("%w: %q", errDriverOverrideConflict, last) + } } -} -func (d *DriverCLIConfig) applyTLSOverride(field, value string) error { - switch normalizeKey(field) { - case "cacert", "cacertfile": - d.setExtra("caCertFile", convertOverrideValue("caCertFile", value)) - case "insecureskipverify", "tlsinsecureskipverify": - d.setExtra("tlsInsecureSkipVerify", inferType(value)) - default: - return fmt.Errorf("%w: unknown tls option %q", errInvalidDriverOverride, field) - } + target[last] = value return nil } -func (d *DriverCLIConfig) setNestedExtra(parent, child string, value any) error { - if d.Extra == nil { - d.Extra = make(map[string]any) - } - - existing, ok := d.Extra[parent] - if !ok { - nested := map[string]any{child: value} - d.Extra[parent] = nested - - return nil +func validateOverridePath(path []string) error { + for _, part := range path { + if part == "" { + return fmt.Errorf("%w: empty dotted path segment", errInvalidDriverOverride) + } } - nested, ok := existing.(map[string]any) - if !ok { - return fmt.Errorf("%w: %q", errDriverOverrideConflict, parent) + if len(path) > 1 && isDriverCLIField(path[0]) { + return fmt.Errorf("%w: %q", errDriverOverrideConflict, path[0]) } - nested[child] = value - return nil } -func (d *DriverCLIConfig) setExtra(key string, value any) { - if d.Extra == nil { - d.Extra = make(map[string]any) +func isDriverCLIField(key string) bool { + switch normalizeKey(key) { + case driverTypeKey, urlKey, defaultInsertMethodKey: + return true + default: + return false } - - d.Extra[key] = value } func convertOverrideValue(key, value string) any { @@ -300,12 +281,12 @@ func NewDriverCLIConfigFromJSON(raw string) (DriverCLIConfig, error) { for field, val := range m { str, _ := val.(string) - switch strings.ToLower(field) { - case "drivertype": + switch normalizeKey(field) { + case driverTypeKey: cfg.DriverType = str - case "url": + case urlKey: cfg.URL = str - case "defaultinsertmethod": + case defaultInsertMethodKey: cfg.DefaultInsertMethod = str default: if cfg.Extra == nil { diff --git a/internal/static/helpers.ts b/internal/static/helpers.ts index 503711ee..1a056a34 100644 --- a/internal/static/helpers.ts +++ b/internal/static/helpers.ts @@ -333,6 +333,116 @@ export type DriverSetup = Omit, "errorMode" | "driverType" sql?: Partial; } +type ScalarKind = "string" | "number" | "boolean"; +type FieldRule = ScalarKind | { enum: Set } | { object: Schema }; +type Schema = Record; + +const poolSchema: Schema = { + maxConns: "number", + minConns: "number", + maxConnLifetime: "string", + maxConnIdleTime: "string", +}; + +const postgresSchema: Schema = { + traceLogLevel: "string", + maxConnLifetime: "string", + maxConnIdleTime: "string", + maxConns: "number", + minConns: "number", + minIdleConns: "number", + defaultQueryExecMode: "string", + descriptionCacheCapacity: "number", + statementCacheCapacity: "number", +}; + +const sqlSchema: Schema = { + maxOpenConns: "number", + maxIdleConns: "number", + connMaxLifetime: "string", + connMaxIdleTime: "string", +}; + +const driverSetupSchema: Schema = { + url: "string", + driverType: { enum: new Set(Object.keys(driverTypeMap)) }, + defaultTxIsolation: { enum: new Set(Object.keys(txIsolationMap)) }, + defaultInsertMethod: { enum: new Set(Object.keys(insertMethodMap)) }, + errorMode: { enum: new Set(Object.keys(errorModeMap)) }, + pool: { object: poolSchema }, + postgres: { object: postgresSchema }, + sql: { object: sqlSchema }, + bulkSize: "number", + caCertFile: "string", + authToken: "string", + authUser: "string", + authPassword: "string", + tlsInsecureSkipVerify: "boolean", +}; + +function isPlainObject(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function schemaKeys(schema: Schema): string { + return Object.keys(schema).sort().join(", "); +} + +function validateSchema(source: string, path: string, value: unknown, schema: Schema): void { + if (!isPlainObject(value)) { + throw new Error(`${source}: ${path} must be object`); + } + + for (const [key, nested] of Object.entries(value)) { + if (nested === undefined) continue; + + const rule = schema[key]; + if (!rule) { + throw new Error(`${source}: unknown ${path} option "${key}" (available: ${schemaKeys(schema)})`); + } + + const fieldPath = path === "driver" ? key : `${path}.${key}`; + if (typeof rule === "string") { + validateKind(source, fieldPath, nested, rule); + } else if ("enum" in rule) { + validateEnum(source, fieldPath, nested, rule.enum); + } else { + validateSchema(source, fieldPath, nested, rule.object); + } + } +} + +function validateKind(source: string, path: string, value: unknown, kind: ScalarKind): void { + if (typeof value !== kind) { + throw new Error(`${source}: ${path} must be ${kind}, got ${typeof value}`); + } +} + +function validateEnum(source: string, path: string, value: unknown, allowed: Set): void { + validateKind(source, path, value, "string"); + if (!allowed.has(value as string)) { + throw new Error(`${source}: ${path} must be one of: ${[...allowed].sort().join(", ")}`); + } +} + +export function validateDriverSetup(source: string, setup: unknown): asserts setup is DriverSetup { + validateSchema(source, "driver", setup, driverSetupSchema); +} + +function mergeDriverSetup(defaults: DriverSetup, cli: Partial): DriverSetup { + const merged: Record = { ...defaults }; + for (const [key, value] of Object.entries(cli)) { + if (value === undefined) continue; + if ((key === "pool" || key === "postgres" || key === "sql") && isPlainObject(value)) { + merged[key] = { ...((defaults as Record)[key] as object | undefined), ...value }; + } else { + merged[key] = value; + } + } + + return merged as DriverSetup; +} + /** Resolve pool sugar into the appropriate driver-specific config. */ function resolvePoolConfig(config: DriverSetup): { postgres?: Partial; @@ -390,29 +500,18 @@ export function declareDriverSetup(index: number, defaults: DriverSetup): Driver const raw = __ENV[envKey]; if (!raw || raw === "") return defaults; + let cli: unknown; try { - const cli = JSON.parse(raw) as Partial; - // Deep merge: CLI fields override defaults, but only if actually set. - const merged: DriverSetup = { ...defaults }; - if (cli.driverType !== undefined) merged.driverType = cli.driverType as DriverTypeName; - if (cli.url !== undefined) merged.url = cli.url; - if (cli.defaultTxIsolation !== undefined) merged.defaultTxIsolation = cli.defaultTxIsolation as TxIsolationName; - if (cli.defaultInsertMethod !== undefined) merged.defaultInsertMethod = cli.defaultInsertMethod as InsertMethodName; - if (cli.errorMode !== undefined) merged.errorMode = cli.errorMode as ErrorModeName; - if (cli.pool !== undefined) merged.pool = cli.pool; - if (cli.postgres !== undefined) merged.postgres = cli.postgres; - if (cli.sql !== undefined) merged.sql = cli.sql; - if ((cli as any).bulkSize !== undefined) merged.bulkSize = (cli as any).bulkSize; - if (cli.caCertFile !== undefined) merged.caCertFile = cli.caCertFile; - if (cli.authToken !== undefined) merged.authToken = cli.authToken; - if (cli.authUser !== undefined) merged.authUser = cli.authUser; - if (cli.authPassword !== undefined) merged.authPassword = cli.authPassword; - if (cli.tlsInsecureSkipVerify !== undefined) merged.tlsInsecureSkipVerify = cli.tlsInsecureSkipVerify; - return merged; + cli = JSON.parse(raw); } catch (e) { - console.error(`[stroppy] failed to parse ${envKey}: ${e}`); - return defaults; + throw new Error(`[stroppy] failed to parse ${envKey}: ${e}`); } + + validateDriverSetup(envKey, cli); + const merged = mergeDriverSetup(defaults, cli); + validateDriverSetup(`${envKey} merged`, merged); + + return merged; } export class DriverX implements QueryAPI { diff --git a/internal/static/tests/helpers.test.ts b/internal/static/tests/helpers.test.ts new file mode 100644 index 00000000..48ecd812 --- /dev/null +++ b/internal/static/tests/helpers.test.ts @@ -0,0 +1,55 @@ +import { describe, expect, it, vi } from "vitest"; + +async function loadHelpers(env: Record = {}) { + vi.resetModules(); + (globalThis as unknown as { __ENV: Record }).__ENV = env; + + return import("../helpers.ts"); +} + +describe("DriverSetup validation", () => { + it("accepts valid pool overrides and deep-merges them over defaults", async () => { + const { declareDriverSetup } = await loadHelpers({ + STROPPY_DRIVER_0: JSON.stringify({ pool: { maxConns: 20 } }), + }); + + const setup = declareDriverSetup(0, { + driverType: "postgres", + pool: { maxConns: 5, minConns: 5 }, + }); + + expect(setup.pool).toEqual({ maxConns: 20, minConns: 5 }); + }); + + it("rejects unknown pool keys transported by the CLI", async () => { + const { declareDriverSetup } = await loadHelpers({ + STROPPY_DRIVER_0: JSON.stringify({ pool: { maximum: 20 } }), + }); + + expect(() => declareDriverSetup(0, { driverType: "postgres" })).toThrow(/unknown pool option "maximum"/); + }); + + it("rejects unknown top-level keys", async () => { + const { declareDriverSetup } = await loadHelpers({ + STROPPY_DRIVER_0: JSON.stringify({ tls: { insecureSkipVerify: true } }), + }); + + expect(() => declareDriverSetup(0, { driverType: "postgres" })).toThrow(/unknown driver option "tls"/); + }); + + it("rejects invalid primitive types", async () => { + const { declareDriverSetup } = await loadHelpers({ + STROPPY_DRIVER_0: JSON.stringify({ pool: { maxConns: "20" } }), + }); + + expect(() => declareDriverSetup(0, { driverType: "postgres" })).toThrow(/pool\.maxConns must be number/); + }); + + it("rejects invalid enum values", async () => { + const { declareDriverSetup } = await loadHelpers({ + STROPPY_DRIVER_0: JSON.stringify({ driverType: "oracle" }), + }); + + expect(() => declareDriverSetup(0, { driverType: "postgres" })).toThrow(/driverType must be one of/); + }); +}); diff --git a/internal/static/tests/stubs/k6_execution.ts b/internal/static/tests/stubs/k6_execution.ts new file mode 100644 index 00000000..5790dac5 --- /dev/null +++ b/internal/static/tests/stubs/k6_execution.ts @@ -0,0 +1,8 @@ +export const test = { + fail(message?: string): never { + throw new Error(message ?? "test.fail"); + }, + abort(message?: string): never { + throw new Error(message ?? "test.abort"); + }, +}; diff --git a/internal/static/tests/stubs/k6_metrics.ts b/internal/static/tests/stubs/k6_metrics.ts new file mode 100644 index 00000000..9a39135b --- /dev/null +++ b/internal/static/tests/stubs/k6_metrics.ts @@ -0,0 +1,14 @@ +export class Counter { + constructor(_name: string) {} + add(_value: number, _tags?: Record): void {} +} + +export class Rate { + constructor(_name: string) {} + add(_value: number, _tags?: Record): void {} +} + +export class Trend { + constructor(_name: string, _isTime?: boolean) {} + add(_value: number, _tags?: Record): void {} +} diff --git a/internal/static/tests/stubs/k6_x_encoding.ts b/internal/static/tests/stubs/k6_x_encoding.ts new file mode 100644 index 00000000..6bde309c --- /dev/null +++ b/internal/static/tests/stubs/k6_x_encoding.ts @@ -0,0 +1,4 @@ +export default { + TextEncoder: globalThis.TextEncoder, + TextDecoder: globalThis.TextDecoder, +}; diff --git a/internal/static/tests/stubs/k6_x_stroppy.ts b/internal/static/tests/stubs/k6_x_stroppy.ts new file mode 100644 index 00000000..595b570a --- /dev/null +++ b/internal/static/tests/stubs/k6_x_stroppy.ts @@ -0,0 +1,40 @@ +const stats = { + elapsed: { + milliseconds: () => 0, + seconds: () => 0, + microseconds: () => 0, + nanoseconds: () => 0, + string: () => "0s", + }, +}; + +const result = { + stats, + rows: { + columns: () => [], + close: () => null, + err: () => null, + next: () => false, + values: () => [], + readAll: () => [], + }, +}; + +export function NewDriver() { + return { + setup(): void {}, + runQuery: () => result, + insertSpecBin: () => stats, + begin: () => ({ + runQuery: () => result, + commit(): void {}, + rollback(): void {}, + }), + }; +} + +export function NotifyStep(_name: string, _status?: number): void {} +export function DeclareEnv(_names: string[], _defaultValue: string, _description: string): void {} +export function Once any>(fn?: F): F { + return (fn ?? ((() => undefined) as F)); +} diff --git a/internal/static/tests/tsconfig.json b/internal/static/tests/tsconfig.json index d2cf8a01..c0480128 100644 --- a/internal/static/tests/tsconfig.json +++ b/internal/static/tests/tsconfig.json @@ -14,9 +14,14 @@ "allowImportingTsExtensions": true, "noEmit": true, "rootDir": "..", - "baseUrl": ".." + "baseUrl": "..", + "paths": { + "k6/metrics": ["tests/stubs/k6_metrics.ts"], + "k6/execution": ["tests/stubs/k6_execution.ts"], + "k6/x/encoding": ["tests/stubs/k6_x_encoding.ts"], + "k6/x/stroppy": ["tests/stubs/k6_x_stroppy.ts"] + } }, "include": ["../**/*.ts", "**/*.ts"], "exclude": ["node_modules"] } - diff --git a/internal/static/tests/vitest.config.ts b/internal/static/tests/vitest.config.ts index e9cbdede..18b8df40 100644 --- a/internal/static/tests/vitest.config.ts +++ b/internal/static/tests/vitest.config.ts @@ -11,7 +11,10 @@ export default defineConfig({ alias: { // Ensure proper resolution of .js imports from .ts files './stroppy.pb.js': path.resolve(__dirname, '../stroppy.pb.ts'), + 'k6/metrics': path.resolve(__dirname, 'stubs/k6_metrics.ts'), + 'k6/execution': path.resolve(__dirname, 'stubs/k6_execution.ts'), + 'k6/x/encoding': path.resolve(__dirname, 'stubs/k6_x_encoding.ts'), + 'k6/x/stroppy': path.resolve(__dirname, 'stubs/k6_x_stroppy.ts'), }, }, }); - From d81745275a7f28844e57a7d042098d156f46c482 Mon Sep 17 00:00:00 2001 From: Nikita Aleksandrov Date: Thu, 30 Apr 2026 02:56:35 +0300 Subject: [PATCH 3/8] refactor(ts): remove extra --- internal/static/helpers.ts | 137 +++++++++---------- internal/static/tests/helpers.test.ts | 55 -------- internal/static/tests/stubs/k6_execution.ts | 8 -- internal/static/tests/stubs/k6_metrics.ts | 14 -- internal/static/tests/stubs/k6_x_encoding.ts | 4 - internal/static/tests/stubs/k6_x_stroppy.ts | 40 ------ internal/static/tests/tsconfig.json | 6 +- internal/static/tests/vitest.config.ts | 4 - 8 files changed, 63 insertions(+), 205 deletions(-) delete mode 100644 internal/static/tests/helpers.test.ts delete mode 100644 internal/static/tests/stubs/k6_execution.ts delete mode 100644 internal/static/tests/stubs/k6_metrics.ts delete mode 100644 internal/static/tests/stubs/k6_x_encoding.ts delete mode 100644 internal/static/tests/stubs/k6_x_stroppy.ts diff --git a/internal/static/helpers.ts b/internal/static/helpers.ts index 1a056a34..98780a4c 100644 --- a/internal/static/helpers.ts +++ b/internal/static/helpers.ts @@ -333,102 +333,87 @@ export type DriverSetup = Omit, "errorMode" | "driverType" sql?: Partial; } -type ScalarKind = "string" | "number" | "boolean"; -type FieldRule = ScalarKind | { enum: Set } | { object: Schema }; -type Schema = Record; - -const poolSchema: Schema = { - maxConns: "number", - minConns: "number", - maxConnLifetime: "string", - maxConnIdleTime: "string", +const driverSetupKeys = new Set([ + "url", + "driverType", + "defaultTxIsolation", + "defaultInsertMethod", + "errorMode", + "pool", + "postgres", + "sql", + "bulkSize", + "caCertFile", + "authToken", + "authUser", + "authPassword", + "tlsInsecureSkipVerify", +]); + +const nestedDriverSetupKeys: Record> = { + pool: new Set(["maxConns", "minConns", "maxConnLifetime", "maxConnIdleTime"]), + postgres: new Set([ + "traceLogLevel", + "maxConnLifetime", + "maxConnIdleTime", + "maxConns", + "minConns", + "minIdleConns", + "defaultQueryExecMode", + "descriptionCacheCapacity", + "statementCacheCapacity", + ]), + sql: new Set(["maxOpenConns", "maxIdleConns", "connMaxLifetime", "connMaxIdleTime"]), }; -const postgresSchema: Schema = { - traceLogLevel: "string", - maxConnLifetime: "string", - maxConnIdleTime: "string", - maxConns: "number", - minConns: "number", - minIdleConns: "number", - defaultQueryExecMode: "string", - descriptionCacheCapacity: "number", - statementCacheCapacity: "number", -}; - -const sqlSchema: Schema = { - maxOpenConns: "number", - maxIdleConns: "number", - connMaxLifetime: "string", - connMaxIdleTime: "string", -}; - -const driverSetupSchema: Schema = { - url: "string", - driverType: { enum: new Set(Object.keys(driverTypeMap)) }, - defaultTxIsolation: { enum: new Set(Object.keys(txIsolationMap)) }, - defaultInsertMethod: { enum: new Set(Object.keys(insertMethodMap)) }, - errorMode: { enum: new Set(Object.keys(errorModeMap)) }, - pool: { object: poolSchema }, - postgres: { object: postgresSchema }, - sql: { object: sqlSchema }, - bulkSize: "number", - caCertFile: "string", - authToken: "string", - authUser: "string", - authPassword: "string", - tlsInsecureSkipVerify: "boolean", +const driverSetupEnums: Record> = { + driverType: new Set(Object.keys(driverTypeMap)), + defaultTxIsolation: new Set(Object.keys(txIsolationMap)), + defaultInsertMethod: new Set(Object.keys(insertMethodMap)), + errorMode: new Set(Object.keys(errorModeMap)), }; function isPlainObject(value: unknown): value is Record { return typeof value === "object" && value !== null && !Array.isArray(value); } -function schemaKeys(schema: Schema): string { - return Object.keys(schema).sort().join(", "); +function allowedList(keys: Set): string { + return [...keys].sort().join(", "); } -function validateSchema(source: string, path: string, value: unknown, schema: Schema): void { - if (!isPlainObject(value)) { - throw new Error(`${source}: ${path} must be object`); +function validateEnum(source: string, path: string, value: unknown, allowed: Set): void { + if (value !== undefined && (typeof value !== "string" || !allowed.has(value))) { + throw new Error(`${source}: ${path} must be one of: ${allowedList(allowed)}`); } +} - for (const [key, nested] of Object.entries(value)) { - if (nested === undefined) continue; +export function validateDriverSetup(source: string, setup: unknown): asserts setup is DriverSetup { + if (!isPlainObject(setup)) { + throw new Error(`${source}: driver setup must be object`); + } - const rule = schema[key]; - if (!rule) { - throw new Error(`${source}: unknown ${path} option "${key}" (available: ${schemaKeys(schema)})`); + for (const [key, value] of Object.entries(setup)) { + if (!driverSetupKeys.has(key)) { + throw new Error(`${source}: unknown driver option "${key}" (available: ${allowedList(driverSetupKeys)})`); } - const fieldPath = path === "driver" ? key : `${path}.${key}`; - if (typeof rule === "string") { - validateKind(source, fieldPath, nested, rule); - } else if ("enum" in rule) { - validateEnum(source, fieldPath, nested, rule.enum); - } else { - validateSchema(source, fieldPath, nested, rule.object); - } - } -} + const enumValues = driverSetupEnums[key]; + if (enumValues) validateEnum(source, key, value, enumValues); -function validateKind(source: string, path: string, value: unknown, kind: ScalarKind): void { - if (typeof value !== kind) { - throw new Error(`${source}: ${path} must be ${kind}, got ${typeof value}`); - } -} + const nestedKeys = nestedDriverSetupKeys[key]; + if (!nestedKeys || value === undefined) continue; + if (!isPlainObject(value)) { + throw new Error(`${source}: ${key} must be object`); + } -function validateEnum(source: string, path: string, value: unknown, allowed: Set): void { - validateKind(source, path, value, "string"); - if (!allowed.has(value as string)) { - throw new Error(`${source}: ${path} must be one of: ${[...allowed].sort().join(", ")}`); + for (const nestedKey of Object.keys(value)) { + if (!nestedKeys.has(nestedKey)) { + throw new Error(`${source}: unknown ${key} option "${nestedKey}" (available: ${allowedList(nestedKeys)})`); + } + } } } -export function validateDriverSetup(source: string, setup: unknown): asserts setup is DriverSetup { - validateSchema(source, "driver", setup, driverSetupSchema); -} - function mergeDriverSetup(defaults: DriverSetup, cli: Partial): DriverSetup { const merged: Record = { ...defaults }; for (const [key, value] of Object.entries(cli)) { diff --git a/internal/static/tests/helpers.test.ts b/internal/static/tests/helpers.test.ts deleted file mode 100644 index 48ecd812..00000000 --- a/internal/static/tests/helpers.test.ts +++ /dev/null @@ -1,55 +0,0 @@ -import { describe, expect, it, vi } from "vitest"; - -async function loadHelpers(env: Record = {}) { - vi.resetModules(); - (globalThis as unknown as { __ENV: Record }).__ENV = env; - - return import("../helpers.ts"); -} - -describe("DriverSetup validation", () => { - it("accepts valid pool overrides and deep-merges them over defaults", async () => { - const { declareDriverSetup } = await loadHelpers({ - STROPPY_DRIVER_0: JSON.stringify({ pool: { maxConns: 20 } }), - }); - - const setup = declareDriverSetup(0, { - driverType: "postgres", - pool: { maxConns: 5, minConns: 5 }, - }); - - expect(setup.pool).toEqual({ maxConns: 20, minConns: 5 }); - }); - - it("rejects unknown pool keys transported by the CLI", async () => { - const { declareDriverSetup } = await loadHelpers({ - STROPPY_DRIVER_0: JSON.stringify({ pool: { maximum: 20 } }), - }); - - expect(() => declareDriverSetup(0, { driverType: "postgres" })).toThrow(/unknown pool option "maximum"/); - }); - - it("rejects unknown top-level keys", async () => { - const { declareDriverSetup } = await loadHelpers({ - STROPPY_DRIVER_0: JSON.stringify({ tls: { insecureSkipVerify: true } }), - }); - - expect(() => declareDriverSetup(0, { driverType: "postgres" })).toThrow(/unknown driver option "tls"/); - }); - - it("rejects invalid primitive types", async () => { - const { declareDriverSetup } = await loadHelpers({ - STROPPY_DRIVER_0: JSON.stringify({ pool: { maxConns: "20" } }), - }); - - expect(() => declareDriverSetup(0, { driverType: "postgres" })).toThrow(/pool\.maxConns must be number/); - }); - - it("rejects invalid enum values", async () => { - const { declareDriverSetup } = await loadHelpers({ - STROPPY_DRIVER_0: JSON.stringify({ driverType: "oracle" }), - }); - - expect(() => declareDriverSetup(0, { driverType: "postgres" })).toThrow(/driverType must be one of/); - }); -}); diff --git a/internal/static/tests/stubs/k6_execution.ts b/internal/static/tests/stubs/k6_execution.ts deleted file mode 100644 index 5790dac5..00000000 --- a/internal/static/tests/stubs/k6_execution.ts +++ /dev/null @@ -1,8 +0,0 @@ -export const test = { - fail(message?: string): never { - throw new Error(message ?? "test.fail"); - }, - abort(message?: string): never { - throw new Error(message ?? "test.abort"); - }, -}; diff --git a/internal/static/tests/stubs/k6_metrics.ts b/internal/static/tests/stubs/k6_metrics.ts deleted file mode 100644 index 9a39135b..00000000 --- a/internal/static/tests/stubs/k6_metrics.ts +++ /dev/null @@ -1,14 +0,0 @@ -export class Counter { - constructor(_name: string) {} - add(_value: number, _tags?: Record): void {} -} - -export class Rate { - constructor(_name: string) {} - add(_value: number, _tags?: Record): void {} -} - -export class Trend { - constructor(_name: string, _isTime?: boolean) {} - add(_value: number, _tags?: Record): void {} -} diff --git a/internal/static/tests/stubs/k6_x_encoding.ts b/internal/static/tests/stubs/k6_x_encoding.ts deleted file mode 100644 index 6bde309c..00000000 --- a/internal/static/tests/stubs/k6_x_encoding.ts +++ /dev/null @@ -1,4 +0,0 @@ -export default { - TextEncoder: globalThis.TextEncoder, - TextDecoder: globalThis.TextDecoder, -}; diff --git a/internal/static/tests/stubs/k6_x_stroppy.ts b/internal/static/tests/stubs/k6_x_stroppy.ts deleted file mode 100644 index 595b570a..00000000 --- a/internal/static/tests/stubs/k6_x_stroppy.ts +++ /dev/null @@ -1,40 +0,0 @@ -const stats = { - elapsed: { - milliseconds: () => 0, - seconds: () => 0, - microseconds: () => 0, - nanoseconds: () => 0, - string: () => "0s", - }, -}; - -const result = { - stats, - rows: { - columns: () => [], - close: () => null, - err: () => null, - next: () => false, - values: () => [], - readAll: () => [], - }, -}; - -export function NewDriver() { - return { - setup(): void {}, - runQuery: () => result, - insertSpecBin: () => stats, - begin: () => ({ - runQuery: () => result, - commit(): void {}, - rollback(): void {}, - }), - }; -} - -export function NotifyStep(_name: string, _status?: number): void {} -export function DeclareEnv(_names: string[], _defaultValue: string, _description: string): void {} -export function Once any>(fn?: F): F { - return (fn ?? ((() => undefined) as F)); -} diff --git a/internal/static/tests/tsconfig.json b/internal/static/tests/tsconfig.json index c0480128..4ecff88e 100644 --- a/internal/static/tests/tsconfig.json +++ b/internal/static/tests/tsconfig.json @@ -16,10 +16,8 @@ "rootDir": "..", "baseUrl": "..", "paths": { - "k6/metrics": ["tests/stubs/k6_metrics.ts"], - "k6/execution": ["tests/stubs/k6_execution.ts"], - "k6/x/encoding": ["tests/stubs/k6_x_encoding.ts"], - "k6/x/stroppy": ["tests/stubs/k6_x_stroppy.ts"] + "k6/x/stroppy": ["stroppy.d.ts"], + "k6/x/encoding": ["encoding.d.ts"] } }, "include": ["../**/*.ts", "**/*.ts"], diff --git a/internal/static/tests/vitest.config.ts b/internal/static/tests/vitest.config.ts index 18b8df40..7926bd16 100644 --- a/internal/static/tests/vitest.config.ts +++ b/internal/static/tests/vitest.config.ts @@ -11,10 +11,6 @@ export default defineConfig({ alias: { // Ensure proper resolution of .js imports from .ts files './stroppy.pb.js': path.resolve(__dirname, '../stroppy.pb.ts'), - 'k6/metrics': path.resolve(__dirname, 'stubs/k6_metrics.ts'), - 'k6/execution': path.resolve(__dirname, 'stubs/k6_execution.ts'), - 'k6/x/encoding': path.resolve(__dirname, 'stubs/k6_x_encoding.ts'), - 'k6/x/stroppy': path.resolve(__dirname, 'stubs/k6_x_stroppy.ts'), }, }, }); From a02e1578bf53e962b5b48dba43572013b1a86b45 Mon Sep 17 00:00:00 2001 From: Nikita Aleksandrov Date: Thu, 30 Apr 2026 03:24:46 +0300 Subject: [PATCH 4/8] fix(cli): exit codes capturing --- cmd/stroppy/commands/k6signals_test.go | 20 ++++++++++++++++ cmd/stroppy/commands/root.go | 4 +++- cmd/stroppy/commands/run/run.go | 22 ++++++++++------- internal/runner/k6_exit.go | 33 ++++++++++++++++++++++---- internal/runner/script_runner.go | 2 ++ 5 files changed, 67 insertions(+), 14 deletions(-) diff --git a/cmd/stroppy/commands/k6signals_test.go b/cmd/stroppy/commands/k6signals_test.go index e9d38031..01207f40 100644 --- a/cmd/stroppy/commands/k6signals_test.go +++ b/cmd/stroppy/commands/k6signals_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/require" "go.k6.io/k6/cmd/state" + + "github.com/stroppy-io/stroppy/internal/runner" ) func sendInt(t *testing.T) { @@ -69,3 +71,21 @@ func TestInterceptMiddleware(_ *testing.T) { // SignalStop should not panic and should clean up. gs.SignalStop(ch) } + +func TestK6SubcommandScopesExitCapture(t *testing.T) { + called := false + gs := &state.GlobalState{OSExit: func(int) { called = true }} + + K6Subcommand(gs) + gs.OSExit(1) + require.True(t, called) + + called = false + stop := runner.BeginK6ExitCapture() + defer stop() + + gs = &state.GlobalState{OSExit: func(int) { called = true }} + K6Subcommand(gs) + gs.OSExit(1) + require.False(t, called) +} diff --git a/cmd/stroppy/commands/root.go b/cmd/stroppy/commands/root.go index d619d183..3b411130 100644 --- a/cmd/stroppy/commands/root.go +++ b/cmd/stroppy/commands/root.go @@ -85,7 +85,9 @@ func Root() *cobra.Command { func K6Subcommand(gs *state.GlobalState) *cobra.Command { inteceptInteruptSignals(gs) - gs.OSExit = runner.OSExit // handles exit code + if runner.K6ExitCaptureEnabled() { + gs.OSExit = runner.OSExit + } return rootCmd } diff --git a/cmd/stroppy/commands/run/run.go b/cmd/stroppy/commands/run/run.go index 79673304..3472784c 100644 --- a/cmd/stroppy/commands/run/run.go +++ b/cmd/stroppy/commands/run/run.go @@ -10,6 +10,8 @@ import ( "strings" "github.com/spf13/cobra" + "go.k6.io/k6/errext" + "go.k6.io/k6/errext/exitcodes" "go.uber.org/zap" "github.com/stroppy-io/stroppy/internal/runner" @@ -92,13 +94,13 @@ Config file flags: parsed, err := parseRunArgs(args) if err != nil { - return err + return invalidConfig(err) } // Load config file if -f is specified or stroppy-config.json exists. fileConfig, _, err := runner.LoadRunConfig(parsed.fileArg) if err != nil { - return fmt.Errorf("failed to load config file: %w", err) + return invalidConfig(fmt.Errorf("failed to load config file: %w", err)) } // Apply effective values: CLI overrides config file. @@ -109,7 +111,7 @@ Config file flags: k6RunArgs := runner.EffectiveK6Args(parsed.afterDash, fileConfig) if scriptArg == "" { - return errNoScript + return invalidConfig(errNoScript) } // Log override decisions when both CLI and file config are present. @@ -141,21 +143,21 @@ Config file flags: // Resolve -e overrides (uppercase keys, validate format). envOverrides, err := runner.ResolveEnvOverrides(parsed.envArgs) if err != nil { - return err + return invalidConfig(err) } driverConfigs := runner.DriverCLIConfigs{} for idx, presetName := range parsed.driverPresets { if err := applyDriverPreset(driverConfigs, idx, presetName); err != nil { - return err + return invalidConfig(err) } } for idx, opts := range parsed.driverOpts { for _, kv := range opts { if err := applyDriverOpt(driverConfigs, idx, kv[0], kv[1]); err != nil { - return err + return invalidConfig(err) } } } @@ -163,7 +165,7 @@ Config file flags: // Resolve files through search path. input, err := runner.ResolveInput(scriptArg, sqlArg) if err != nil { - return fmt.Errorf("failed to resolve input: %w", err) + return invalidConfig(fmt.Errorf("failed to resolve input: %w", err)) } scriptRunner, err := runner.NewScriptRunner( @@ -176,7 +178,7 @@ Config file flags: fileConfig, ) if err != nil { - return fmt.Errorf("failed to create runner: %w", err) + return invalidConfig(fmt.Errorf("failed to create runner: %w", err)) } err = scriptRunner.Run(context.Background()) @@ -194,6 +196,10 @@ Config file flags: }, } +func invalidConfig(err error) error { + return errext.WithExitCodeIfNone(err, exitcodes.InvalidConfig) +} + // runArgs holds the result of parseRunArgs. type runArgs struct { scriptArg string diff --git a/internal/runner/k6_exit.go b/internal/runner/k6_exit.go index 930c4aa9..26b68732 100644 --- a/internal/runner/k6_exit.go +++ b/internal/runner/k6_exit.go @@ -1,6 +1,9 @@ package runner -import "fmt" +import ( + "fmt" + "sync/atomic" +) type ExitError struct { Code int @@ -8,19 +11,39 @@ type ExitError struct { func (e *ExitError) Error() string { return fmt.Sprintf("k6 exited with code %d", e.Code) } -var exitCode int = -1 +var ( + captureK6Exit atomic.Bool + exitCode atomic.Int32 +) + +func init() { + exitCode.Store(-1) +} + +func BeginK6ExitCapture() func() { + exitCode.Store(-1) + captureK6Exit.Store(true) + + return func() { + captureK6Exit.Store(false) + } +} + +func K6ExitCaptureEnabled() bool { + return captureK6Exit.Load() +} func exitCodeToError() error { - switch exitCode { + switch code := int(exitCode.Load()); code { case -1: panic("unreachable; k6 must to set an exit code") case 0: // do nothing, all correct return nil default: - return &ExitError{Code: exitCode} + return &ExitError{Code: code} } } func OSExit(i int) { - exitCode = i + exitCode.Store(int32(i)) } diff --git a/internal/runner/script_runner.go b/internal/runner/script_runner.go index 74e74845..488a6e7b 100644 --- a/internal/runner/script_runner.go +++ b/internal/runner/script_runner.go @@ -474,6 +474,8 @@ func (r *ScriptRunner) runK6( r.logger.Info("Running k6", zap.Strings("args", os.Args)) // run the test + stopExitCapture := BeginK6ExitCapture() + defer stopExitCapture() k6cmd.Execute() defer func() { err = errors.Join(err, exitCodeToError()) }() From 60f896f708d8760f01eef6f14e05839513e6360d Mon Sep 17 00:00:00 2001 From: Nikita Aleksandrov Date: Thu, 30 Apr 2026 03:51:10 +0300 Subject: [PATCH 5/8] feat(logging): extensive envs piplene logging --- cmd/stroppy/commands/help/topic_envs.go | 13 +- internal/runner/env_log.go | 151 ++++++++++++++++++++++++ internal/runner/env_log_test.go | 65 ++++++++++ internal/runner/script_runner.go | 20 +++- pkg/common/logger/logger.go | 2 + pkg/common/logger/logger_test.go | 1 + 6 files changed, 249 insertions(+), 3 deletions(-) create mode 100644 internal/runner/env_log.go create mode 100644 internal/runner/env_log_test.go diff --git a/cmd/stroppy/commands/help/topic_envs.go b/cmd/stroppy/commands/help/topic_envs.go index f89477bd..f6b027b7 100644 --- a/cmd/stroppy/commands/help/topic_envs.go +++ b/cmd/stroppy/commands/help/topic_envs.go @@ -150,7 +150,18 @@ CONFIG FILE ALTERNATIVE DEBUG: TRACING ENV RESOLUTION - To see which env vars are applied and where they came from: + Normal run logs show -e/config env entries that were applied: + + stroppy run tpcc/tx -e load_workers=8 + + INFO script_runner Applied script env {"source":"cli","env":["LOAD_WORKERS=8"]} + + If an applied env key is not declared by the workload's ENV() calls, Stroppy + warns so typos are visible: + + WARN script_runner Script env is not declared by workload {"keys":["LOAD_WROKERS"]} + + To inspect skipped env entries and precedence decisions: LOG_LEVEL=debug stroppy run