From 6c80f238285d94fe58aa52a455917d10d2b31b7b Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Thu, 11 Jun 2026 18:52:04 +0000 Subject: [PATCH 01/18] experimental/air: scaffold AI runtime CLI command package Add the experimental `air` command group as the Go port surface for the Python `air` CLI. Every subcommand (run, status, list, logs, cancel, register-image) is registered as a stub that returns a not-implemented error; the real implementations land in later milestones. The package lives under experimental/air/cmd (imported as aircmd), matching the layout of the other experimental features (aitools, genie, postgres); cmd/experimental/ keeps only the dispatcher. TEST_PACKAGES in Taskfile.yml gains ./experimental/air/... so the unit tests keep running after the move. Includes unit tests for the command-tree wiring and the not-implemented stubs, plus an acceptance test exercising the stubs end-to-end. Co-authored-by: Isaac --- Taskfile.yml | 2 +- .../experimental/air/help/out.test.toml | 3 ++ acceptance/experimental/air/help/output.txt | 29 ++++++++++++++ acceptance/experimental/air/help/script | 5 +++ acceptance/experimental/air/help/test.toml | 3 ++ .../air/unimplemented/out.test.toml | 3 ++ .../experimental/air/unimplemented/output.txt | 36 +++++++++++++++++ .../experimental/air/unimplemented/script | 19 +++++++++ .../experimental/air/unimplemented/test.toml | 3 ++ cmd/experimental/experimental.go | 2 + experimental/air/cmd/air.go | 36 +++++++++++++++++ experimental/air/cmd/air_test.go | 22 +++++++++++ experimental/air/cmd/cancel.go | 39 +++++++++++++++++++ experimental/air/cmd/list.go | 31 +++++++++++++++ experimental/air/cmd/logs.go | 34 ++++++++++++++++ experimental/air/cmd/register_image.go | 33 ++++++++++++++++ experimental/air/cmd/run.go | 36 +++++++++++++++++ experimental/air/cmd/status.go | 19 +++++++++ experimental/air/cmd/stubs_test.go | 31 +++++++++++++++ 19 files changed, 385 insertions(+), 1 deletion(-) create mode 100644 acceptance/experimental/air/help/out.test.toml create mode 100644 acceptance/experimental/air/help/output.txt create mode 100644 acceptance/experimental/air/help/script create mode 100644 acceptance/experimental/air/help/test.toml create mode 100644 acceptance/experimental/air/unimplemented/out.test.toml create mode 100644 acceptance/experimental/air/unimplemented/output.txt create mode 100644 acceptance/experimental/air/unimplemented/script create mode 100644 acceptance/experimental/air/unimplemented/test.toml create mode 100644 experimental/air/cmd/air.go create mode 100644 experimental/air/cmd/air_test.go create mode 100644 experimental/air/cmd/cancel.go create mode 100644 experimental/air/cmd/list.go create mode 100644 experimental/air/cmd/logs.go create mode 100644 experimental/air/cmd/register_image.go create mode 100644 experimental/air/cmd/run.go create mode 100644 experimental/air/cmd/status.go create mode 100644 experimental/air/cmd/stubs_test.go diff --git a/Taskfile.yml b/Taskfile.yml index d72140290e2..32cb14d0c43 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -4,7 +4,7 @@ vars: # Absolute path so tasks with `dir:` (lint-go-tools, lint-go-codegen) can use it. GO_TOOL: go tool -modfile={{.ROOT_DIR}}/tools/go.mod EXE_EXT: '{{if eq OS "windows"}}.exe{{end}}' - TEST_PACKAGES: ./acceptance/internal ./libs/... ./internal/... ./cmd/... ./bundle/... ./experimental/ssh/... . + TEST_PACKAGES: ./acceptance/internal ./libs/... ./internal/... ./cmd/... ./bundle/... ./experimental/air/... ./experimental/ssh/... . ACCEPTANCE_TEST_FILTER: "" # Single brace-expansion glob covering every //go:embed target in the repo, # computed by grepping `//go:embed` directives. Evaluated lazily by Task so diff --git a/acceptance/experimental/air/help/out.test.toml b/acceptance/experimental/air/help/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/experimental/air/help/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/help/output.txt b/acceptance/experimental/air/help/output.txt new file mode 100644 index 00000000000..cf7e5af634c --- /dev/null +++ b/acceptance/experimental/air/help/output.txt @@ -0,0 +1,29 @@ + +=== help +>>> [CLI] experimental air --help +Run and manage AI runtime training workloads on Databricks serverless GPU compute. + +This command set is the Go port of the standalone Python "air" CLI. It is +experimental and may change in future versions. + +Usage: + databricks experimental air [command] + +Available Commands: + cancel Cancel one or more runs + list List recent runs + logs Stream or fetch logs for a run + register-image Mirror a Docker image into the workspace registry + run Submit a training workload from a YAML config + status Show status and configuration for a run + +Flags: + -h, --help help for air + +Global Flags: + --debug enable debug logging + -o, --output type output type: text or json (default text) + -p, --profile string ~/.databrickscfg profile + -t, --target string bundle target to use (if applicable) + +Use "databricks experimental air [command] --help" for more information about a command. diff --git a/acceptance/experimental/air/help/script b/acceptance/experimental/air/help/script new file mode 100644 index 00000000000..cd67a6fc1b1 --- /dev/null +++ b/acceptance/experimental/air/help/script @@ -0,0 +1,5 @@ +# Pin the command tree so any change to a subcommand or its short description +# shows up as a diff here. + +title "help" +trace $CLI experimental air --help diff --git a/acceptance/experimental/air/help/test.toml b/acceptance/experimental/air/help/test.toml new file mode 100644 index 00000000000..49709b578ef --- /dev/null +++ b/acceptance/experimental/air/help/test.toml @@ -0,0 +1,3 @@ +# --help prints without authenticating, so no server stubs are needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/unimplemented/out.test.toml b/acceptance/experimental/air/unimplemented/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/experimental/air/unimplemented/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt new file mode 100644 index 00000000000..3dc88de3b77 --- /dev/null +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -0,0 +1,36 @@ + +=== run +>>> [CLI] experimental air run +Error: `air run` is not implemented yet + +Exit code: 1 + +=== status +>>> [CLI] experimental air status 123 +Error: `air status` is not implemented yet + +Exit code: 1 + +=== list +>>> [CLI] experimental air list +Error: `air list` is not implemented yet + +Exit code: 1 + +=== logs +>>> [CLI] experimental air logs 123 +Error: `air logs` is not implemented yet + +Exit code: 1 + +=== cancel +>>> [CLI] experimental air cancel 123 +Error: `air cancel` is not implemented yet + +Exit code: 1 + +=== register-image +>>> [CLI] experimental air register-image my-image:latest +Error: `air register-image` is not implemented yet + +Exit code: 1 diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script new file mode 100644 index 00000000000..83397b4b741 --- /dev/null +++ b/acceptance/experimental/air/unimplemented/script @@ -0,0 +1,19 @@ +# Each stub must fail with "not implemented"; errcode records the exit code. + +title "run" +errcode trace $CLI experimental air run + +title "status" +errcode trace $CLI experimental air status 123 + +title "list" +errcode trace $CLI experimental air list + +title "logs" +errcode trace $CLI experimental air logs 123 + +title "cancel" +errcode trace $CLI experimental air cancel 123 + +title "register-image" +errcode trace $CLI experimental air register-image my-image:latest diff --git a/acceptance/experimental/air/unimplemented/test.toml b/acceptance/experimental/air/unimplemented/test.toml new file mode 100644 index 00000000000..c233c30a86c --- /dev/null +++ b/acceptance/experimental/air/unimplemented/test.toml @@ -0,0 +1,3 @@ +# Stubs fail locally before any API call, so no server stubs needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] diff --git a/cmd/experimental/experimental.go b/cmd/experimental/experimental.go index 8d9827c5c94..d87c893abc5 100644 --- a/cmd/experimental/experimental.go +++ b/cmd/experimental/experimental.go @@ -1,6 +1,7 @@ package experimental import ( + aircmd "github.com/databricks/cli/experimental/air/cmd" aitoolscmd "github.com/databricks/cli/experimental/aitools/cmd" geniecmd "github.com/databricks/cli/experimental/genie/cmd" postgrescmd "github.com/databricks/cli/experimental/postgres/cmd" @@ -22,6 +23,7 @@ These commands provide early access to new features that are still under development. They may change or be removed in future versions without notice.`, } + cmd.AddCommand(aircmd.New()) cmd.AddCommand(aitoolscmd.NewAitoolsCmd()) cmd.AddCommand(geniecmd.NewGenieCmd()) cmd.AddCommand(postgrescmd.New()) diff --git a/experimental/air/cmd/air.go b/experimental/air/cmd/air.go new file mode 100644 index 00000000000..3f9122c828c --- /dev/null +++ b/experimental/air/cmd/air.go @@ -0,0 +1,36 @@ +package aircmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +// New returns the root command for the experimental AI runtime CLI. +// +// Milestone 0: scaffolds the command group with every subcommand registered as a +// stub (not yet implemented), pending the port from the Python `air` CLI. +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "air", + Short: "Run and manage AI runtime training workloads", + Long: `Run and manage AI runtime training workloads on Databricks serverless GPU compute. + +This command set is the Go port of the standalone Python "air" CLI. It is +experimental and may change in future versions.`, + } + + cmd.AddCommand(newRunCommand()) + cmd.AddCommand(newStatusCommand()) + cmd.AddCommand(newListCommand()) + cmd.AddCommand(newLogsCommand()) + cmd.AddCommand(newCancelCommand()) + cmd.AddCommand(newRegisterImageCommand()) + + return cmd +} + +// notImplemented returns the placeholder error used by milestone-0 stubs. +func notImplemented(name string) error { + return fmt.Errorf("`air %s` is not implemented yet", name) +} diff --git a/experimental/air/cmd/air_test.go b/experimental/air/cmd/air_test.go new file mode 100644 index 00000000000..26268690850 --- /dev/null +++ b/experimental/air/cmd/air_test.go @@ -0,0 +1,22 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestNewRegistersAllSubcommands asserts the `air` command wires up every +// expected subcommand, so none is accidentally dropped from New. +func TestNewRegistersAllSubcommands(t *testing.T) { + registered := make(map[string]bool) + for _, c := range New().Commands() { + registered[c.Name()] = true + } + + want := []string{"run", "status", "list", "logs", "cancel", "register-image"} + for _, name := range want { + assert.True(t, registered[name], "subcommand %q is not registered", name) + } + assert.Len(t, registered, len(want), "unexpected number of subcommands") +} diff --git a/experimental/air/cmd/cancel.go b/experimental/air/cmd/cancel.go new file mode 100644 index 00000000000..ad7fffc7125 --- /dev/null +++ b/experimental/air/cmd/cancel.go @@ -0,0 +1,39 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newCancelCommand() *cobra.Command { + var ( + all bool + yes bool + ) + + cmd := &cobra.Command{ + Use: "cancel [RUN_ID...]", + Short: "Cancel one or more runs", + Long: `Cancel one or more runs by ID, or cancel all of your active runs with --all.`, + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("cancel") + }, + } + + cmd.Flags().BoolVar(&all, "all", false, "Cancel all of your active runs") + cmd.Flags().BoolVarP(&yes, "yes", "y", false, "Skip the confirmation prompt") + + // Require exactly one of: one or more RUN_IDs, or --all. Cobra parses flags + // before running this, so `all` reflects the user's input. + cmd.Args = func(cmd *cobra.Command, args []string) error { + switch { + case all && len(args) > 0: + return &root.InvalidArgsError{Command: cmd, Message: "cannot combine RUN_ID arguments with --all"} + case !all && len(args) == 0: + return &root.InvalidArgsError{Command: cmd, Message: "provide at least one RUN_ID, or use --all"} + } + return nil + } + + return cmd +} diff --git a/experimental/air/cmd/list.go b/experimental/air/cmd/list.go new file mode 100644 index 00000000000..bf24cff9b23 --- /dev/null +++ b/experimental/air/cmd/list.go @@ -0,0 +1,31 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newListCommand() *cobra.Command { + var ( + limit int + active bool + allUsers bool + filters []string + ) + + cmd := &cobra.Command{ + Use: "list", + Args: root.NoArgs, + Short: "List recent runs", + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("list") + }, + } + + cmd.Flags().IntVar(&limit, "limit", 20, "Maximum number of runs to show") + cmd.Flags().BoolVar(&active, "active", false, "Show only active runs") + cmd.Flags().BoolVar(&allUsers, "all-users", false, "Show runs from all users") + cmd.Flags().StringArrayVar(&filters, "filter", nil, "Filter runs, e.g. experiment=foo* (repeatable)") + + return cmd +} diff --git a/experimental/air/cmd/logs.go b/experimental/air/cmd/logs.go new file mode 100644 index 00000000000..4dbbe41c278 --- /dev/null +++ b/experimental/air/cmd/logs.go @@ -0,0 +1,34 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newLogsCommand() *cobra.Command { + var ( + node int + lines int + retry int + downloadTo string + review bool + ) + + cmd := &cobra.Command{ + Use: "logs RUN_ID", + Args: root.ExactArgs(1), + Short: "Stream or fetch logs for a run", + Long: `Stream logs from an active run, or fetch logs from a completed run.`, + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("logs") + }, + } + + cmd.Flags().IntVar(&node, "node", 0, "Fetch logs from this node") + cmd.Flags().IntVar(&lines, "lines", 10000, "For completed runs, print the last N lines") + cmd.Flags().IntVar(&retry, "retry", -1, "View logs from a specific retry attempt; -1 means latest") + cmd.Flags().StringVar(&downloadTo, "download-to", "", "Download all logs to this directory instead of printing") + cmd.Flags().BoolVar(&review, "review", false, "Download logs from all nodes and filter for error signatures") + + return cmd +} diff --git a/experimental/air/cmd/register_image.go b/experimental/air/cmd/register_image.go new file mode 100644 index 00000000000..a5be3df408b --- /dev/null +++ b/experimental/air/cmd/register_image.go @@ -0,0 +1,33 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newRegisterImageCommand() *cobra.Command { + var ( + scope string + key string + interactiveAuth bool + tagPolicy string + timeoutMinutes int + ) + + cmd := &cobra.Command{ + Use: "register-image IMAGE_URL", + Args: root.ExactArgs(1), + Short: "Mirror a Docker image into the workspace registry", + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("register-image") + }, + } + + cmd.Flags().StringVar(&scope, "scope", "", "Databricks secret scope holding registry credentials") + cmd.Flags().StringVar(&key, "key", "", "Databricks secret key holding registry credentials") + cmd.Flags().BoolVar(&interactiveAuth, "interactive-authenticate", false, "Prompt for registry credentials and store them as a secret") + cmd.Flags().StringVar(&tagPolicy, "tag-policy", "auto", "Image resolution policy: auto or latest") + cmd.Flags().IntVar(&timeoutMinutes, "timeout-minutes", 60, "Timeout to wait for the image to become available") + + return cmd +} diff --git a/experimental/air/cmd/run.go b/experimental/air/cmd/run.go new file mode 100644 index 00000000000..0bc3d1fd94b --- /dev/null +++ b/experimental/air/cmd/run.go @@ -0,0 +1,36 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newRunCommand() *cobra.Command { + var ( + file string + watch bool + overrides []string + dryRun bool + idempotencyKey string + ) + + cmd := &cobra.Command{ + Use: "run", + Args: root.NoArgs, + Short: "Submit a training workload from a YAML config", + Long: `Submit a training workload to Databricks serverless GPU compute. + +The workload is described by a YAML config file (see --file).`, + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("run") + }, + } + + cmd.Flags().StringVarP(&file, "file", "f", "", "Path to the workload YAML config") + cmd.Flags().BoolVar(&watch, "watch", false, "Stream logs until the run completes") + cmd.Flags().StringArrayVar(&overrides, "override", nil, "Override a YAML field, e.g. compute.num_accelerators=8 (repeatable)") + cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Validate the config without submitting") + cmd.Flags().StringVar(&idempotencyKey, "idempotency-key", "", "Return the existing run if this key was already used") + + return cmd +} diff --git a/experimental/air/cmd/status.go b/experimental/air/cmd/status.go new file mode 100644 index 00000000000..a0db0619331 --- /dev/null +++ b/experimental/air/cmd/status.go @@ -0,0 +1,19 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newStatusCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "status RUN_ID", + Args: root.ExactArgs(1), + Short: "Show status and configuration for a run", + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("status") + }, + } + + return cmd +} diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go new file mode 100644 index 00000000000..8ffd197973f --- /dev/null +++ b/experimental/air/cmd/stubs_test.go @@ -0,0 +1,31 @@ +package aircmd + +import ( + "fmt" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestStubCommandsReturnNotImplemented asserts each unimplemented subcommand +// fails with a "not implemented" error. Drop a command here once it lands. +func TestStubCommandsReturnNotImplemented(t *testing.T) { + stubs := map[string]*cobra.Command{ + "run": newRunCommand(), + "status": newStatusCommand(), + "list": newListCommand(), + "logs": newLogsCommand(), + "cancel": newCancelCommand(), + "register-image": newRegisterImageCommand(), + } + + for name, cmd := range stubs { + t.Run(name, func(t *testing.T) { + require.NotNil(t, cmd.RunE, "command should define RunE") + err := cmd.RunE(cmd, nil) + assert.EqualError(t, err, fmt.Sprintf("`air %s` is not implemented yet", name)) + }) + } +} From 059bd61ca8e04f4e2ecd4f3c6c92c45f98208b99 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Fri, 12 Jun 2026 23:10:23 +0000 Subject: [PATCH 02/18] experimental/air: rename `status` subcommand to `get` Rename the run-details subcommand from `status` to `get`, matching the Python air CLI's current `air get run` naming (it replaced `get status`). Renames the file, constructor, command name, and updates the stub/help/unimplemented tests and goldens accordingly. Co-authored-by: Isaac --- acceptance/experimental/air/help/output.txt | 2 +- acceptance/experimental/air/unimplemented/output.txt | 6 +++--- acceptance/experimental/air/unimplemented/script | 4 ++-- experimental/air/cmd/air.go | 2 +- experimental/air/cmd/air_test.go | 2 +- experimental/air/cmd/{status.go => get.go} | 8 ++++---- experimental/air/cmd/stubs_test.go | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) rename experimental/air/cmd/{status.go => get.go} (59%) diff --git a/acceptance/experimental/air/help/output.txt b/acceptance/experimental/air/help/output.txt index cf7e5af634c..3a0f86e164f 100644 --- a/acceptance/experimental/air/help/output.txt +++ b/acceptance/experimental/air/help/output.txt @@ -11,11 +11,11 @@ Usage: Available Commands: cancel Cancel one or more runs + get Show details for a run list List recent runs logs Stream or fetch logs for a run register-image Mirror a Docker image into the workspace registry run Submit a training workload from a YAML config - status Show status and configuration for a run Flags: -h, --help help for air diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 3dc88de3b77..4a07a38a378 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -5,9 +5,9 @@ Error: `air run` is not implemented yet Exit code: 1 -=== status ->>> [CLI] experimental air status 123 -Error: `air status` is not implemented yet +=== get +>>> [CLI] experimental air get 123 +Error: `air get` is not implemented yet Exit code: 1 diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index 83397b4b741..2ed885c0e66 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -3,8 +3,8 @@ title "run" errcode trace $CLI experimental air run -title "status" -errcode trace $CLI experimental air status 123 +title "get" +errcode trace $CLI experimental air get 123 title "list" errcode trace $CLI experimental air list diff --git a/experimental/air/cmd/air.go b/experimental/air/cmd/air.go index 3f9122c828c..81ffb2dd346 100644 --- a/experimental/air/cmd/air.go +++ b/experimental/air/cmd/air.go @@ -21,7 +21,7 @@ experimental and may change in future versions.`, } cmd.AddCommand(newRunCommand()) - cmd.AddCommand(newStatusCommand()) + cmd.AddCommand(newGetCommand()) cmd.AddCommand(newListCommand()) cmd.AddCommand(newLogsCommand()) cmd.AddCommand(newCancelCommand()) diff --git a/experimental/air/cmd/air_test.go b/experimental/air/cmd/air_test.go index 26268690850..7efac253a2b 100644 --- a/experimental/air/cmd/air_test.go +++ b/experimental/air/cmd/air_test.go @@ -14,7 +14,7 @@ func TestNewRegistersAllSubcommands(t *testing.T) { registered[c.Name()] = true } - want := []string{"run", "status", "list", "logs", "cancel", "register-image"} + want := []string{"run", "get", "list", "logs", "cancel", "register-image"} for _, name := range want { assert.True(t, registered[name], "subcommand %q is not registered", name) } diff --git a/experimental/air/cmd/status.go b/experimental/air/cmd/get.go similarity index 59% rename from experimental/air/cmd/status.go rename to experimental/air/cmd/get.go index a0db0619331..0ab0b8226bf 100644 --- a/experimental/air/cmd/status.go +++ b/experimental/air/cmd/get.go @@ -5,13 +5,13 @@ import ( "github.com/spf13/cobra" ) -func newStatusCommand() *cobra.Command { +func newGetCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "status RUN_ID", + Use: "get RUN_ID", Args: root.ExactArgs(1), - Short: "Show status and configuration for a run", + Short: "Show details for a run", RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("status") + return notImplemented("get") }, } diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index 8ffd197973f..a6e24177f33 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -14,7 +14,7 @@ import ( func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ "run": newRunCommand(), - "status": newStatusCommand(), + "get": newGetCommand(), "list": newListCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), From 2ccd0697c1c4ac9373685286c96eab0024686860 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 21:38:13 +0000 Subject: [PATCH 03/18] experimental/air: implement the `air get` command Implement the read-only run-details command (renamed from `status` to `get`). It fetches a job run via the Jobs API and renders the run's status, start time, duration, retries, experiment, accelerators, dashboard URL, MLflow deep-link, and a foreach/sweep summary. Output is the air-style {v, ts, data} JSON envelope under -o json, or a text view. Renames the command-level identifiers (status -> get) while keeping the run's "status" field/label. Adds format/mlflow/sweep/output helpers with unit tests and an acceptance test, and drops `get` from the not-implemented stub coverage. Co-authored-by: Isaac --- acceptance/experimental/air/get/out.test.toml | 3 + acceptance/experimental/air/get/output.txt | 36 +++ acceptance/experimental/air/get/script | 8 + acceptance/experimental/air/get/test.toml | 40 ++++ .../experimental/air/unimplemented/output.txt | 6 - .../experimental/air/unimplemented/script | 3 - experimental/air/cmd/format.go | 154 +++++++++++++ experimental/air/cmd/format_test.go | 131 +++++++++++ experimental/air/cmd/get.go | 172 +++++++++++++- experimental/air/cmd/get_test.go | 211 ++++++++++++++++++ experimental/air/cmd/mlflow.go | 65 ++++++ experimental/air/cmd/mlflow_test.go | 64 ++++++ experimental/air/cmd/output.go | 39 ++++ experimental/air/cmd/output_test.go | 13 ++ experimental/air/cmd/stubs_test.go | 1 - experimental/air/cmd/sweep.go | 76 +++++++ experimental/air/cmd/sweep_test.go | 81 +++++++ 17 files changed, 1091 insertions(+), 12 deletions(-) create mode 100644 acceptance/experimental/air/get/out.test.toml create mode 100644 acceptance/experimental/air/get/output.txt create mode 100644 acceptance/experimental/air/get/script create mode 100644 acceptance/experimental/air/get/test.toml create mode 100644 experimental/air/cmd/format.go create mode 100644 experimental/air/cmd/format_test.go create mode 100644 experimental/air/cmd/get_test.go create mode 100644 experimental/air/cmd/mlflow.go create mode 100644 experimental/air/cmd/mlflow_test.go create mode 100644 experimental/air/cmd/output.go create mode 100644 experimental/air/cmd/output_test.go create mode 100644 experimental/air/cmd/sweep.go create mode 100644 experimental/air/cmd/sweep_test.go diff --git a/acceptance/experimental/air/get/out.test.toml b/acceptance/experimental/air/get/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/experimental/air/get/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/get/output.txt b/acceptance/experimental/air/get/output.txt new file mode 100644 index 00000000000..6ce803659b4 --- /dev/null +++ b/acceptance/experimental/air/get/output.txt @@ -0,0 +1,36 @@ + +=== get (text) +>>> [CLI] experimental air get 123 +Run ID: 123 +Status: SUCCESS +Submitted: [TIMESTAMP] +Duration: 12s +Retries: 0 +Experiment: my-exp +User: user@example.com +Accelerators: 8x H100 +MLflow: [DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0 +Dashboard: https://my-workspace.cloud.databricks.test/jobs/runs/123 + +=== get (json) +>>> [CLI] experimental air get 123 -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "data": { + "run_id": "123", + "status": "SUCCESS", + "started_at": "[TIMESTAMP]", + "duration_seconds": 12, + "attempt_number": 0, + "experiment_name": "my-exp", + "dashboard_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "mlflow_url": "[DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0" + } +} + +=== invalid run id +>>> [CLI] experimental air get notanumber +Error: invalid RUN_ID "notanumber": must be a positive integer + +Exit code: 1 diff --git a/acceptance/experimental/air/get/script b/acceptance/experimental/air/get/script new file mode 100644 index 00000000000..e0ea8d10f85 --- /dev/null +++ b/acceptance/experimental/air/get/script @@ -0,0 +1,8 @@ +title "get (text)" +trace $CLI experimental air get 123 + +title "get (json)" +trace $CLI experimental air get 123 -o json + +title "invalid run id" +errcode trace $CLI experimental air get notanumber diff --git a/acceptance/experimental/air/get/test.toml b/acceptance/experimental/air/get/test.toml new file mode 100644 index 00000000000..b6219b87f07 --- /dev/null +++ b/acceptance/experimental/air/get/test.toml @@ -0,0 +1,40 @@ +# This command does not deploy a bundle, so no engine matrix is needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] + +# The SDK occasionally probes host reachability with a HEAD request; stub it so +# the test is deterministic. +[[Server]] +Pattern = "HEAD /" +Response.Body = '' + +# A single GenAI-compute run with an experiment, GPUs, and a creator. +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get" +Response.Body = ''' +{ + "run_id": 123, + "run_page_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "creator_user_name": "user@example.com", + "start_time": 1700000000000, + "end_time": 1700000012000, + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"}, + "tasks": [ + { + "task_key": "train", + "attempt_number": 0, + "gen_ai_compute_task": { + "mlflow_experiment_name": "/Users/user@example.com/my-exp", + "compute": {"gpu_type": "GPU_8xH100", "num_gpus": 8} + } + } + ] +} +''' + +# MLflow identifiers for the deep-link (runs/get-output is not modeled by the typed SDK). +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get-output" +Response.Body = ''' +{"gen_ai_compute_output": {"run_info": {"mlflow_experiment_id": "exp1", "mlflow_run_id": "run1"}}} +''' diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 4a07a38a378..0a86360c78f 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -5,12 +5,6 @@ Error: `air run` is not implemented yet Exit code: 1 -=== get ->>> [CLI] experimental air get 123 -Error: `air get` is not implemented yet - -Exit code: 1 - === list >>> [CLI] experimental air list Error: `air list` is not implemented yet diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index 2ed885c0e66..e6e8d33ef9d 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -3,9 +3,6 @@ title "run" errcode trace $CLI experimental air run -title "get" -errcode trace $CLI experimental air get 123 - title "list" errcode trace $CLI experimental air list diff --git a/experimental/air/cmd/format.go b/experimental/air/cmd/format.go new file mode 100644 index 00000000000..88f620ee7c3 --- /dev/null +++ b/experimental/air/cmd/format.go @@ -0,0 +1,154 @@ +package aircmd + +import ( + "fmt" + "strings" + "time" + + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// gpuDisplayNames maps the GPU identifiers returned by the backend to the short +// names we show to users. Unknown identifiers are shown unchanged. +var gpuDisplayNames = map[string]string{ + "h100_80gb": "H100", + "a10": "A10", + "GPU_1xA10": "A10", + "GPU_8xH100": "H100", + "GPU_1xH100": "H100", +} + +// runStatus returns the single status word to show for a run. The backend +// reports two values: a lifecycle state (e.g. PENDING, RUNNING) and, once the +// run has finished, a result state (e.g. SUCCESS, FAILED). The result state is +// the more meaningful one, so we prefer it when it is set. +func runStatus(state *jobs.RunState) string { + if state == nil { + return "UNKNOWN" + } + if state.ResultState != "" { + return string(state.ResultState) + } + if state.LifeCycleState != "" { + return string(state.LifeCycleState) + } + return "UNKNOWN" +} + +// startedAt converts the run's start time (epoch milliseconds) to an RFC 3339 +// UTC string, or returns nil if the run has not started yet. +func startedAt(run *jobs.Run) *string { + if run.StartTime == 0 { + return nil + } + s := time.UnixMilli(run.StartTime).UTC().Format(time.RFC3339) + return &s +} + +// durationSeconds returns how long the run has taken, in whole seconds, or nil +// if it has not started. For a finished run this is the elapsed time; for a +// still-running run it is the time since it started. +func durationSeconds(run *jobs.Run) *int64 { + if run.StartTime == 0 { + return nil + } + + var endMillis int64 + switch { + case run.RunDuration > 0: + // The backend already computed the duration for us. + d := run.RunDuration / 1000 + return &d + case run.EndTime > 0: + endMillis = run.EndTime + default: + // Still running: measure against the current time. + endMillis = time.Now().UnixMilli() + } + + d := (endMillis - run.StartTime) / 1000 + return &d +} + +// formatDuration turns a number of seconds into a compact human string such as +// "1h 2m 3s". Trailing zero units are dropped, but a lone "0s" is kept so the +// result is never empty. +func formatDuration(totalSeconds int64) string { + hours := totalSeconds / 3600 + minutes := (totalSeconds % 3600) / 60 + seconds := totalSeconds % 60 + + var parts []string + if hours > 0 { + parts = append(parts, fmt.Sprintf("%dh", hours)) + } + if minutes > 0 { + parts = append(parts, fmt.Sprintf("%dm", minutes)) + } + if seconds > 0 || len(parts) == 0 { + parts = append(parts, fmt.Sprintf("%ds", seconds)) + } + return strings.Join(parts, " ") +} + +// latestAttemptNumber returns the retry count of the run's most recent task. +// Tasks start at attempt 0, so a value of 0 means the run has not been retried. +func latestAttemptNumber(run *jobs.Run) int { + if len(run.Tasks) == 0 { + return 0 + } + return run.Tasks[len(run.Tasks)-1].AttemptNumber +} + +// experimentName returns the MLflow experiment name for the run, or nil if there +// isn't one. Experiment names are often stored under a user's home folder (e.g. +// "/Users/me@example.com/my-experiment"); we strip that prefix so users see just +// the experiment name they chose. +func experimentName(run *jobs.Run) *string { + if len(run.Tasks) == 0 { + return nil + } + task := run.Tasks[0].GenAiComputeTask + if task == nil || task.MlflowExperimentName == "" { + return nil + } + name := stripExperimentUserPrefix(task.MlflowExperimentName) + return &name +} + +// stripExperimentUserPrefix removes a leading "/Users//" from an +// experiment name, leaving the remainder. Names without that prefix are returned +// unchanged. +func stripExperimentUserPrefix(name string) string { + if !strings.HasPrefix(name, "/Users/") { + return name + } + // Split into ["", "Users", "", ""]; keep "". + parts := strings.SplitN(name, "/", 4) + if len(parts) == 4 { + return parts[3] + } + return name +} + +// accelerators returns a short description of the GPUs the run uses, such as +// "8x H100", or an empty string if the run has no GPU compute attached. +func accelerators(run *jobs.Run) string { + if len(run.Tasks) == 0 { + return "" + } + task := run.Tasks[0].GenAiComputeTask + if task == nil || task.Compute == nil || task.Compute.NumGpus == 0 { + return "" + } + return fmt.Sprintf("%dx %s", task.Compute.NumGpus, gpuDisplayName(task.Compute.GpuType)) +} + +// gpuDisplayName returns the friendly name for a GPU identifier, falling back to +// the identifier itself when it is not one we recognize. +func gpuDisplayName(gpuType string) string { + if name, ok := gpuDisplayNames[gpuType]; ok { + return name + } + return gpuType +} diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go new file mode 100644 index 00000000000..c3e2e865b81 --- /dev/null +++ b/experimental/air/cmd/format_test.go @@ -0,0 +1,131 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFormatDuration(t *testing.T) { + cases := []struct { + seconds int64 + want string + }{ + {0, "0s"}, + {45, "45s"}, + {60, "1m"}, + {63, "1m 3s"}, + {3600, "1h"}, + {3723, "1h 2m 3s"}, + {7260, "2h 1m"}, + } + for _, c := range cases { + assert.Equal(t, c.want, formatDuration(c.seconds)) + } +} + +func TestStripExperimentUserPrefix(t *testing.T) { + cases := []struct { + name string + want string + }{ + {"/Users/me@example.com/my-experiment", "my-experiment"}, + {"/Users/me@example.com/nested/path", "nested/path"}, + {"my-experiment", "my-experiment"}, + {"/Shared/team-experiment", "/Shared/team-experiment"}, + {"/Users/me@example.com", "/Users/me@example.com"}, + } + for _, c := range cases { + assert.Equal(t, c.want, stripExperimentUserPrefix(c.name)) + } +} + +func TestGpuDisplayName(t *testing.T) { + assert.Equal(t, "H100", gpuDisplayName("h100_80gb")) + assert.Equal(t, "A10", gpuDisplayName("GPU_1xA10")) + assert.Equal(t, "A10", gpuDisplayName("a10")) + assert.Equal(t, "H100", gpuDisplayName("GPU_8xH100")) + assert.Equal(t, "H100", gpuDisplayName("GPU_1xH100")) + // Unknown identifiers pass through unchanged. + assert.Equal(t, "b200", gpuDisplayName("b200")) + assert.Equal(t, "", gpuDisplayName("")) +} + +func TestRunStatusPrefersResultState(t *testing.T) { + // Result state wins once the run has finished. + assert.Equal(t, "SUCCESS", runStatus(&jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateTerminated, + ResultState: jobs.RunResultStateSuccess, + })) + // Before completion only the lifecycle state is set. + assert.Equal(t, "RUNNING", runStatus(&jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateRunning, + })) + // Non-nil state with neither field set, and nil state. + assert.Equal(t, "UNKNOWN", runStatus(&jobs.RunState{})) + assert.Equal(t, "UNKNOWN", runStatus(nil)) +} + +func TestStartedAt(t *testing.T) { + // Not started yet. + assert.Nil(t, startedAt(&jobs.Run{})) + // 1700000000000 ms == 2023-11-14T22:13:20Z. + got := startedAt(&jobs.Run{StartTime: 1700000000000}) + require.NotNil(t, got) + assert.Equal(t, "2023-11-14T22:13:20Z", *got) +} + +func TestDurationSeconds(t *testing.T) { + // Not started yet. + assert.Nil(t, durationSeconds(&jobs.Run{})) + + // Backend-provided duration wins (milliseconds → seconds). + d := durationSeconds(&jobs.Run{StartTime: 1700000000000, RunDuration: 5000}) + require.NotNil(t, d) + assert.Equal(t, int64(5), *d) + + // Finished run with no RunDuration: end - start. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000, EndTime: 1700000012000}) + require.NotNil(t, d) + assert.Equal(t, int64(12), *d) + + // Still running: measured against the current time, so positive. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000}) + require.NotNil(t, d) + assert.Positive(t, *d) +} + +func TestLatestAttemptNumber(t *testing.T) { + assert.Equal(t, 0, latestAttemptNumber(&jobs.Run{})) + run := &jobs.Run{Tasks: []jobs.RunTask{{AttemptNumber: 0}, {AttemptNumber: 2}}} + assert.Equal(t, 2, latestAttemptNumber(run)) +} + +func TestExperimentName(t *testing.T) { + assert.Nil(t, experimentName(&jobs.Run{})) + assert.Nil(t, experimentName(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Nil(t, experimentName(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: ""}, + }}})) + got := experimentName(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: "/Users/me@example.com/exp"}, + }}}) + require.NotNil(t, got) + assert.Equal(t, "exp", *got) +} + +func TestAccelerators(t *testing.T) { + assert.Equal(t, "", accelerators(&jobs.Run{})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{}, + }}})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 0}}, + }}})) + assert.Equal(t, "8x H100", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 8, GpuType: "GPU_8xH100"}}, + }}})) +} diff --git a/experimental/air/cmd/get.go b/experimental/air/cmd/get.go index 0ab0b8226bf..cc486b722f8 100644 --- a/experimental/air/cmd/get.go +++ b/experimental/air/cmd/get.go @@ -1,19 +1,187 @@ package aircmd import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/spf13/cobra" ) +// getData is the payload printed by `air get`. The json-tagged fields form +// the machine-readable output; fields tagged `json:"-"` are shown only in the +// human-readable text view. +type getData struct { + RunID string `json:"run_id"` + Status string `json:"status"` + StartedAt *string `json:"started_at"` + DurationSeconds *int64 `json:"duration_seconds"` + AttemptNumber int `json:"attempt_number"` + ExperimentName *string `json:"experiment_name"` + DashboardURL string `json:"dashboard_url"` + MLflowURL *string `json:"mlflow_url"` + + // Duration is the human-readable form of DurationSeconds, e.g. "12m 3s". + Duration string `json:"-"` + // Accelerators describes the run's GPUs, e.g. "8x H100". + Accelerators string `json:"-"` + // User is the run's creator. Text-only; JSON omits it, matching `air get --json`. + User string `json:"-"` + // Sweep replaces the single-run view for foreach runs. Text-only; JSON omits it. + Sweep *sweepInfo `json:"-"` +} + +// getTemplate is the text-mode layout. It reads from the JSON envelope, so +// every field is reached through ".Data". Optional rows are hidden when empty. +const getTemplate = `{{- if .Data.Sweep -}} +Sweep Run ID: {{.Data.RunID}} +Status: {{.Data.Status}} +Total: {{.Data.Sweep.Total}} +Completed: {{.Data.Sweep.Completed}} +Succeeded: {{.Data.Sweep.Succeeded}} +Failed: {{.Data.Sweep.Failed}} +Active: {{.Data.Sweep.Active}} +{{- if .Data.Sweep.Tasks}} + +Sweep Tasks: +{{printf " %-24s %-14s %-12s %s" "TASK" "RUN ID" "STATUS" "EXPERIMENT"}} +{{- range .Data.Sweep.Tasks}} +{{printf " %-24s %-14s %-12s %s" .TaskKey .RunID .Status .Experiment}} +{{- end}} +{{- end}} +{{- else -}} +Run ID: {{.Data.RunID}} +Status: {{.Data.Status}} +{{- if .Data.StartedAt}} +Submitted: {{.Data.StartedAt}} +{{- end}} +{{- if .Data.Duration}} +Duration: {{.Data.Duration}} +{{- end}} +Retries: {{.Data.AttemptNumber}} +{{- if .Data.ExperimentName}} +Experiment: {{.Data.ExperimentName}} +{{- end}} +{{- if .Data.User}} +User: {{.Data.User}} +{{- end}} +{{- if .Data.Accelerators}} +Accelerators: {{.Data.Accelerators}} +{{- end}} +{{- if .Data.MLflowURL}} +MLflow: {{.Data.MLflowURL}} +{{- end}} +Dashboard: {{.Data.DashboardURL}} +{{- end}} +` + func newGetCommand() *cobra.Command { cmd := &cobra.Command{ Use: "get RUN_ID", Args: root.ExactArgs(1), Short: "Show details for a run", - RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("get") + Annotations: map[string]string{ + "template": getTemplate, }, } + cmd.PreRunE = root.MustWorkspaceClient + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + + runID, err := strconv.ParseInt(args[0], 10, 64) + if err != nil || runID <= 0 { + return fmt.Errorf("invalid RUN_ID %q: must be a positive integer", args[0]) + } + + run, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: runID}) + if err != nil { + // The backend returns this when the run ID is unknown to the user. + if errors.Is(err, apierr.ErrResourceDoesNotExist) { + return fmt.Errorf("run %d not found: check the run ID and that it is a job run ID", runID) + } + return fmt.Errorf("failed to get status for run %d: %w", runID, err) + } + + data := buildGetData(run) + data.MLflowURL = mlflowURL(ctx, w, run) + if task := findForEachTask(run); task != nil { + data.Sweep = buildSweepInfo(ctx, w, task) + } + + // Text mode shows the training-config YAML before the status, mirroring + // `air get`. JSON output omits it, matching `air get --json`. + if root.OutputType(cmd) == flags.OutputText { + if path := yamlConfigPath(run); path != "" { + printConfigYAML(ctx, w, path) + } + } + return renderEnvelope(ctx, data) + } + return cmd } + +// buildGetData extracts the fields we display from a run. +func buildGetData(run *jobs.Run) getData { + data := getData{ + RunID: strconv.FormatInt(run.RunId, 10), + Status: runStatus(run.State), + StartedAt: startedAt(run), + DurationSeconds: durationSeconds(run), + AttemptNumber: latestAttemptNumber(run), + ExperimentName: experimentName(run), + DashboardURL: run.RunPageUrl, + Accelerators: accelerators(run), + User: run.CreatorUserName, + } + if data.DurationSeconds != nil { + data.Duration = formatDuration(*data.DurationSeconds) + } + return data +} + +// yamlConfigPath returns the run's training-config YAML path, or "" if none. +func yamlConfigPath(run *jobs.Run) string { + if len(run.Tasks) == 0 { + return "" + } + task := run.Tasks[0].GenAiComputeTask + if task == nil { + return "" + } + return task.YamlParametersFilePath +} + +// printConfigYAML downloads the run's training-config YAML and prints it. It is +// best-effort: a failure is surfaced as a warning but does not fail status. +func printConfigYAML(ctx context.Context, w *databricks.WorkspaceClient, path string) { + r, err := w.Workspace.Download(ctx, path) + if err != nil { + log.Warnf(ctx, "air get: could not download training config %s: %v", path, err) + return + } + defer r.Close() + + content, err := io.ReadAll(r) + if err != nil { + log.Warnf(ctx, "air get: could not read training config %s: %v", path, err) + return + } + + cmdio.LogString(ctx, "Training Configuration:") + cmdio.LogString(ctx, string(content)) + cmdio.LogString(ctx, "") +} diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go new file mode 100644 index 00000000000..6dfdc54db7b --- /dev/null +++ b/experimental/air/cmd/get_test.go @@ -0,0 +1,211 @@ +package aircmd + +import ( + "bytes" + "io" + "strings" + "testing" + "text/template" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// renderGet renders the status template against the JSON envelope, exactly as +// the command does, so the test covers the real template branches. +func renderGet(t *testing.T, data getData) string { + t.Helper() + tmpl, err := template.New("status").Parse(getTemplate) + require.NoError(t, err) + var buf bytes.Buffer + require.NoError(t, tmpl.Execute(&buf, envelope{V: envelopeVersion, Data: data})) + return buf.String() +} + +func TestGetTemplateSingleRun(t *testing.T) { + out := renderGet(t, getData{ + RunID: "123", + Status: "RUNNING", + User: "me@example.com", + DashboardURL: "https://example.test/run/123", + }) + assert.Contains(t, out, "Run ID: 123") + assert.Contains(t, out, "Status: RUNNING") + assert.Contains(t, out, "User:") + assert.Contains(t, out, "me@example.com") + assert.Contains(t, out, "Dashboard: https://example.test/run/123") + assert.NotContains(t, out, "Sweep") +} + +func TestGetRunInvalidID(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + cmd := newGetCommand() + cmd.SetContext(ctx) + + err := cmd.RunE(cmd, []string{"abc"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid RUN_ID") +} + +func TestGetRunNotFound(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 5}).Return( + nil, apierr.ErrResourceDoesNotExist) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + cmd := newGetCommand() + cmd.SetContext(ctx) + + err := cmd.RunE(cmd, []string{"5"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "run 5 not found") +} + +func TestPrintConfigYAML(t *testing.T) { + t.Run("downloads and prints", func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + // The mock asserts Download is called with the resolved path. + m.GetMockWorkspaceAPI().EXPECT(). + Download(mock.Anything, "/Workspace/cfg.yaml"). + Return(io.NopCloser(strings.NewReader("epochs: 3\n")), nil) + + printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/cfg.yaml") + }) + + t.Run("download failure is non-fatal", func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + m.GetMockWorkspaceAPI().EXPECT(). + Download(mock.Anything, "/Workspace/missing.yaml"). + Return(nil, apierr.ErrResourceDoesNotExist) + + // Must not panic: a failed config fetch is best-effort. + printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/missing.yaml") + }) +} + +func TestYAMLConfigPath(t *testing.T) { + // No tasks, or a task without GenAiComputeTask, yields no path. + assert.Equal(t, "", yamlConfigPath(&jobs.Run{})) + assert.Equal(t, "", yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + + run := &jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{YamlParametersFilePath: "/Workspace/cfg.yaml"}, + }}} + assert.Equal(t, "/Workspace/cfg.yaml", yamlConfigPath(run)) +} + +func TestGetTemplateSweep(t *testing.T) { + out := renderGet(t, getData{ + RunID: "456", + Status: "RUNNING", + Sweep: &sweepInfo{ + Total: 4, Completed: 2, Succeeded: 1, Failed: 1, Active: 2, + Tasks: []sweepTask{ + {TaskKey: "iter_0", RunID: "789", Status: "SUCCESS", Experiment: "my-exp"}, + {TaskKey: "iter_1", RunID: "790", Status: "FAILED", Experiment: "my-exp"}, + }, + }, + }) + assert.Contains(t, out, "Sweep Run ID: 456") + assert.Contains(t, out, "Total: 4") + assert.Contains(t, out, "Sweep Tasks:") + assert.Contains(t, out, "iter_0") + assert.Contains(t, out, "iter_1") + assert.Contains(t, out, "FAILED") + assert.Contains(t, out, "my-exp") + // The single-run rows must not appear in the sweep view. + assert.NotContains(t, out, "Dashboard:") +} + +func TestGetTemplateSweepNoTasks(t *testing.T) { + // A sweep whose iterations haven't materialized yet: counts show, but the + // task table header is hidden. + out := renderGet(t, getData{ + RunID: "456", + Status: "RUNNING", + Sweep: &sweepInfo{Total: 4, Active: 4}, + }) + assert.Contains(t, out, "Sweep Run ID: 456") + assert.Contains(t, out, "Total: 4") + assert.NotContains(t, out, "Sweep Tasks:") +} + +func TestGetTemplateMinimal(t *testing.T) { + // Only the always-present rows render; optional rows are hidden when empty. + out := renderGet(t, getData{RunID: "1", Status: "PENDING", DashboardURL: "https://example.test/1"}) + assert.Contains(t, out, "Run ID: 1") + assert.Contains(t, out, "Status: PENDING") + assert.Contains(t, out, "Retries: 0") + assert.Contains(t, out, "Dashboard: https://example.test/1") + for _, hidden := range []string{"Submitted:", "Duration:", "Experiment:", "User:", "Accelerators:", "MLflow:"} { + assert.NotContains(t, out, hidden) + } +} + +func TestGetTemplateAllFields(t *testing.T) { + started := "2023-11-14T22:13:20Z" + exp := "exp" + mlflow := "https://example.test/ml/exp/1" + out := renderGet(t, getData{ + RunID: "1", + Status: "SUCCESS", + StartedAt: &started, + Duration: "12s", + AttemptNumber: 2, + ExperimentName: &exp, + User: "me@example.com", + Accelerators: "8x H100", + MLflowURL: &mlflow, + DashboardURL: "https://example.test/1", + }) + for _, want := range []string{ + "Submitted: 2023-11-14T22:13:20Z", + "Duration: 12s", + "Retries: 2", + "Experiment: exp", + "User: me@example.com", + "Accelerators: 8x H100", + "MLflow: https://example.test/ml/exp/1", + "Dashboard: https://example.test/1", + } { + assert.Contains(t, out, want) + } +} + +func TestBuildStatusData(t *testing.T) { + run := &jobs.Run{ + RunId: 123, + RunPageUrl: "https://example.test/run/123", + CreatorUserName: "me@example.com", + StartTime: 1700000000000, + EndTime: 1700000012000, + State: &jobs.RunState{ResultState: jobs.RunResultStateSuccess}, + Tasks: []jobs.RunTask{{ + AttemptNumber: 1, + GenAiComputeTask: &jobs.GenAiComputeTask{ + MlflowExperimentName: "/Users/me@example.com/exp", + Compute: &jobs.ComputeConfig{NumGpus: 8, GpuType: "GPU_8xH100"}, + }, + }}, + } + d := buildGetData(run) + assert.Equal(t, "123", d.RunID) + assert.Equal(t, "SUCCESS", d.Status) + assert.Equal(t, 1, d.AttemptNumber) + assert.Equal(t, "https://example.test/run/123", d.DashboardURL) + assert.Equal(t, "me@example.com", d.User) + assert.Equal(t, "8x H100", d.Accelerators) + assert.Equal(t, "12s", d.Duration) + require.NotNil(t, d.ExperimentName) + assert.Equal(t, "exp", *d.ExperimentName) + require.NotNil(t, d.DurationSeconds) + assert.Equal(t, int64(12), *d.DurationSeconds) +} diff --git a/experimental/air/cmd/mlflow.go b/experimental/air/cmd/mlflow.go new file mode 100644 index 00000000000..97d085b0128 --- /dev/null +++ b/experimental/air/cmd/mlflow.go @@ -0,0 +1,65 @@ +package aircmd + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// getRunOutputResponse is the slice of the jobs runs/get-output response we care +// about. The MLflow identifiers live under a gen_ai_compute_output field that +// the typed SDK does not model, so we call the endpoint directly and parse just +// these fields. +type getRunOutputResponse struct { + GenAiComputeOutput *struct { + RunInfo *struct { + MlflowExperimentID string `json:"mlflow_experiment_id"` + MlflowRunID string `json:"mlflow_run_id"` + } `json:"run_info"` + } `json:"gen_ai_compute_output"` +} + +// mlflowURL returns a link to the run's MLflow logs, or nil if it can't be +// built. The link is a convenience, so any failure here (missing task, endpoint +// error, run not yet started) is logged and treated as "no link" rather than +// failing the whole command. +func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run) *string { + if len(run.Tasks) == 0 { + return nil + } + // The MLflow output is attached to the task run, not the parent job run. + taskRunID := run.Tasks[0].RunId + + apiClient, err := client.New(w.Config) + if err != nil { + log.Debugf(ctx, "air get: could not build API client for MLflow link: %v", err) + return nil + } + + var out getRunOutputResponse + err = apiClient.Do(ctx, http.MethodGet, "/api/2.2/jobs/runs/get-output", + nil, map[string]any{"run_id": taskRunID}, nil, &out) + if err != nil { + log.Debugf(ctx, "air get: could not fetch run output for MLflow link: %v", err) + return nil + } + + if out.GenAiComputeOutput == nil || out.GenAiComputeOutput.RunInfo == nil { + return nil + } + info := out.GenAiComputeOutput.RunInfo + if info.MlflowExperimentID == "" || info.MlflowRunID == "" { + return nil + } + + host := strings.TrimRight(w.Config.Host, "/") + url := fmt.Sprintf("%s/ml/experiments/%s/runs/%s/artifacts/logs/node_0", + host, info.MlflowExperimentID, info.MlflowRunID) + return &url +} diff --git a/experimental/air/cmd/mlflow_test.go b/experimental/air/cmd/mlflow_test.go new file mode 100644 index 00000000000..bbc4fef9822 --- /dev/null +++ b/experimental/air/cmd/mlflow_test.go @@ -0,0 +1,64 @@ +package aircmd + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestWorkspaceClient builds a WorkspaceClient pointed at a mock HTTP server. +// mlflowURL calls the runs/get-output REST endpoint directly (the field it needs +// is not modeled by the typed SDK), so it must be exercised over HTTP. +func newTestWorkspaceClient(t *testing.T, host string) *databricks.WorkspaceClient { + t.Helper() + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: host, Token: "token"}) + require.NoError(t, err) + return w +} + +// runOutputServer serves the given runs/get-output body and a stub for the SDK's +// well-known config discovery request. *hit is set when get-output is called. +func runOutputServer(t *testing.T, body string, hit *bool) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/2.2/jobs/runs/get-output" { + *hit = true + _, _ = w.Write([]byte(body)) + return + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestMLflowURL(t *testing.T) { + ctx := t.Context() + run := &jobs.Run{Tasks: []jobs.RunTask{{RunId: 99}}} + + t.Run("builds the deep-link on success", func(t *testing.T) { + var hit bool + srv := runOutputServer(t, `{"gen_ai_compute_output":{"run_info":{"mlflow_experiment_id":"E1","mlflow_run_id":"R1"}}}`, &hit) + + got := mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run) + require.NotNil(t, got) + assert.True(t, hit, "runs/get-output should have been called") + assert.Equal(t, srv.URL+"/ml/experiments/E1/runs/R1/artifacts/logs/node_0", *got) + }) + + t.Run("nil when the run has no MLflow info", func(t *testing.T) { + var hit bool + srv := runOutputServer(t, `{}`, &hit) + assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run)) + }) + + t.Run("nil when the run has no tasks", func(t *testing.T) { + // Returns before any HTTP call, so the host is never contacted. + assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, "https://unused.invalid"), &jobs.Run{})) + }) +} diff --git a/experimental/air/cmd/output.go b/experimental/air/cmd/output.go new file mode 100644 index 00000000000..3da766a7d4f --- /dev/null +++ b/experimental/air/cmd/output.go @@ -0,0 +1,39 @@ +package aircmd + +import ( + "context" + "time" + + "github.com/databricks/cli/libs/cmdio" +) + +// envelopeVersion is the envelope's format-version marker. The Python `air` CLI +// hardcodes it to 1; it lets consumers detect a future incompatible change to +// the envelope shape. +const envelopeVersion = 1 + +// envelope is the JSON shape that the AI runtime CLI prints: +// +// { "v": 1, "ts": "2024-01-15T14:30:45Z", "data": { ... } } +// +// It mirrors the envelope used by the original Python `air` CLI so existing +// consumers keep working after the port to Go. +type envelope struct { + // V is the envelope format-version marker (always 1). + V int `json:"v"` + // TS is the wall-clock time the response was produced, in RFC 3339 UTC. + // It is an absolute timestamp, not an elapsed duration. + TS string `json:"ts"` + // Data is the command-specific payload. + Data any `json:"data"` +} + +// renderEnvelope wraps data in the JSON envelope and prints it. +// Fields that should appear only in text output are tagged `json:"-"` on the payload struct. +func renderEnvelope(ctx context.Context, data any) error { + return cmdio.Render(ctx, envelope{ + V: envelopeVersion, + TS: time.Now().UTC().Format(time.RFC3339), + Data: data, + }) +} diff --git a/experimental/air/cmd/output_test.go b/experimental/air/cmd/output_test.go new file mode 100644 index 00000000000..73a5572c3f5 --- /dev/null +++ b/experimental/air/cmd/output_test.go @@ -0,0 +1,13 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/require" +) + +func TestRenderEnvelope(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + require.NoError(t, renderEnvelope(ctx, getData{RunID: "1", Status: "RUNNING"})) +} diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index a6e24177f33..5e35bcdcd14 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -14,7 +14,6 @@ import ( func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ "run": newRunCommand(), - "get": newGetCommand(), "list": newListCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), diff --git a/experimental/air/cmd/sweep.go b/experimental/air/cmd/sweep.go new file mode 100644 index 00000000000..b346f43f1b6 --- /dev/null +++ b/experimental/air/cmd/sweep.go @@ -0,0 +1,76 @@ +package aircmd + +import ( + "context" + "strconv" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// sweepInfo summarizes a "foreach" run, which fans a single config out into many +// iterations (a hyperparameter sweep). It is shown only in text output. +type sweepInfo struct { + Total int + Succeeded int + Failed int + Active int + Completed int + Tasks []sweepTask +} + +// sweepTask is one iteration of a sweep. +type sweepTask struct { + TaskKey string + RunID string + Status string + Experiment string +} + +// findForEachTask returns the run's foreach task if it has one, or nil. A run is +// a sweep when one of its tasks fans out into iterations. +func findForEachTask(run *jobs.Run) *jobs.RunTask { + for i := range run.Tasks { + if run.Tasks[i].ForEachTask != nil { + return &run.Tasks[i] + } + } + return nil +} + +// buildSweepInfo gathers the iteration counts and per-iteration rows for a +// sweep. The counts come from the task we already have; the individual +// iterations require a second lookup. If that lookup fails we still return the +// counts (logging the failure) so the user sees the summary. +func buildSweepInfo(ctx context.Context, w *databricks.WorkspaceClient, task *jobs.RunTask) *sweepInfo { + info := &sweepInfo{} + if task.ForEachTask.Stats != nil && task.ForEachTask.Stats.TaskRunStats != nil { + stats := task.ForEachTask.Stats.TaskRunStats + info.Total = stats.TotalIterations + info.Succeeded = stats.SucceededIterations + info.Failed = stats.FailedIterations + info.Active = stats.ActiveIterations + info.Completed = stats.CompletedIterations + } + + // The iterations are returned as part of a run lookup on the foreach task. + iterated, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: task.RunId}) + if err != nil { + log.Debugf(ctx, "air get: could not fetch sweep iterations: %v", err) + return info + } + + for _, it := range iterated.Iterations { + row := sweepTask{ + TaskKey: it.TaskKey, + RunID: strconv.FormatInt(it.RunId, 10), + Status: runStatus(it.State), + } + if it.GenAiComputeTask != nil && it.GenAiComputeTask.MlflowExperimentName != "" { + row.Experiment = stripExperimentUserPrefix(it.GenAiComputeTask.MlflowExperimentName) + } + info.Tasks = append(info.Tasks, row) + } + return info +} diff --git a/experimental/air/cmd/sweep_test.go b/experimental/air/cmd/sweep_test.go new file mode 100644 index 00000000000..10134c0df42 --- /dev/null +++ b/experimental/air/cmd/sweep_test.go @@ -0,0 +1,81 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestFindForEachTask(t *testing.T) { + // No tasks at all. + assert.Nil(t, findForEachTask(&jobs.Run{})) + + // A task that is not a foreach. + assert.Nil(t, findForEachTask(&jobs.Run{Tasks: []jobs.RunTask{{TaskKey: "a"}}})) + + // The foreach task is found even when it isn't first. + run := &jobs.Run{Tasks: []jobs.RunTask{ + {TaskKey: "a"}, + {TaskKey: "sweep", ForEachTask: &jobs.RunForEachTask{}}, + }} + got := findForEachTask(run) + require.NotNil(t, got) + assert.Equal(t, "sweep", got.TaskKey) +} + +func sweepTaskFixture() *jobs.RunTask { + return &jobs.RunTask{ + RunId: 99, + ForEachTask: &jobs.RunForEachTask{ + Stats: &jobs.ForEachStats{TaskRunStats: &jobs.ForEachTaskTaskRunStats{ + TotalIterations: 4, + SucceededIterations: 1, + FailedIterations: 1, + ActiveIterations: 2, + CompletedIterations: 2, + }}, + }, + } +} + +func TestBuildSweepInfo(t *testing.T) { + ctx := t.Context() + + t.Run("counts and iteration rows", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 99}).Return( + &jobs.Run{Iterations: []jobs.RunTask{{ + TaskKey: "iter_0", + RunId: 100, + State: &jobs.RunState{ResultState: jobs.RunResultStateSuccess}, + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: "/Users/me@example.com/exp"}, + }}}, nil) + + info := buildSweepInfo(ctx, m.WorkspaceClient, sweepTaskFixture()) + assert.Equal(t, 4, info.Total) + assert.Equal(t, 2, info.Completed) + assert.Equal(t, 1, info.Succeeded) + assert.Equal(t, 1, info.Failed) + assert.Equal(t, 2, info.Active) + require.Len(t, info.Tasks, 1) + assert.Equal(t, "iter_0", info.Tasks[0].TaskKey) + assert.Equal(t, "100", info.Tasks[0].RunID) + assert.Equal(t, "SUCCESS", info.Tasks[0].Status) + assert.Equal(t, "exp", info.Tasks[0].Experiment) + }) + + t.Run("iteration lookup failure still returns counts", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 99}).Return( + nil, apierr.ErrResourceDoesNotExist) + + info := buildSweepInfo(ctx, m.WorkspaceClient, sweepTaskFixture()) + assert.Equal(t, 4, info.Total) + assert.Empty(t, info.Tasks) + }) +} From 89042d08bf1ebfb6352abd1078a597e880ec21bb Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 23:27:54 +0000 Subject: [PATCH 04/18] experimental/air: rename stale TestBuildStatusData to TestBuildGetData Co-authored-by: Isaac --- experimental/air/cmd/get_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index 6dfdc54db7b..64fd1aa1f68 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -17,11 +17,11 @@ import ( "github.com/stretchr/testify/require" ) -// renderGet renders the status template against the JSON envelope, exactly as +// renderGet renders the get template against the JSON envelope, exactly as // the command does, so the test covers the real template branches. func renderGet(t *testing.T, data getData) string { t.Helper() - tmpl, err := template.New("status").Parse(getTemplate) + tmpl, err := template.New("get").Parse(getTemplate) require.NoError(t, err) var buf bytes.Buffer require.NoError(t, tmpl.Execute(&buf, envelope{V: envelopeVersion, Data: data})) @@ -180,7 +180,7 @@ func TestGetTemplateAllFields(t *testing.T) { } } -func TestBuildStatusData(t *testing.T) { +func TestBuildGetData(t *testing.T) { run := &jobs.Run{ RunId: 123, RunPageUrl: "https://example.test/run/123", From c99239ca432a5001ff6ae5f58439399d262916eb Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 23:30:07 +0000 Subject: [PATCH 05/18] experimental/air: apply testifylint fixes in get/format tests Co-authored-by: Isaac --- experimental/air/cmd/format_test.go | 10 +++++----- experimental/air/cmd/get_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go index c3e2e865b81..583f3cc3111 100644 --- a/experimental/air/cmd/format_test.go +++ b/experimental/air/cmd/format_test.go @@ -50,7 +50,7 @@ func TestGpuDisplayName(t *testing.T) { assert.Equal(t, "H100", gpuDisplayName("GPU_1xH100")) // Unknown identifiers pass through unchanged. assert.Equal(t, "b200", gpuDisplayName("b200")) - assert.Equal(t, "", gpuDisplayName("")) + assert.Empty(t, gpuDisplayName("")) } func TestRunStatusPrefersResultState(t *testing.T) { @@ -117,12 +117,12 @@ func TestExperimentName(t *testing.T) { } func TestAccelerators(t *testing.T) { - assert.Equal(t, "", accelerators(&jobs.Run{})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + assert.Empty(t, accelerators(&jobs.Run{})) + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{}, }}})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 0}}, }}})) assert.Equal(t, "8x H100", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index 64fd1aa1f68..b6d6d0baab9 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -93,8 +93,8 @@ func TestPrintConfigYAML(t *testing.T) { func TestYAMLConfigPath(t *testing.T) { // No tasks, or a task without GenAiComputeTask, yields no path. - assert.Equal(t, "", yamlConfigPath(&jobs.Run{})) - assert.Equal(t, "", yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Empty(t, yamlConfigPath(&jobs.Run{})) + assert.Empty(t, yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) run := &jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{YamlParametersFilePath: "/Workspace/cfg.yaml"}, From 235f1bd137e0a27f30d432d7205c7ac77855030d Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Tue, 16 Jun 2026 18:29:42 +0000 Subject: [PATCH 06/18] experimental/air: disambiguate JOB_RUN_ID, hide --review, add -i alias Rename the RUN_ID arg placeholder to JOB_RUN_ID across get/logs/cancel to disambiguate it from other run identifiers. Hide the `logs --review` flag to match the Python CLI (help=argparse.SUPPRESS), and add the `-i` shorthand for `register-image --interactive-authenticate`. Co-authored-by: Isaac --- experimental/air/cmd/cancel.go | 8 ++++---- experimental/air/cmd/get.go | 2 +- experimental/air/cmd/logs.go | 4 +++- experimental/air/cmd/register_image.go | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/experimental/air/cmd/cancel.go b/experimental/air/cmd/cancel.go index ad7fffc7125..ae5514e5b04 100644 --- a/experimental/air/cmd/cancel.go +++ b/experimental/air/cmd/cancel.go @@ -12,7 +12,7 @@ func newCancelCommand() *cobra.Command { ) cmd := &cobra.Command{ - Use: "cancel [RUN_ID...]", + Use: "cancel [JOB_RUN_ID...]", Short: "Cancel one or more runs", Long: `Cancel one or more runs by ID, or cancel all of your active runs with --all.`, RunE: func(cmd *cobra.Command, args []string) error { @@ -23,14 +23,14 @@ func newCancelCommand() *cobra.Command { cmd.Flags().BoolVar(&all, "all", false, "Cancel all of your active runs") cmd.Flags().BoolVarP(&yes, "yes", "y", false, "Skip the confirmation prompt") - // Require exactly one of: one or more RUN_IDs, or --all. Cobra parses flags + // Require exactly one of: one or more JOB_RUN_IDs, or --all. Cobra parses flags // before running this, so `all` reflects the user's input. cmd.Args = func(cmd *cobra.Command, args []string) error { switch { case all && len(args) > 0: - return &root.InvalidArgsError{Command: cmd, Message: "cannot combine RUN_ID arguments with --all"} + return &root.InvalidArgsError{Command: cmd, Message: "cannot combine JOB_RUN_ID arguments with --all"} case !all && len(args) == 0: - return &root.InvalidArgsError{Command: cmd, Message: "provide at least one RUN_ID, or use --all"} + return &root.InvalidArgsError{Command: cmd, Message: "provide at least one JOB_RUN_ID, or use --all"} } return nil } diff --git a/experimental/air/cmd/get.go b/experimental/air/cmd/get.go index 0ab0b8226bf..45f93df10a6 100644 --- a/experimental/air/cmd/get.go +++ b/experimental/air/cmd/get.go @@ -7,7 +7,7 @@ import ( func newGetCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "get RUN_ID", + Use: "get JOB_RUN_ID", Args: root.ExactArgs(1), Short: "Show details for a run", RunE: func(cmd *cobra.Command, args []string) error { diff --git a/experimental/air/cmd/logs.go b/experimental/air/cmd/logs.go index 4dbbe41c278..c34fb62a7df 100644 --- a/experimental/air/cmd/logs.go +++ b/experimental/air/cmd/logs.go @@ -15,7 +15,7 @@ func newLogsCommand() *cobra.Command { ) cmd := &cobra.Command{ - Use: "logs RUN_ID", + Use: "logs JOB_RUN_ID", Args: root.ExactArgs(1), Short: "Stream or fetch logs for a run", Long: `Stream logs from an active run, or fetch logs from a completed run.`, @@ -29,6 +29,8 @@ func newLogsCommand() *cobra.Command { cmd.Flags().IntVar(&retry, "retry", -1, "View logs from a specific retry attempt; -1 means latest") cmd.Flags().StringVar(&downloadTo, "download-to", "", "Download all logs to this directory instead of printing") cmd.Flags().BoolVar(&review, "review", false, "Download logs from all nodes and filter for error signatures") + // Hidden in the Python `air` CLI (help=argparse.SUPPRESS); keep it internal here to match. + cmd.Flags().MarkHidden("review") return cmd } diff --git a/experimental/air/cmd/register_image.go b/experimental/air/cmd/register_image.go index a5be3df408b..1d8b45044a7 100644 --- a/experimental/air/cmd/register_image.go +++ b/experimental/air/cmd/register_image.go @@ -25,7 +25,7 @@ func newRegisterImageCommand() *cobra.Command { cmd.Flags().StringVar(&scope, "scope", "", "Databricks secret scope holding registry credentials") cmd.Flags().StringVar(&key, "key", "", "Databricks secret key holding registry credentials") - cmd.Flags().BoolVar(&interactiveAuth, "interactive-authenticate", false, "Prompt for registry credentials and store them as a secret") + cmd.Flags().BoolVarP(&interactiveAuth, "interactive-authenticate", "i", false, "Prompt for registry credentials and store them as a secret") cmd.Flags().StringVar(&tagPolicy, "tag-policy", "auto", "Image resolution policy: auto or latest") cmd.Flags().IntVar(&timeoutMinutes, "timeout-minutes", 60, "Timeout to wait for the image to become available") From 31121ad181fa34b6a4326ccc95ce62f6bff5ffaa Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 21:38:13 +0000 Subject: [PATCH 07/18] experimental/air: implement the `air get` command Implement the read-only run-details command (renamed from `status` to `get`). It fetches a job run via the Jobs API and renders the run's status, start time, duration, retries, experiment, accelerators, dashboard URL, MLflow deep-link, and a foreach/sweep summary. Output is the air-style {v, ts, data} JSON envelope under -o json, or a text view. Renames the command-level identifiers (status -> get) while keeping the run's "status" field/label. Adds format/mlflow/sweep/output helpers with unit tests and an acceptance test, and drops `get` from the not-implemented stub coverage. Co-authored-by: Isaac --- acceptance/experimental/air/get/out.test.toml | 3 + acceptance/experimental/air/get/output.txt | 36 +++ acceptance/experimental/air/get/script | 8 + acceptance/experimental/air/get/test.toml | 40 ++++ .../experimental/air/unimplemented/output.txt | 6 - .../experimental/air/unimplemented/script | 3 - experimental/air/cmd/format.go | 154 +++++++++++++ experimental/air/cmd/format_test.go | 131 +++++++++++ experimental/air/cmd/get.go | 172 +++++++++++++- experimental/air/cmd/get_test.go | 211 ++++++++++++++++++ experimental/air/cmd/mlflow.go | 65 ++++++ experimental/air/cmd/mlflow_test.go | 64 ++++++ experimental/air/cmd/output.go | 39 ++++ experimental/air/cmd/output_test.go | 13 ++ experimental/air/cmd/stubs_test.go | 1 - experimental/air/cmd/sweep.go | 76 +++++++ experimental/air/cmd/sweep_test.go | 81 +++++++ 17 files changed, 1091 insertions(+), 12 deletions(-) create mode 100644 acceptance/experimental/air/get/out.test.toml create mode 100644 acceptance/experimental/air/get/output.txt create mode 100644 acceptance/experimental/air/get/script create mode 100644 acceptance/experimental/air/get/test.toml create mode 100644 experimental/air/cmd/format.go create mode 100644 experimental/air/cmd/format_test.go create mode 100644 experimental/air/cmd/get_test.go create mode 100644 experimental/air/cmd/mlflow.go create mode 100644 experimental/air/cmd/mlflow_test.go create mode 100644 experimental/air/cmd/output.go create mode 100644 experimental/air/cmd/output_test.go create mode 100644 experimental/air/cmd/sweep.go create mode 100644 experimental/air/cmd/sweep_test.go diff --git a/acceptance/experimental/air/get/out.test.toml b/acceptance/experimental/air/get/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/experimental/air/get/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/get/output.txt b/acceptance/experimental/air/get/output.txt new file mode 100644 index 00000000000..6ce803659b4 --- /dev/null +++ b/acceptance/experimental/air/get/output.txt @@ -0,0 +1,36 @@ + +=== get (text) +>>> [CLI] experimental air get 123 +Run ID: 123 +Status: SUCCESS +Submitted: [TIMESTAMP] +Duration: 12s +Retries: 0 +Experiment: my-exp +User: user@example.com +Accelerators: 8x H100 +MLflow: [DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0 +Dashboard: https://my-workspace.cloud.databricks.test/jobs/runs/123 + +=== get (json) +>>> [CLI] experimental air get 123 -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "data": { + "run_id": "123", + "status": "SUCCESS", + "started_at": "[TIMESTAMP]", + "duration_seconds": 12, + "attempt_number": 0, + "experiment_name": "my-exp", + "dashboard_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "mlflow_url": "[DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0" + } +} + +=== invalid run id +>>> [CLI] experimental air get notanumber +Error: invalid RUN_ID "notanumber": must be a positive integer + +Exit code: 1 diff --git a/acceptance/experimental/air/get/script b/acceptance/experimental/air/get/script new file mode 100644 index 00000000000..e0ea8d10f85 --- /dev/null +++ b/acceptance/experimental/air/get/script @@ -0,0 +1,8 @@ +title "get (text)" +trace $CLI experimental air get 123 + +title "get (json)" +trace $CLI experimental air get 123 -o json + +title "invalid run id" +errcode trace $CLI experimental air get notanumber diff --git a/acceptance/experimental/air/get/test.toml b/acceptance/experimental/air/get/test.toml new file mode 100644 index 00000000000..b6219b87f07 --- /dev/null +++ b/acceptance/experimental/air/get/test.toml @@ -0,0 +1,40 @@ +# This command does not deploy a bundle, so no engine matrix is needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] + +# The SDK occasionally probes host reachability with a HEAD request; stub it so +# the test is deterministic. +[[Server]] +Pattern = "HEAD /" +Response.Body = '' + +# A single GenAI-compute run with an experiment, GPUs, and a creator. +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get" +Response.Body = ''' +{ + "run_id": 123, + "run_page_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "creator_user_name": "user@example.com", + "start_time": 1700000000000, + "end_time": 1700000012000, + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"}, + "tasks": [ + { + "task_key": "train", + "attempt_number": 0, + "gen_ai_compute_task": { + "mlflow_experiment_name": "/Users/user@example.com/my-exp", + "compute": {"gpu_type": "GPU_8xH100", "num_gpus": 8} + } + } + ] +} +''' + +# MLflow identifiers for the deep-link (runs/get-output is not modeled by the typed SDK). +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get-output" +Response.Body = ''' +{"gen_ai_compute_output": {"run_info": {"mlflow_experiment_id": "exp1", "mlflow_run_id": "run1"}}} +''' diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 4a07a38a378..0a86360c78f 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -5,12 +5,6 @@ Error: `air run` is not implemented yet Exit code: 1 -=== get ->>> [CLI] experimental air get 123 -Error: `air get` is not implemented yet - -Exit code: 1 - === list >>> [CLI] experimental air list Error: `air list` is not implemented yet diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index 2ed885c0e66..e6e8d33ef9d 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -3,9 +3,6 @@ title "run" errcode trace $CLI experimental air run -title "get" -errcode trace $CLI experimental air get 123 - title "list" errcode trace $CLI experimental air list diff --git a/experimental/air/cmd/format.go b/experimental/air/cmd/format.go new file mode 100644 index 00000000000..88f620ee7c3 --- /dev/null +++ b/experimental/air/cmd/format.go @@ -0,0 +1,154 @@ +package aircmd + +import ( + "fmt" + "strings" + "time" + + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// gpuDisplayNames maps the GPU identifiers returned by the backend to the short +// names we show to users. Unknown identifiers are shown unchanged. +var gpuDisplayNames = map[string]string{ + "h100_80gb": "H100", + "a10": "A10", + "GPU_1xA10": "A10", + "GPU_8xH100": "H100", + "GPU_1xH100": "H100", +} + +// runStatus returns the single status word to show for a run. The backend +// reports two values: a lifecycle state (e.g. PENDING, RUNNING) and, once the +// run has finished, a result state (e.g. SUCCESS, FAILED). The result state is +// the more meaningful one, so we prefer it when it is set. +func runStatus(state *jobs.RunState) string { + if state == nil { + return "UNKNOWN" + } + if state.ResultState != "" { + return string(state.ResultState) + } + if state.LifeCycleState != "" { + return string(state.LifeCycleState) + } + return "UNKNOWN" +} + +// startedAt converts the run's start time (epoch milliseconds) to an RFC 3339 +// UTC string, or returns nil if the run has not started yet. +func startedAt(run *jobs.Run) *string { + if run.StartTime == 0 { + return nil + } + s := time.UnixMilli(run.StartTime).UTC().Format(time.RFC3339) + return &s +} + +// durationSeconds returns how long the run has taken, in whole seconds, or nil +// if it has not started. For a finished run this is the elapsed time; for a +// still-running run it is the time since it started. +func durationSeconds(run *jobs.Run) *int64 { + if run.StartTime == 0 { + return nil + } + + var endMillis int64 + switch { + case run.RunDuration > 0: + // The backend already computed the duration for us. + d := run.RunDuration / 1000 + return &d + case run.EndTime > 0: + endMillis = run.EndTime + default: + // Still running: measure against the current time. + endMillis = time.Now().UnixMilli() + } + + d := (endMillis - run.StartTime) / 1000 + return &d +} + +// formatDuration turns a number of seconds into a compact human string such as +// "1h 2m 3s". Trailing zero units are dropped, but a lone "0s" is kept so the +// result is never empty. +func formatDuration(totalSeconds int64) string { + hours := totalSeconds / 3600 + minutes := (totalSeconds % 3600) / 60 + seconds := totalSeconds % 60 + + var parts []string + if hours > 0 { + parts = append(parts, fmt.Sprintf("%dh", hours)) + } + if minutes > 0 { + parts = append(parts, fmt.Sprintf("%dm", minutes)) + } + if seconds > 0 || len(parts) == 0 { + parts = append(parts, fmt.Sprintf("%ds", seconds)) + } + return strings.Join(parts, " ") +} + +// latestAttemptNumber returns the retry count of the run's most recent task. +// Tasks start at attempt 0, so a value of 0 means the run has not been retried. +func latestAttemptNumber(run *jobs.Run) int { + if len(run.Tasks) == 0 { + return 0 + } + return run.Tasks[len(run.Tasks)-1].AttemptNumber +} + +// experimentName returns the MLflow experiment name for the run, or nil if there +// isn't one. Experiment names are often stored under a user's home folder (e.g. +// "/Users/me@example.com/my-experiment"); we strip that prefix so users see just +// the experiment name they chose. +func experimentName(run *jobs.Run) *string { + if len(run.Tasks) == 0 { + return nil + } + task := run.Tasks[0].GenAiComputeTask + if task == nil || task.MlflowExperimentName == "" { + return nil + } + name := stripExperimentUserPrefix(task.MlflowExperimentName) + return &name +} + +// stripExperimentUserPrefix removes a leading "/Users//" from an +// experiment name, leaving the remainder. Names without that prefix are returned +// unchanged. +func stripExperimentUserPrefix(name string) string { + if !strings.HasPrefix(name, "/Users/") { + return name + } + // Split into ["", "Users", "", ""]; keep "". + parts := strings.SplitN(name, "/", 4) + if len(parts) == 4 { + return parts[3] + } + return name +} + +// accelerators returns a short description of the GPUs the run uses, such as +// "8x H100", or an empty string if the run has no GPU compute attached. +func accelerators(run *jobs.Run) string { + if len(run.Tasks) == 0 { + return "" + } + task := run.Tasks[0].GenAiComputeTask + if task == nil || task.Compute == nil || task.Compute.NumGpus == 0 { + return "" + } + return fmt.Sprintf("%dx %s", task.Compute.NumGpus, gpuDisplayName(task.Compute.GpuType)) +} + +// gpuDisplayName returns the friendly name for a GPU identifier, falling back to +// the identifier itself when it is not one we recognize. +func gpuDisplayName(gpuType string) string { + if name, ok := gpuDisplayNames[gpuType]; ok { + return name + } + return gpuType +} diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go new file mode 100644 index 00000000000..c3e2e865b81 --- /dev/null +++ b/experimental/air/cmd/format_test.go @@ -0,0 +1,131 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFormatDuration(t *testing.T) { + cases := []struct { + seconds int64 + want string + }{ + {0, "0s"}, + {45, "45s"}, + {60, "1m"}, + {63, "1m 3s"}, + {3600, "1h"}, + {3723, "1h 2m 3s"}, + {7260, "2h 1m"}, + } + for _, c := range cases { + assert.Equal(t, c.want, formatDuration(c.seconds)) + } +} + +func TestStripExperimentUserPrefix(t *testing.T) { + cases := []struct { + name string + want string + }{ + {"/Users/me@example.com/my-experiment", "my-experiment"}, + {"/Users/me@example.com/nested/path", "nested/path"}, + {"my-experiment", "my-experiment"}, + {"/Shared/team-experiment", "/Shared/team-experiment"}, + {"/Users/me@example.com", "/Users/me@example.com"}, + } + for _, c := range cases { + assert.Equal(t, c.want, stripExperimentUserPrefix(c.name)) + } +} + +func TestGpuDisplayName(t *testing.T) { + assert.Equal(t, "H100", gpuDisplayName("h100_80gb")) + assert.Equal(t, "A10", gpuDisplayName("GPU_1xA10")) + assert.Equal(t, "A10", gpuDisplayName("a10")) + assert.Equal(t, "H100", gpuDisplayName("GPU_8xH100")) + assert.Equal(t, "H100", gpuDisplayName("GPU_1xH100")) + // Unknown identifiers pass through unchanged. + assert.Equal(t, "b200", gpuDisplayName("b200")) + assert.Equal(t, "", gpuDisplayName("")) +} + +func TestRunStatusPrefersResultState(t *testing.T) { + // Result state wins once the run has finished. + assert.Equal(t, "SUCCESS", runStatus(&jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateTerminated, + ResultState: jobs.RunResultStateSuccess, + })) + // Before completion only the lifecycle state is set. + assert.Equal(t, "RUNNING", runStatus(&jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateRunning, + })) + // Non-nil state with neither field set, and nil state. + assert.Equal(t, "UNKNOWN", runStatus(&jobs.RunState{})) + assert.Equal(t, "UNKNOWN", runStatus(nil)) +} + +func TestStartedAt(t *testing.T) { + // Not started yet. + assert.Nil(t, startedAt(&jobs.Run{})) + // 1700000000000 ms == 2023-11-14T22:13:20Z. + got := startedAt(&jobs.Run{StartTime: 1700000000000}) + require.NotNil(t, got) + assert.Equal(t, "2023-11-14T22:13:20Z", *got) +} + +func TestDurationSeconds(t *testing.T) { + // Not started yet. + assert.Nil(t, durationSeconds(&jobs.Run{})) + + // Backend-provided duration wins (milliseconds → seconds). + d := durationSeconds(&jobs.Run{StartTime: 1700000000000, RunDuration: 5000}) + require.NotNil(t, d) + assert.Equal(t, int64(5), *d) + + // Finished run with no RunDuration: end - start. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000, EndTime: 1700000012000}) + require.NotNil(t, d) + assert.Equal(t, int64(12), *d) + + // Still running: measured against the current time, so positive. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000}) + require.NotNil(t, d) + assert.Positive(t, *d) +} + +func TestLatestAttemptNumber(t *testing.T) { + assert.Equal(t, 0, latestAttemptNumber(&jobs.Run{})) + run := &jobs.Run{Tasks: []jobs.RunTask{{AttemptNumber: 0}, {AttemptNumber: 2}}} + assert.Equal(t, 2, latestAttemptNumber(run)) +} + +func TestExperimentName(t *testing.T) { + assert.Nil(t, experimentName(&jobs.Run{})) + assert.Nil(t, experimentName(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Nil(t, experimentName(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: ""}, + }}})) + got := experimentName(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: "/Users/me@example.com/exp"}, + }}}) + require.NotNil(t, got) + assert.Equal(t, "exp", *got) +} + +func TestAccelerators(t *testing.T) { + assert.Equal(t, "", accelerators(&jobs.Run{})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{}, + }}})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 0}}, + }}})) + assert.Equal(t, "8x H100", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 8, GpuType: "GPU_8xH100"}}, + }}})) +} diff --git a/experimental/air/cmd/get.go b/experimental/air/cmd/get.go index 45f93df10a6..b5b15bd6003 100644 --- a/experimental/air/cmd/get.go +++ b/experimental/air/cmd/get.go @@ -1,19 +1,187 @@ package aircmd import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/spf13/cobra" ) +// getData is the payload printed by `air get`. The json-tagged fields form +// the machine-readable output; fields tagged `json:"-"` are shown only in the +// human-readable text view. +type getData struct { + RunID string `json:"run_id"` + Status string `json:"status"` + StartedAt *string `json:"started_at"` + DurationSeconds *int64 `json:"duration_seconds"` + AttemptNumber int `json:"attempt_number"` + ExperimentName *string `json:"experiment_name"` + DashboardURL string `json:"dashboard_url"` + MLflowURL *string `json:"mlflow_url"` + + // Duration is the human-readable form of DurationSeconds, e.g. "12m 3s". + Duration string `json:"-"` + // Accelerators describes the run's GPUs, e.g. "8x H100". + Accelerators string `json:"-"` + // User is the run's creator. Text-only; JSON omits it, matching `air get --json`. + User string `json:"-"` + // Sweep replaces the single-run view for foreach runs. Text-only; JSON omits it. + Sweep *sweepInfo `json:"-"` +} + +// getTemplate is the text-mode layout. It reads from the JSON envelope, so +// every field is reached through ".Data". Optional rows are hidden when empty. +const getTemplate = `{{- if .Data.Sweep -}} +Sweep Run ID: {{.Data.RunID}} +Status: {{.Data.Status}} +Total: {{.Data.Sweep.Total}} +Completed: {{.Data.Sweep.Completed}} +Succeeded: {{.Data.Sweep.Succeeded}} +Failed: {{.Data.Sweep.Failed}} +Active: {{.Data.Sweep.Active}} +{{- if .Data.Sweep.Tasks}} + +Sweep Tasks: +{{printf " %-24s %-14s %-12s %s" "TASK" "RUN ID" "STATUS" "EXPERIMENT"}} +{{- range .Data.Sweep.Tasks}} +{{printf " %-24s %-14s %-12s %s" .TaskKey .RunID .Status .Experiment}} +{{- end}} +{{- end}} +{{- else -}} +Run ID: {{.Data.RunID}} +Status: {{.Data.Status}} +{{- if .Data.StartedAt}} +Submitted: {{.Data.StartedAt}} +{{- end}} +{{- if .Data.Duration}} +Duration: {{.Data.Duration}} +{{- end}} +Retries: {{.Data.AttemptNumber}} +{{- if .Data.ExperimentName}} +Experiment: {{.Data.ExperimentName}} +{{- end}} +{{- if .Data.User}} +User: {{.Data.User}} +{{- end}} +{{- if .Data.Accelerators}} +Accelerators: {{.Data.Accelerators}} +{{- end}} +{{- if .Data.MLflowURL}} +MLflow: {{.Data.MLflowURL}} +{{- end}} +Dashboard: {{.Data.DashboardURL}} +{{- end}} +` + func newGetCommand() *cobra.Command { cmd := &cobra.Command{ Use: "get JOB_RUN_ID", Args: root.ExactArgs(1), Short: "Show details for a run", - RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("get") + Annotations: map[string]string{ + "template": getTemplate, }, } + cmd.PreRunE = root.MustWorkspaceClient + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + + runID, err := strconv.ParseInt(args[0], 10, 64) + if err != nil || runID <= 0 { + return fmt.Errorf("invalid RUN_ID %q: must be a positive integer", args[0]) + } + + run, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: runID}) + if err != nil { + // The backend returns this when the run ID is unknown to the user. + if errors.Is(err, apierr.ErrResourceDoesNotExist) { + return fmt.Errorf("run %d not found: check the run ID and that it is a job run ID", runID) + } + return fmt.Errorf("failed to get status for run %d: %w", runID, err) + } + + data := buildGetData(run) + data.MLflowURL = mlflowURL(ctx, w, run) + if task := findForEachTask(run); task != nil { + data.Sweep = buildSweepInfo(ctx, w, task) + } + + // Text mode shows the training-config YAML before the status, mirroring + // `air get`. JSON output omits it, matching `air get --json`. + if root.OutputType(cmd) == flags.OutputText { + if path := yamlConfigPath(run); path != "" { + printConfigYAML(ctx, w, path) + } + } + return renderEnvelope(ctx, data) + } + return cmd } + +// buildGetData extracts the fields we display from a run. +func buildGetData(run *jobs.Run) getData { + data := getData{ + RunID: strconv.FormatInt(run.RunId, 10), + Status: runStatus(run.State), + StartedAt: startedAt(run), + DurationSeconds: durationSeconds(run), + AttemptNumber: latestAttemptNumber(run), + ExperimentName: experimentName(run), + DashboardURL: run.RunPageUrl, + Accelerators: accelerators(run), + User: run.CreatorUserName, + } + if data.DurationSeconds != nil { + data.Duration = formatDuration(*data.DurationSeconds) + } + return data +} + +// yamlConfigPath returns the run's training-config YAML path, or "" if none. +func yamlConfigPath(run *jobs.Run) string { + if len(run.Tasks) == 0 { + return "" + } + task := run.Tasks[0].GenAiComputeTask + if task == nil { + return "" + } + return task.YamlParametersFilePath +} + +// printConfigYAML downloads the run's training-config YAML and prints it. It is +// best-effort: a failure is surfaced as a warning but does not fail status. +func printConfigYAML(ctx context.Context, w *databricks.WorkspaceClient, path string) { + r, err := w.Workspace.Download(ctx, path) + if err != nil { + log.Warnf(ctx, "air get: could not download training config %s: %v", path, err) + return + } + defer r.Close() + + content, err := io.ReadAll(r) + if err != nil { + log.Warnf(ctx, "air get: could not read training config %s: %v", path, err) + return + } + + cmdio.LogString(ctx, "Training Configuration:") + cmdio.LogString(ctx, string(content)) + cmdio.LogString(ctx, "") +} diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go new file mode 100644 index 00000000000..6dfdc54db7b --- /dev/null +++ b/experimental/air/cmd/get_test.go @@ -0,0 +1,211 @@ +package aircmd + +import ( + "bytes" + "io" + "strings" + "testing" + "text/template" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// renderGet renders the status template against the JSON envelope, exactly as +// the command does, so the test covers the real template branches. +func renderGet(t *testing.T, data getData) string { + t.Helper() + tmpl, err := template.New("status").Parse(getTemplate) + require.NoError(t, err) + var buf bytes.Buffer + require.NoError(t, tmpl.Execute(&buf, envelope{V: envelopeVersion, Data: data})) + return buf.String() +} + +func TestGetTemplateSingleRun(t *testing.T) { + out := renderGet(t, getData{ + RunID: "123", + Status: "RUNNING", + User: "me@example.com", + DashboardURL: "https://example.test/run/123", + }) + assert.Contains(t, out, "Run ID: 123") + assert.Contains(t, out, "Status: RUNNING") + assert.Contains(t, out, "User:") + assert.Contains(t, out, "me@example.com") + assert.Contains(t, out, "Dashboard: https://example.test/run/123") + assert.NotContains(t, out, "Sweep") +} + +func TestGetRunInvalidID(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + cmd := newGetCommand() + cmd.SetContext(ctx) + + err := cmd.RunE(cmd, []string{"abc"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid RUN_ID") +} + +func TestGetRunNotFound(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 5}).Return( + nil, apierr.ErrResourceDoesNotExist) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + cmd := newGetCommand() + cmd.SetContext(ctx) + + err := cmd.RunE(cmd, []string{"5"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "run 5 not found") +} + +func TestPrintConfigYAML(t *testing.T) { + t.Run("downloads and prints", func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + // The mock asserts Download is called with the resolved path. + m.GetMockWorkspaceAPI().EXPECT(). + Download(mock.Anything, "/Workspace/cfg.yaml"). + Return(io.NopCloser(strings.NewReader("epochs: 3\n")), nil) + + printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/cfg.yaml") + }) + + t.Run("download failure is non-fatal", func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + m.GetMockWorkspaceAPI().EXPECT(). + Download(mock.Anything, "/Workspace/missing.yaml"). + Return(nil, apierr.ErrResourceDoesNotExist) + + // Must not panic: a failed config fetch is best-effort. + printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/missing.yaml") + }) +} + +func TestYAMLConfigPath(t *testing.T) { + // No tasks, or a task without GenAiComputeTask, yields no path. + assert.Equal(t, "", yamlConfigPath(&jobs.Run{})) + assert.Equal(t, "", yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + + run := &jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{YamlParametersFilePath: "/Workspace/cfg.yaml"}, + }}} + assert.Equal(t, "/Workspace/cfg.yaml", yamlConfigPath(run)) +} + +func TestGetTemplateSweep(t *testing.T) { + out := renderGet(t, getData{ + RunID: "456", + Status: "RUNNING", + Sweep: &sweepInfo{ + Total: 4, Completed: 2, Succeeded: 1, Failed: 1, Active: 2, + Tasks: []sweepTask{ + {TaskKey: "iter_0", RunID: "789", Status: "SUCCESS", Experiment: "my-exp"}, + {TaskKey: "iter_1", RunID: "790", Status: "FAILED", Experiment: "my-exp"}, + }, + }, + }) + assert.Contains(t, out, "Sweep Run ID: 456") + assert.Contains(t, out, "Total: 4") + assert.Contains(t, out, "Sweep Tasks:") + assert.Contains(t, out, "iter_0") + assert.Contains(t, out, "iter_1") + assert.Contains(t, out, "FAILED") + assert.Contains(t, out, "my-exp") + // The single-run rows must not appear in the sweep view. + assert.NotContains(t, out, "Dashboard:") +} + +func TestGetTemplateSweepNoTasks(t *testing.T) { + // A sweep whose iterations haven't materialized yet: counts show, but the + // task table header is hidden. + out := renderGet(t, getData{ + RunID: "456", + Status: "RUNNING", + Sweep: &sweepInfo{Total: 4, Active: 4}, + }) + assert.Contains(t, out, "Sweep Run ID: 456") + assert.Contains(t, out, "Total: 4") + assert.NotContains(t, out, "Sweep Tasks:") +} + +func TestGetTemplateMinimal(t *testing.T) { + // Only the always-present rows render; optional rows are hidden when empty. + out := renderGet(t, getData{RunID: "1", Status: "PENDING", DashboardURL: "https://example.test/1"}) + assert.Contains(t, out, "Run ID: 1") + assert.Contains(t, out, "Status: PENDING") + assert.Contains(t, out, "Retries: 0") + assert.Contains(t, out, "Dashboard: https://example.test/1") + for _, hidden := range []string{"Submitted:", "Duration:", "Experiment:", "User:", "Accelerators:", "MLflow:"} { + assert.NotContains(t, out, hidden) + } +} + +func TestGetTemplateAllFields(t *testing.T) { + started := "2023-11-14T22:13:20Z" + exp := "exp" + mlflow := "https://example.test/ml/exp/1" + out := renderGet(t, getData{ + RunID: "1", + Status: "SUCCESS", + StartedAt: &started, + Duration: "12s", + AttemptNumber: 2, + ExperimentName: &exp, + User: "me@example.com", + Accelerators: "8x H100", + MLflowURL: &mlflow, + DashboardURL: "https://example.test/1", + }) + for _, want := range []string{ + "Submitted: 2023-11-14T22:13:20Z", + "Duration: 12s", + "Retries: 2", + "Experiment: exp", + "User: me@example.com", + "Accelerators: 8x H100", + "MLflow: https://example.test/ml/exp/1", + "Dashboard: https://example.test/1", + } { + assert.Contains(t, out, want) + } +} + +func TestBuildStatusData(t *testing.T) { + run := &jobs.Run{ + RunId: 123, + RunPageUrl: "https://example.test/run/123", + CreatorUserName: "me@example.com", + StartTime: 1700000000000, + EndTime: 1700000012000, + State: &jobs.RunState{ResultState: jobs.RunResultStateSuccess}, + Tasks: []jobs.RunTask{{ + AttemptNumber: 1, + GenAiComputeTask: &jobs.GenAiComputeTask{ + MlflowExperimentName: "/Users/me@example.com/exp", + Compute: &jobs.ComputeConfig{NumGpus: 8, GpuType: "GPU_8xH100"}, + }, + }}, + } + d := buildGetData(run) + assert.Equal(t, "123", d.RunID) + assert.Equal(t, "SUCCESS", d.Status) + assert.Equal(t, 1, d.AttemptNumber) + assert.Equal(t, "https://example.test/run/123", d.DashboardURL) + assert.Equal(t, "me@example.com", d.User) + assert.Equal(t, "8x H100", d.Accelerators) + assert.Equal(t, "12s", d.Duration) + require.NotNil(t, d.ExperimentName) + assert.Equal(t, "exp", *d.ExperimentName) + require.NotNil(t, d.DurationSeconds) + assert.Equal(t, int64(12), *d.DurationSeconds) +} diff --git a/experimental/air/cmd/mlflow.go b/experimental/air/cmd/mlflow.go new file mode 100644 index 00000000000..97d085b0128 --- /dev/null +++ b/experimental/air/cmd/mlflow.go @@ -0,0 +1,65 @@ +package aircmd + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// getRunOutputResponse is the slice of the jobs runs/get-output response we care +// about. The MLflow identifiers live under a gen_ai_compute_output field that +// the typed SDK does not model, so we call the endpoint directly and parse just +// these fields. +type getRunOutputResponse struct { + GenAiComputeOutput *struct { + RunInfo *struct { + MlflowExperimentID string `json:"mlflow_experiment_id"` + MlflowRunID string `json:"mlflow_run_id"` + } `json:"run_info"` + } `json:"gen_ai_compute_output"` +} + +// mlflowURL returns a link to the run's MLflow logs, or nil if it can't be +// built. The link is a convenience, so any failure here (missing task, endpoint +// error, run not yet started) is logged and treated as "no link" rather than +// failing the whole command. +func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run) *string { + if len(run.Tasks) == 0 { + return nil + } + // The MLflow output is attached to the task run, not the parent job run. + taskRunID := run.Tasks[0].RunId + + apiClient, err := client.New(w.Config) + if err != nil { + log.Debugf(ctx, "air get: could not build API client for MLflow link: %v", err) + return nil + } + + var out getRunOutputResponse + err = apiClient.Do(ctx, http.MethodGet, "/api/2.2/jobs/runs/get-output", + nil, map[string]any{"run_id": taskRunID}, nil, &out) + if err != nil { + log.Debugf(ctx, "air get: could not fetch run output for MLflow link: %v", err) + return nil + } + + if out.GenAiComputeOutput == nil || out.GenAiComputeOutput.RunInfo == nil { + return nil + } + info := out.GenAiComputeOutput.RunInfo + if info.MlflowExperimentID == "" || info.MlflowRunID == "" { + return nil + } + + host := strings.TrimRight(w.Config.Host, "/") + url := fmt.Sprintf("%s/ml/experiments/%s/runs/%s/artifacts/logs/node_0", + host, info.MlflowExperimentID, info.MlflowRunID) + return &url +} diff --git a/experimental/air/cmd/mlflow_test.go b/experimental/air/cmd/mlflow_test.go new file mode 100644 index 00000000000..bbc4fef9822 --- /dev/null +++ b/experimental/air/cmd/mlflow_test.go @@ -0,0 +1,64 @@ +package aircmd + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestWorkspaceClient builds a WorkspaceClient pointed at a mock HTTP server. +// mlflowURL calls the runs/get-output REST endpoint directly (the field it needs +// is not modeled by the typed SDK), so it must be exercised over HTTP. +func newTestWorkspaceClient(t *testing.T, host string) *databricks.WorkspaceClient { + t.Helper() + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: host, Token: "token"}) + require.NoError(t, err) + return w +} + +// runOutputServer serves the given runs/get-output body and a stub for the SDK's +// well-known config discovery request. *hit is set when get-output is called. +func runOutputServer(t *testing.T, body string, hit *bool) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/2.2/jobs/runs/get-output" { + *hit = true + _, _ = w.Write([]byte(body)) + return + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestMLflowURL(t *testing.T) { + ctx := t.Context() + run := &jobs.Run{Tasks: []jobs.RunTask{{RunId: 99}}} + + t.Run("builds the deep-link on success", func(t *testing.T) { + var hit bool + srv := runOutputServer(t, `{"gen_ai_compute_output":{"run_info":{"mlflow_experiment_id":"E1","mlflow_run_id":"R1"}}}`, &hit) + + got := mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run) + require.NotNil(t, got) + assert.True(t, hit, "runs/get-output should have been called") + assert.Equal(t, srv.URL+"/ml/experiments/E1/runs/R1/artifacts/logs/node_0", *got) + }) + + t.Run("nil when the run has no MLflow info", func(t *testing.T) { + var hit bool + srv := runOutputServer(t, `{}`, &hit) + assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run)) + }) + + t.Run("nil when the run has no tasks", func(t *testing.T) { + // Returns before any HTTP call, so the host is never contacted. + assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, "https://unused.invalid"), &jobs.Run{})) + }) +} diff --git a/experimental/air/cmd/output.go b/experimental/air/cmd/output.go new file mode 100644 index 00000000000..3da766a7d4f --- /dev/null +++ b/experimental/air/cmd/output.go @@ -0,0 +1,39 @@ +package aircmd + +import ( + "context" + "time" + + "github.com/databricks/cli/libs/cmdio" +) + +// envelopeVersion is the envelope's format-version marker. The Python `air` CLI +// hardcodes it to 1; it lets consumers detect a future incompatible change to +// the envelope shape. +const envelopeVersion = 1 + +// envelope is the JSON shape that the AI runtime CLI prints: +// +// { "v": 1, "ts": "2024-01-15T14:30:45Z", "data": { ... } } +// +// It mirrors the envelope used by the original Python `air` CLI so existing +// consumers keep working after the port to Go. +type envelope struct { + // V is the envelope format-version marker (always 1). + V int `json:"v"` + // TS is the wall-clock time the response was produced, in RFC 3339 UTC. + // It is an absolute timestamp, not an elapsed duration. + TS string `json:"ts"` + // Data is the command-specific payload. + Data any `json:"data"` +} + +// renderEnvelope wraps data in the JSON envelope and prints it. +// Fields that should appear only in text output are tagged `json:"-"` on the payload struct. +func renderEnvelope(ctx context.Context, data any) error { + return cmdio.Render(ctx, envelope{ + V: envelopeVersion, + TS: time.Now().UTC().Format(time.RFC3339), + Data: data, + }) +} diff --git a/experimental/air/cmd/output_test.go b/experimental/air/cmd/output_test.go new file mode 100644 index 00000000000..73a5572c3f5 --- /dev/null +++ b/experimental/air/cmd/output_test.go @@ -0,0 +1,13 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/require" +) + +func TestRenderEnvelope(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + require.NoError(t, renderEnvelope(ctx, getData{RunID: "1", Status: "RUNNING"})) +} diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index a6e24177f33..5e35bcdcd14 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -14,7 +14,6 @@ import ( func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ "run": newRunCommand(), - "get": newGetCommand(), "list": newListCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), diff --git a/experimental/air/cmd/sweep.go b/experimental/air/cmd/sweep.go new file mode 100644 index 00000000000..b346f43f1b6 --- /dev/null +++ b/experimental/air/cmd/sweep.go @@ -0,0 +1,76 @@ +package aircmd + +import ( + "context" + "strconv" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// sweepInfo summarizes a "foreach" run, which fans a single config out into many +// iterations (a hyperparameter sweep). It is shown only in text output. +type sweepInfo struct { + Total int + Succeeded int + Failed int + Active int + Completed int + Tasks []sweepTask +} + +// sweepTask is one iteration of a sweep. +type sweepTask struct { + TaskKey string + RunID string + Status string + Experiment string +} + +// findForEachTask returns the run's foreach task if it has one, or nil. A run is +// a sweep when one of its tasks fans out into iterations. +func findForEachTask(run *jobs.Run) *jobs.RunTask { + for i := range run.Tasks { + if run.Tasks[i].ForEachTask != nil { + return &run.Tasks[i] + } + } + return nil +} + +// buildSweepInfo gathers the iteration counts and per-iteration rows for a +// sweep. The counts come from the task we already have; the individual +// iterations require a second lookup. If that lookup fails we still return the +// counts (logging the failure) so the user sees the summary. +func buildSweepInfo(ctx context.Context, w *databricks.WorkspaceClient, task *jobs.RunTask) *sweepInfo { + info := &sweepInfo{} + if task.ForEachTask.Stats != nil && task.ForEachTask.Stats.TaskRunStats != nil { + stats := task.ForEachTask.Stats.TaskRunStats + info.Total = stats.TotalIterations + info.Succeeded = stats.SucceededIterations + info.Failed = stats.FailedIterations + info.Active = stats.ActiveIterations + info.Completed = stats.CompletedIterations + } + + // The iterations are returned as part of a run lookup on the foreach task. + iterated, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: task.RunId}) + if err != nil { + log.Debugf(ctx, "air get: could not fetch sweep iterations: %v", err) + return info + } + + for _, it := range iterated.Iterations { + row := sweepTask{ + TaskKey: it.TaskKey, + RunID: strconv.FormatInt(it.RunId, 10), + Status: runStatus(it.State), + } + if it.GenAiComputeTask != nil && it.GenAiComputeTask.MlflowExperimentName != "" { + row.Experiment = stripExperimentUserPrefix(it.GenAiComputeTask.MlflowExperimentName) + } + info.Tasks = append(info.Tasks, row) + } + return info +} diff --git a/experimental/air/cmd/sweep_test.go b/experimental/air/cmd/sweep_test.go new file mode 100644 index 00000000000..10134c0df42 --- /dev/null +++ b/experimental/air/cmd/sweep_test.go @@ -0,0 +1,81 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestFindForEachTask(t *testing.T) { + // No tasks at all. + assert.Nil(t, findForEachTask(&jobs.Run{})) + + // A task that is not a foreach. + assert.Nil(t, findForEachTask(&jobs.Run{Tasks: []jobs.RunTask{{TaskKey: "a"}}})) + + // The foreach task is found even when it isn't first. + run := &jobs.Run{Tasks: []jobs.RunTask{ + {TaskKey: "a"}, + {TaskKey: "sweep", ForEachTask: &jobs.RunForEachTask{}}, + }} + got := findForEachTask(run) + require.NotNil(t, got) + assert.Equal(t, "sweep", got.TaskKey) +} + +func sweepTaskFixture() *jobs.RunTask { + return &jobs.RunTask{ + RunId: 99, + ForEachTask: &jobs.RunForEachTask{ + Stats: &jobs.ForEachStats{TaskRunStats: &jobs.ForEachTaskTaskRunStats{ + TotalIterations: 4, + SucceededIterations: 1, + FailedIterations: 1, + ActiveIterations: 2, + CompletedIterations: 2, + }}, + }, + } +} + +func TestBuildSweepInfo(t *testing.T) { + ctx := t.Context() + + t.Run("counts and iteration rows", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 99}).Return( + &jobs.Run{Iterations: []jobs.RunTask{{ + TaskKey: "iter_0", + RunId: 100, + State: &jobs.RunState{ResultState: jobs.RunResultStateSuccess}, + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: "/Users/me@example.com/exp"}, + }}}, nil) + + info := buildSweepInfo(ctx, m.WorkspaceClient, sweepTaskFixture()) + assert.Equal(t, 4, info.Total) + assert.Equal(t, 2, info.Completed) + assert.Equal(t, 1, info.Succeeded) + assert.Equal(t, 1, info.Failed) + assert.Equal(t, 2, info.Active) + require.Len(t, info.Tasks, 1) + assert.Equal(t, "iter_0", info.Tasks[0].TaskKey) + assert.Equal(t, "100", info.Tasks[0].RunID) + assert.Equal(t, "SUCCESS", info.Tasks[0].Status) + assert.Equal(t, "exp", info.Tasks[0].Experiment) + }) + + t.Run("iteration lookup failure still returns counts", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 99}).Return( + nil, apierr.ErrResourceDoesNotExist) + + info := buildSweepInfo(ctx, m.WorkspaceClient, sweepTaskFixture()) + assert.Equal(t, 4, info.Total) + assert.Empty(t, info.Tasks) + }) +} From 38837910b91a51dd3254b7a0e8f0271710e53f41 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 23:27:54 +0000 Subject: [PATCH 08/18] experimental/air: rename stale TestBuildStatusData to TestBuildGetData Co-authored-by: Isaac --- experimental/air/cmd/get_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index 6dfdc54db7b..64fd1aa1f68 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -17,11 +17,11 @@ import ( "github.com/stretchr/testify/require" ) -// renderGet renders the status template against the JSON envelope, exactly as +// renderGet renders the get template against the JSON envelope, exactly as // the command does, so the test covers the real template branches. func renderGet(t *testing.T, data getData) string { t.Helper() - tmpl, err := template.New("status").Parse(getTemplate) + tmpl, err := template.New("get").Parse(getTemplate) require.NoError(t, err) var buf bytes.Buffer require.NoError(t, tmpl.Execute(&buf, envelope{V: envelopeVersion, Data: data})) @@ -180,7 +180,7 @@ func TestGetTemplateAllFields(t *testing.T) { } } -func TestBuildStatusData(t *testing.T) { +func TestBuildGetData(t *testing.T) { run := &jobs.Run{ RunId: 123, RunPageUrl: "https://example.test/run/123", From 472a1fe8a921f9334f7e71b2c3e30e7843375154 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 23:30:07 +0000 Subject: [PATCH 09/18] experimental/air: apply testifylint fixes in get/format tests Co-authored-by: Isaac --- experimental/air/cmd/format_test.go | 10 +++++----- experimental/air/cmd/get_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go index c3e2e865b81..583f3cc3111 100644 --- a/experimental/air/cmd/format_test.go +++ b/experimental/air/cmd/format_test.go @@ -50,7 +50,7 @@ func TestGpuDisplayName(t *testing.T) { assert.Equal(t, "H100", gpuDisplayName("GPU_1xH100")) // Unknown identifiers pass through unchanged. assert.Equal(t, "b200", gpuDisplayName("b200")) - assert.Equal(t, "", gpuDisplayName("")) + assert.Empty(t, gpuDisplayName("")) } func TestRunStatusPrefersResultState(t *testing.T) { @@ -117,12 +117,12 @@ func TestExperimentName(t *testing.T) { } func TestAccelerators(t *testing.T) { - assert.Equal(t, "", accelerators(&jobs.Run{})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + assert.Empty(t, accelerators(&jobs.Run{})) + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{}, }}})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 0}}, }}})) assert.Equal(t, "8x H100", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index 64fd1aa1f68..b6d6d0baab9 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -93,8 +93,8 @@ func TestPrintConfigYAML(t *testing.T) { func TestYAMLConfigPath(t *testing.T) { // No tasks, or a task without GenAiComputeTask, yields no path. - assert.Equal(t, "", yamlConfigPath(&jobs.Run{})) - assert.Equal(t, "", yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Empty(t, yamlConfigPath(&jobs.Run{})) + assert.Empty(t, yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) run := &jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{YamlParametersFilePath: "/Workspace/cfg.yaml"}, From 0ab1008a7bef43bc348ccd9e76112973fafb8d9f Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Tue, 16 Jun 2026 21:25:38 +0000 Subject: [PATCH 10/18] experimental/air: print `air get` training config to stdout The training-config block is command result data, but it was emitted via cmdio.LogString, which targets stderr. Write it to cmd.OutOrStdout() instead so it lands on stdout, matching the Python `air get`. Download/read failures stay on stderr as warnings. Co-authored-by: Isaac --- experimental/air/cmd/get.go | 16 ++++++++-------- experimental/air/cmd/get_test.go | 18 +++++++++++++----- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/experimental/air/cmd/get.go b/experimental/air/cmd/get.go index b5b15bd6003..7ed1fbaff01 100644 --- a/experimental/air/cmd/get.go +++ b/experimental/air/cmd/get.go @@ -9,7 +9,6 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdctx" - "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" @@ -125,7 +124,7 @@ func newGetCommand() *cobra.Command { // `air get`. JSON output omits it, matching `air get --json`. if root.OutputType(cmd) == flags.OutputText { if path := yamlConfigPath(run); path != "" { - printConfigYAML(ctx, w, path) + printConfigYAML(ctx, cmd.OutOrStdout(), w, path) } } return renderEnvelope(ctx, data) @@ -165,9 +164,10 @@ func yamlConfigPath(run *jobs.Run) string { return task.YamlParametersFilePath } -// printConfigYAML downloads the run's training-config YAML and prints it. It is -// best-effort: a failure is surfaced as a warning but does not fail status. -func printConfigYAML(ctx context.Context, w *databricks.WorkspaceClient, path string) { +// printConfigYAML downloads the run's training-config YAML and writes it to out +// (stdout), mirroring the Python `air get`. It is best-effort: a download or read +// failure is surfaced as a warning on stderr but does not fail the command. +func printConfigYAML(ctx context.Context, out io.Writer, w *databricks.WorkspaceClient, path string) { r, err := w.Workspace.Download(ctx, path) if err != nil { log.Warnf(ctx, "air get: could not download training config %s: %v", path, err) @@ -181,7 +181,7 @@ func printConfigYAML(ctx context.Context, w *databricks.WorkspaceClient, path st return } - cmdio.LogString(ctx, "Training Configuration:") - cmdio.LogString(ctx, string(content)) - cmdio.LogString(ctx, "") + fmt.Fprintln(out, "Training Configuration:") + fmt.Fprintln(out, string(content)) + fmt.Fprintln(out) } diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index b6d6d0baab9..20974e326c4 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -68,7 +68,7 @@ func TestGetRunNotFound(t *testing.T) { } func TestPrintConfigYAML(t *testing.T) { - t.Run("downloads and prints", func(t *testing.T) { + t.Run("downloads and prints to stdout", func(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) m := mocks.NewMockWorkspaceClient(t) // The mock asserts Download is called with the resolved path. @@ -76,18 +76,26 @@ func TestPrintConfigYAML(t *testing.T) { Download(mock.Anything, "/Workspace/cfg.yaml"). Return(io.NopCloser(strings.NewReader("epochs: 3\n")), nil) - printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/cfg.yaml") + // The config is data output and must land on stdout (the out writer), + // matching the Python `air get` behavior. + var out bytes.Buffer + printConfigYAML(ctx, &out, m.WorkspaceClient, "/Workspace/cfg.yaml") + assert.Contains(t, out.String(), "Training Configuration:") + assert.Contains(t, out.String(), "epochs: 3") }) - t.Run("download failure is non-fatal", func(t *testing.T) { + t.Run("download failure is non-fatal and writes nothing", func(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) m := mocks.NewMockWorkspaceClient(t) m.GetMockWorkspaceAPI().EXPECT(). Download(mock.Anything, "/Workspace/missing.yaml"). Return(nil, apierr.ErrResourceDoesNotExist) - // Must not panic: a failed config fetch is best-effort. - printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/missing.yaml") + // Must not panic and must not write to stdout: a failed config fetch is + // best-effort, surfaced only as a stderr warning. + var out bytes.Buffer + printConfigYAML(ctx, &out, m.WorkspaceClient, "/Workspace/missing.yaml") + assert.Empty(t, out.String()) }) } From 2615cd740e2d7fea3f11fbdaedebed9054e034bf Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Tue, 16 Jun 2026 22:28:44 +0000 Subject: [PATCH 11/18] experimental/air: report latest-attempt timing and round duration `air get` derived Submitted and Duration from run-level start/end and truncated milliseconds to seconds. Port Python's _reported_attempt_timing so a retried run reports its latest attempt, and round to the nearest second to match Python's round(). Drops the run-level RunDuration shortcut, which diverged on retries. Co-authored-by: Isaac --- experimental/air/cmd/format.go | 55 +++++++++++++++++++---------- experimental/air/cmd/format_test.go | 53 ++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 24 deletions(-) diff --git a/experimental/air/cmd/format.go b/experimental/air/cmd/format.go index 88f620ee7c3..77c5e5a100a 100644 --- a/experimental/air/cmd/format.go +++ b/experimental/air/cmd/format.go @@ -35,41 +35,58 @@ func runStatus(state *jobs.RunState) string { return "UNKNOWN" } -// startedAt converts the run's start time (epoch milliseconds) to an RFC 3339 -// UTC string, or returns nil if the run has not started yet. +// reportedTiming returns the run's start and end times (epoch milliseconds), +// preferring the last task's window over the run-level times so a retried run +// reports its latest attempt. Mirrors Python's _reported_attempt_timing +// (cli_display.py:78-87). +func reportedTiming(run *jobs.Run) (startMillis, endMillis int64) { + startMillis, endMillis = run.StartTime, run.EndTime + if n := len(run.Tasks); n > 0 { + last := run.Tasks[n-1] + if last.StartTime > 0 { + startMillis = last.StartTime + } + if last.EndTime > 0 { + endMillis = last.EndTime + } + } + return startMillis, endMillis +} + +// startedAt converts the run's reported start time (epoch milliseconds) to an +// RFC 3339 UTC string, or returns nil if the run has not started yet. func startedAt(run *jobs.Run) *string { - if run.StartTime == 0 { + startMillis, _ := reportedTiming(run) + if startMillis == 0 { return nil } - s := time.UnixMilli(run.StartTime).UTC().Format(time.RFC3339) + s := time.UnixMilli(startMillis).UTC().Format(time.RFC3339) return &s } // durationSeconds returns how long the run has taken, in whole seconds, or nil -// if it has not started. For a finished run this is the elapsed time; for a -// still-running run it is the time since it started. +// if it has not started. For a finished run this is the elapsed time of the +// reported attempt; for a still-running run it is the time since it started. func durationSeconds(run *jobs.Run) *int64 { - if run.StartTime == 0 { + startMillis, endMillis := reportedTiming(run) + if startMillis == 0 { return nil } - - var endMillis int64 - switch { - case run.RunDuration > 0: - // The backend already computed the duration for us. - d := run.RunDuration / 1000 - return &d - case run.EndTime > 0: - endMillis = run.EndTime - default: + if endMillis == 0 { // Still running: measure against the current time. endMillis = time.Now().UnixMilli() } - - d := (endMillis - run.StartTime) / 1000 + d := roundMillisToSeconds(endMillis - startMillis) return &d } +// roundMillisToSeconds converts milliseconds to whole seconds, rounding to the +// nearest second to match Python's round() (cli_entrypoint.py:1934) rather than +// truncating. +func roundMillisToSeconds(ms int64) int64 { + return (ms + 500) / 1000 +} + // formatDuration turns a number of seconds into a compact human string such as // "1h 2m 3s". Trailing zero units are dropped, but a lone "0s" is kept so the // result is never empty. diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go index 583f3cc3111..211484c5229 100644 --- a/experimental/air/cmd/format_test.go +++ b/experimental/air/cmd/format_test.go @@ -68,6 +68,33 @@ func TestRunStatusPrefersResultState(t *testing.T) { assert.Equal(t, "UNKNOWN", runStatus(nil)) } +func TestReportedTiming(t *testing.T) { + // No tasks: run-level times are used. + start, end := reportedTiming(&jobs.Run{StartTime: 100, EndTime: 200}) + assert.Equal(t, int64(100), start) + assert.Equal(t, int64(200), end) + + // The last task's window is preferred over the run-level window, so a + // retried run reports its most recent attempt. + start, end = reportedTiming(&jobs.Run{ + StartTime: 100, EndTime: 200, + Tasks: []jobs.RunTask{ + {StartTime: 100, EndTime: 150}, + {StartTime: 300, EndTime: 450}, + }, + }) + assert.Equal(t, int64(300), start) + assert.Equal(t, int64(450), end) + + // A task missing a field falls back to the run-level value for that field. + start, end = reportedTiming(&jobs.Run{ + StartTime: 100, EndTime: 200, + Tasks: []jobs.RunTask{{StartTime: 300}}, + }) + assert.Equal(t, int64(300), start) + assert.Equal(t, int64(200), end) +} + func TestStartedAt(t *testing.T) { // Not started yet. assert.Nil(t, startedAt(&jobs.Run{})) @@ -75,22 +102,38 @@ func TestStartedAt(t *testing.T) { got := startedAt(&jobs.Run{StartTime: 1700000000000}) require.NotNil(t, got) assert.Equal(t, "2023-11-14T22:13:20Z", *got) + // The last attempt's start time is reported. 1700000060000 ms == 22:14:20Z. + got = startedAt(&jobs.Run{ + StartTime: 1700000000000, + Tasks: []jobs.RunTask{{StartTime: 1700000060000}}, + }) + require.NotNil(t, got) + assert.Equal(t, "2023-11-14T22:14:20Z", *got) } func TestDurationSeconds(t *testing.T) { // Not started yet. assert.Nil(t, durationSeconds(&jobs.Run{})) - // Backend-provided duration wins (milliseconds → seconds). - d := durationSeconds(&jobs.Run{StartTime: 1700000000000, RunDuration: 5000}) + // Finished run: reported end - start. + d := durationSeconds(&jobs.Run{StartTime: 1700000000000, EndTime: 1700000012000}) require.NotNil(t, d) - assert.Equal(t, int64(5), *d) + assert.Equal(t, int64(12), *d) - // Finished run with no RunDuration: end - start. - d = durationSeconds(&jobs.Run{StartTime: 1700000000000, EndTime: 1700000012000}) + // Sub-second remainders round to the nearest second, matching Python: an + // 11,500 ms run reports 12s, not 11s. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000, EndTime: 1700000011500}) require.NotNil(t, d) assert.Equal(t, int64(12), *d) + // The last attempt's window drives the duration for a retried run. + d = durationSeconds(&jobs.Run{ + StartTime: 1700000000000, EndTime: 1700000012000, + Tasks: []jobs.RunTask{{StartTime: 1700000000000, EndTime: 1700000005000}}, + }) + require.NotNil(t, d) + assert.Equal(t, int64(5), *d) + // Still running: measured against the current time, so positive. d = durationSeconds(&jobs.Run{StartTime: 1700000000000}) require.NotNil(t, d) From d3bb64b95fb89f0765d83cfbd16758f1ba250485 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Tue, 16 Jun 2026 22:42:00 +0000 Subject: [PATCH 12/18] experimental/air: link MLflow output to the latest attempt mlflowURL resolved runs/get-output against Tasks[0], linking a retried run to its stale first attempt. Use the last task (latest attempt) to match Python (jobs_api_client.py:68). Co-authored-by: Isaac --- experimental/air/cmd/mlflow.go | 6 ++++-- experimental/air/cmd/mlflow_test.go | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/experimental/air/cmd/mlflow.go b/experimental/air/cmd/mlflow.go index 97d085b0128..0d2a74c56ba 100644 --- a/experimental/air/cmd/mlflow.go +++ b/experimental/air/cmd/mlflow.go @@ -33,8 +33,10 @@ func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run if len(run.Tasks) == 0 { return nil } - // The MLflow output is attached to the task run, not the parent job run. - taskRunID := run.Tasks[0].RunId + // The MLflow output is attached to the task run, not the parent job run. Use + // the last task (latest attempt) so a retried run links to its newest output, + // matching Python (jobs_api_client.py:68). + taskRunID := run.Tasks[len(run.Tasks)-1].RunId apiClient, err := client.New(w.Config) if err != nil { diff --git a/experimental/air/cmd/mlflow_test.go b/experimental/air/cmd/mlflow_test.go index bbc4fef9822..bdc54c18d53 100644 --- a/experimental/air/cmd/mlflow_test.go +++ b/experimental/air/cmd/mlflow_test.go @@ -61,4 +61,20 @@ func TestMLflowURL(t *testing.T) { // Returns before any HTTP call, so the host is never contacted. assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, "https://unused.invalid"), &jobs.Run{})) }) + + t.Run("uses the latest attempt's task run", func(t *testing.T) { + // A retried run must link to the last task, not the stale first attempt. + var gotRunID string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/2.2/jobs/runs/get-output" { + gotRunID = r.URL.Query().Get("run_id") + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + + retried := &jobs.Run{Tasks: []jobs.RunTask{{RunId: 99}, {RunId: 100}}} + mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), retried) + assert.Equal(t, "100", gotRunID) + }) } From 3deceabc2283ca381177feac7714feee54ee61f8 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Tue, 16 Jun 2026 23:03:47 +0000 Subject: [PATCH 13/18] experimental/air: render JSON error envelopes and align `air get` JSON with Python In -o json mode, error paths now emit the structured error envelope ({v, ts, error:{code, kind, message, retryable}}) and exit non-zero, matching the Python air CLI's print_json_error instead of letting the framework print a bare "Error: ..." string. Covers invalid RUN_ID, run-not-found, backend failures, and client/auth failures (wrapped PreRunE). Also align the success envelope with the Python CLI: - dashboard_url: construct {host}/jobs/runs/{id}?o={workspace_id} (via CurrentWorkspaceID) instead of using the API's run_page_url - started_at: datetime.isoformat() form ("+00:00" with microseconds), not RFC3339 "Z" - duration_seconds: rounded half-to-even to match Python's round() - use run-level start/end times for started_at and duration_seconds, dropping the last-attempt preference, which had no Python equivalent Co-authored-by: Isaac --- acceptance/experimental/air/get/output.txt | 23 ++++++-- acceptance/experimental/air/get/script | 3 ++ experimental/air/cmd/format.go | 61 ++++++++++------------ experimental/air/cmd/format_test.go | 48 ++++++----------- experimental/air/cmd/get.go | 27 ++++++++-- experimental/air/cmd/get_test.go | 28 ++++++++-- experimental/air/cmd/mlflow.go | 4 +- experimental/air/cmd/output.go | 43 +++++++++++++++ experimental/air/cmd/output_test.go | 39 ++++++++++++++ 9 files changed, 194 insertions(+), 82 deletions(-) diff --git a/acceptance/experimental/air/get/output.txt b/acceptance/experimental/air/get/output.txt index 6ce803659b4..b31d0eaa0c5 100644 --- a/acceptance/experimental/air/get/output.txt +++ b/acceptance/experimental/air/get/output.txt @@ -3,14 +3,14 @@ >>> [CLI] experimental air get 123 Run ID: 123 Status: SUCCESS -Submitted: [TIMESTAMP] +Submitted: [TIMESTAMP]+00:00 Duration: 12s Retries: 0 Experiment: my-exp User: user@example.com Accelerators: 8x H100 MLflow: [DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0 -Dashboard: https://my-workspace.cloud.databricks.test/jobs/runs/123 +Dashboard: [DATABRICKS_URL]/jobs/runs/123?o=[NUMID] === get (json) >>> [CLI] experimental air get 123 -o json @@ -20,11 +20,11 @@ Dashboard: https://my-workspace.cloud.databricks.test/jobs/runs/123 "data": { "run_id": "123", "status": "SUCCESS", - "started_at": "[TIMESTAMP]", + "started_at": "[TIMESTAMP]+00:00", "duration_seconds": 12, "attempt_number": 0, "experiment_name": "my-exp", - "dashboard_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "dashboard_url": "[DATABRICKS_URL]/jobs/runs/123?o=[NUMID]", "mlflow_url": "[DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0" } } @@ -34,3 +34,18 @@ Dashboard: https://my-workspace.cloud.databricks.test/jobs/runs/123 Error: invalid RUN_ID "notanumber": must be a positive integer Exit code: 1 + +=== invalid run id (json) +>>> [CLI] experimental air get notanumber -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "error": { + "code": "INVALID_ARGS", + "kind": "PERMANENT", + "message": "invalid RUN_ID \"notanumber\": must be a positive integer", + "retryable": true + } +} + +Exit code: 1 diff --git a/acceptance/experimental/air/get/script b/acceptance/experimental/air/get/script index e0ea8d10f85..ee66b4aff04 100644 --- a/acceptance/experimental/air/get/script +++ b/acceptance/experimental/air/get/script @@ -6,3 +6,6 @@ trace $CLI experimental air get 123 -o json title "invalid run id" errcode trace $CLI experimental air get notanumber + +title "invalid run id (json)" +errcode trace $CLI experimental air get notanumber -o json diff --git a/experimental/air/cmd/format.go b/experimental/air/cmd/format.go index 77c5e5a100a..a1601f0eab1 100644 --- a/experimental/air/cmd/format.go +++ b/experimental/air/cmd/format.go @@ -2,6 +2,7 @@ package aircmd import ( "fmt" + "math" "strings" "time" @@ -35,56 +36,50 @@ func runStatus(state *jobs.RunState) string { return "UNKNOWN" } -// reportedTiming returns the run's start and end times (epoch milliseconds), -// preferring the last task's window over the run-level times so a retried run -// reports its latest attempt. Mirrors Python's _reported_attempt_timing -// (cli_display.py:78-87). -func reportedTiming(run *jobs.Run) (startMillis, endMillis int64) { - startMillis, endMillis = run.StartTime, run.EndTime - if n := len(run.Tasks); n > 0 { - last := run.Tasks[n-1] - if last.StartTime > 0 { - startMillis = last.StartTime - } - if last.EndTime > 0 { - endMillis = last.EndTime - } - } - return startMillis, endMillis -} - -// startedAt converts the run's reported start time (epoch milliseconds) to an -// RFC 3339 UTC string, or returns nil if the run has not started yet. +// startedAt returns the run's start time as a Python-isoformat string ("+00:00", +// not "Z"; microseconds only when non-zero, cli_entrypoint.py:1899), or nil if it +// hasn't started. func startedAt(run *jobs.Run) *string { - startMillis, _ := reportedTiming(run) - if startMillis == 0 { + if run.StartTime == 0 { return nil } - s := time.UnixMilli(startMillis).UTC().Format(time.RFC3339) + t := time.UnixMilli(run.StartTime).UTC() + layout := "2006-01-02T15:04:05-07:00" + if t.Nanosecond() != 0 { + layout = "2006-01-02T15:04:05.000000-07:00" + } + s := t.Format(layout) return &s } -// durationSeconds returns how long the run has taken, in whole seconds, or nil -// if it has not started. For a finished run this is the elapsed time of the -// reported attempt; for a still-running run it is the time since it started. +// durationSeconds returns how long the run has taken, in whole seconds, or nil if +// it hasn't started. For a finished run this is end-start; for a running one it is +// the time since it started. Both come from the run-level times, matching Python +// (cli_entrypoint.py:1900-1903). func durationSeconds(run *jobs.Run) *int64 { - startMillis, endMillis := reportedTiming(run) - if startMillis == 0 { + if run.StartTime == 0 { return nil } + endMillis := run.EndTime if endMillis == 0 { // Still running: measure against the current time. endMillis = time.Now().UnixMilli() } - d := roundMillisToSeconds(endMillis - startMillis) + d := roundMillisToSeconds(endMillis - run.StartTime) return &d } -// roundMillisToSeconds converts milliseconds to whole seconds, rounding to the -// nearest second to match Python's round() (cli_entrypoint.py:1934) rather than -// truncating. +// roundMillisToSeconds rounds milliseconds to whole seconds, half to even, to +// match Python's round() (cli_entrypoint.py:1903). func roundMillisToSeconds(ms int64) int64 { - return (ms + 500) / 1000 + return int64(math.RoundToEven(float64(ms) / 1000)) +} + +// dashboardURL builds {host}/jobs/runs/{id}?o={workspace_id}, matching Python +// (cli_entrypoint.py:1911). The ?o= workspace id deep-links to the right +// workspace on multi-workspace accounts. +func dashboardURL(host string, runID, workspaceID int64) string { + return fmt.Sprintf("%s/jobs/runs/%d?o=%d", strings.TrimRight(host, "/"), runID, workspaceID) } // formatDuration turns a number of seconds into a compact human string such as diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go index 211484c5229..e2c2e3d2915 100644 --- a/experimental/air/cmd/format_test.go +++ b/experimental/air/cmd/format_test.go @@ -68,47 +68,24 @@ func TestRunStatusPrefersResultState(t *testing.T) { assert.Equal(t, "UNKNOWN", runStatus(nil)) } -func TestReportedTiming(t *testing.T) { - // No tasks: run-level times are used. - start, end := reportedTiming(&jobs.Run{StartTime: 100, EndTime: 200}) - assert.Equal(t, int64(100), start) - assert.Equal(t, int64(200), end) - - // The last task's window is preferred over the run-level window, so a - // retried run reports its most recent attempt. - start, end = reportedTiming(&jobs.Run{ - StartTime: 100, EndTime: 200, - Tasks: []jobs.RunTask{ - {StartTime: 100, EndTime: 150}, - {StartTime: 300, EndTime: 450}, - }, - }) - assert.Equal(t, int64(300), start) - assert.Equal(t, int64(450), end) - - // A task missing a field falls back to the run-level value for that field. - start, end = reportedTiming(&jobs.Run{ - StartTime: 100, EndTime: 200, - Tasks: []jobs.RunTask{{StartTime: 300}}, - }) - assert.Equal(t, int64(300), start) - assert.Equal(t, int64(200), end) -} - func TestStartedAt(t *testing.T) { // Not started yet. assert.Nil(t, startedAt(&jobs.Run{})) - // 1700000000000 ms == 2023-11-14T22:13:20Z. + // 1700000000000 ms == 2023-11-14T22:13:20+00:00 (Python isoformat, not "Z"). got := startedAt(&jobs.Run{StartTime: 1700000000000}) require.NotNil(t, got) - assert.Equal(t, "2023-11-14T22:13:20Z", *got) - // The last attempt's start time is reported. 1700000060000 ms == 22:14:20Z. + assert.Equal(t, "2023-11-14T22:13:20+00:00", *got) + // A sub-second start time carries microsecond precision. + got = startedAt(&jobs.Run{StartTime: 1700000000500}) + require.NotNil(t, got) + assert.Equal(t, "2023-11-14T22:13:20.500000+00:00", *got) + // Task-level times are ignored; the run-level start is reported (matching Python). got = startedAt(&jobs.Run{ StartTime: 1700000000000, Tasks: []jobs.RunTask{{StartTime: 1700000060000}}, }) require.NotNil(t, got) - assert.Equal(t, "2023-11-14T22:14:20Z", *got) + assert.Equal(t, "2023-11-14T22:13:20+00:00", *got) } func TestDurationSeconds(t *testing.T) { @@ -126,13 +103,13 @@ func TestDurationSeconds(t *testing.T) { require.NotNil(t, d) assert.Equal(t, int64(12), *d) - // The last attempt's window drives the duration for a retried run. + // Task-level times are ignored; run-level end-start is used (matching Python). d = durationSeconds(&jobs.Run{ StartTime: 1700000000000, EndTime: 1700000012000, Tasks: []jobs.RunTask{{StartTime: 1700000000000, EndTime: 1700000005000}}, }) require.NotNil(t, d) - assert.Equal(t, int64(5), *d) + assert.Equal(t, int64(12), *d) // Still running: measured against the current time, so positive. d = durationSeconds(&jobs.Run{StartTime: 1700000000000}) @@ -140,6 +117,11 @@ func TestDurationSeconds(t *testing.T) { assert.Positive(t, *d) } +func TestDashboardURL(t *testing.T) { + // The ?o= workspace id and a trailing-slash-trimmed host, matching Python. + assert.Equal(t, "https://example.test/jobs/runs/123?o=42", dashboardURL("https://example.test/", 123, 42)) +} + func TestLatestAttemptNumber(t *testing.T) { assert.Equal(t, 0, latestAttemptNumber(&jobs.Run{})) run := &jobs.Run{Tasks: []jobs.RunTask{{AttemptNumber: 0}, {AttemptNumber: 2}}} diff --git a/experimental/air/cmd/get.go b/experimental/air/cmd/get.go index 7ed1fbaff01..03d7a3aec42 100644 --- a/experimental/air/cmd/get.go +++ b/experimental/air/cmd/get.go @@ -94,7 +94,15 @@ func newGetCommand() *cobra.Command { }, } - cmd.PreRunE = root.MustWorkspaceClient + // Match Python: a client/auth failure is a JSON error envelope in -o json mode, + // not a bare error. ErrAlreadyPrinted passes through (it was handled upstream). + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { + err := root.MustWorkspaceClient(cmd, args) + if err == nil || errors.Is(err, root.ErrAlreadyPrinted) { + return err + } + return renderError(cmd.Context(), cmd, "INTERNAL_ERROR", "TRANSIENT", true, err) + } cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -102,19 +110,29 @@ func newGetCommand() *cobra.Command { runID, err := strconv.ParseInt(args[0], 10, 64) if err != nil || runID <= 0 { - return fmt.Errorf("invalid RUN_ID %q: must be a positive integer", args[0]) + return renderError(ctx, cmd, "INVALID_ARGS", "PERMANENT", true, + fmt.Errorf("invalid RUN_ID %q: must be a positive integer", args[0])) } run, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: runID}) if err != nil { // The backend returns this when the run ID is unknown to the user. if errors.Is(err, apierr.ErrResourceDoesNotExist) { - return fmt.Errorf("run %d not found: check the run ID and that it is a job run ID", runID) + return renderError(ctx, cmd, "NOT_FOUND", "NOT_FOUND", false, + fmt.Errorf("run %d not found: check the run ID and that it is a job run ID", runID)) } - return fmt.Errorf("failed to get status for run %d: %w", runID, err) + return renderError(ctx, cmd, "INTERNAL_ERROR", "TRANSIENT", true, + fmt.Errorf("failed to get status for run %d: %w", runID, err)) + } + + workspaceID, err := w.CurrentWorkspaceID(ctx) + if err != nil { + return renderError(ctx, cmd, "INTERNAL_ERROR", "TRANSIENT", true, + fmt.Errorf("failed to get workspace id for run %d: %w", runID, err)) } data := buildGetData(run) + data.DashboardURL = dashboardURL(w.Config.Host, runID, workspaceID) data.MLflowURL = mlflowURL(ctx, w, run) if task := findForEachTask(run); task != nil { data.Sweep = buildSweepInfo(ctx, w, task) @@ -142,7 +160,6 @@ func buildGetData(run *jobs.Run) getData { DurationSeconds: durationSeconds(run), AttemptNumber: latestAttemptNumber(run), ExperimentName: experimentName(run), - DashboardURL: run.RunPageUrl, Accelerators: accelerators(run), User: run.CreatorUserName, } diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index 20974e326c4..5e83fd3420e 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -2,13 +2,16 @@ package aircmd import ( "bytes" + "encoding/json" "io" "strings" "testing" "text/template" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/jobs" @@ -46,7 +49,7 @@ func TestGetTemplateSingleRun(t *testing.T) { func TestGetRunInvalidID(t *testing.T) { m := mocks.NewMockWorkspaceClient(t) ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) - cmd := newGetCommand() + cmd := withOutput(newGetCommand(), flags.OutputText) cmd.SetContext(ctx) err := cmd.RunE(cmd, []string{"abc"}) @@ -59,7 +62,7 @@ func TestGetRunNotFound(t *testing.T) { m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 5}).Return( nil, apierr.ErrResourceDoesNotExist) ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) - cmd := newGetCommand() + cmd := withOutput(newGetCommand(), flags.OutputText) cmd.SetContext(ctx) err := cmd.RunE(cmd, []string{"5"}) @@ -67,6 +70,25 @@ func TestGetRunNotFound(t *testing.T) { assert.Contains(t, err.Error(), "run 5 not found") } +func TestGetRunNotFoundJSON(t *testing.T) { + var buf bytes.Buffer + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 5}).Return( + nil, apierr.ErrResourceDoesNotExist) + ctx := cmdctx.SetWorkspaceClient(t.Context(), m.WorkspaceClient) + ctx = cmdio.InContext(ctx, cmdio.NewIO(ctx, flags.OutputJSON, nil, &buf, &buf, "", "")) + cmd := withOutput(newGetCommand(), flags.OutputJSON) + cmd.SetContext(ctx) + + // In JSON mode the not-found error is a structured envelope, not a bare error. + err := cmd.RunE(cmd, []string{"5"}) + require.ErrorIs(t, err, root.ErrAlreadyPrinted) + + var got errorEnvelope + require.NoError(t, json.Unmarshal(buf.Bytes(), &got)) + assert.Equal(t, jsonError{Code: "NOT_FOUND", Kind: "NOT_FOUND", Message: "run 5 not found: check the run ID and that it is a job run ID"}, got.Error) +} + func TestPrintConfigYAML(t *testing.T) { t.Run("downloads and prints to stdout", func(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) @@ -191,7 +213,6 @@ func TestGetTemplateAllFields(t *testing.T) { func TestBuildGetData(t *testing.T) { run := &jobs.Run{ RunId: 123, - RunPageUrl: "https://example.test/run/123", CreatorUserName: "me@example.com", StartTime: 1700000000000, EndTime: 1700000012000, @@ -208,7 +229,6 @@ func TestBuildGetData(t *testing.T) { assert.Equal(t, "123", d.RunID) assert.Equal(t, "SUCCESS", d.Status) assert.Equal(t, 1, d.AttemptNumber) - assert.Equal(t, "https://example.test/run/123", d.DashboardURL) assert.Equal(t, "me@example.com", d.User) assert.Equal(t, "8x H100", d.Accelerators) assert.Equal(t, "12s", d.Duration) diff --git a/experimental/air/cmd/mlflow.go b/experimental/air/cmd/mlflow.go index 0d2a74c56ba..b4e95918b41 100644 --- a/experimental/air/cmd/mlflow.go +++ b/experimental/air/cmd/mlflow.go @@ -33,9 +33,7 @@ func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run if len(run.Tasks) == 0 { return nil } - // The MLflow output is attached to the task run, not the parent job run. Use - // the last task (latest attempt) so a retried run links to its newest output, - // matching Python (jobs_api_client.py:68). + // The MLflow output is attached to the task run, not the parent job run. taskRunID := run.Tasks[len(run.Tasks)-1].RunId apiClient, err := client.New(w.Config) diff --git a/experimental/air/cmd/output.go b/experimental/air/cmd/output.go index 3da766a7d4f..c00d870e386 100644 --- a/experimental/air/cmd/output.go +++ b/experimental/air/cmd/output.go @@ -4,7 +4,10 @@ import ( "context" "time" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/spf13/cobra" ) // envelopeVersion is the envelope's format-version marker. The Python `air` CLI @@ -37,3 +40,43 @@ func renderEnvelope(ctx context.Context, data any) error { Data: data, }) } + +// jsonError is the error payload, matching the Python `air` CLI's shape (cli/json_output.py). +type jsonError struct { + Code string `json:"code"` + Kind string `json:"kind"` + Message string `json:"message"` + Retryable bool `json:"retryable"` +} + +// errorEnvelope is what a failed command prints in JSON mode: +// +// { "v": 1, "ts": "...", "error": { "code": ..., "kind": ..., "message": ..., "retryable": ... } } +type errorEnvelope struct { + V int `json:"v"` + TS string `json:"ts"` + Error jsonError `json:"error"` +} + +// renderError prints err as a JSON error envelope when output is JSON, returning +// root.ErrAlreadyPrinted so the command exits non-zero without Cobra reprinting +// it; in text mode it returns err unchanged. code/kind/retryable match the +// Python CLI's call site. +func renderError(ctx context.Context, cmd *cobra.Command, code, kind string, retryable bool, err error) error { + if root.OutputType(cmd) != flags.OutputJSON { + return err + } + if rerr := cmdio.Render(ctx, errorEnvelope{ + V: envelopeVersion, + TS: time.Now().UTC().Format(time.RFC3339), + Error: jsonError{ + Code: code, + Kind: kind, + Message: err.Error(), + Retryable: retryable, + }, + }); rerr != nil { + return rerr + } + return root.ErrAlreadyPrinted +} diff --git a/experimental/air/cmd/output_test.go b/experimental/air/cmd/output_test.go index 73a5572c3f5..3c35dc60a67 100644 --- a/experimental/air/cmd/output_test.go +++ b/experimental/air/cmd/output_test.go @@ -1,9 +1,16 @@ package aircmd import ( + "bytes" + "encoding/json" + "errors" "testing" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -11,3 +18,35 @@ func TestRenderEnvelope(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) require.NoError(t, renderEnvelope(ctx, getData{RunID: "1", Status: "RUNNING"})) } + +// withOutput registers the --output flag on cmd and sets it, mirroring how the +// root command wires output mode in production. Subcommand unit tests need it +// because they invoke RunE without going through the root command. +func withOutput(cmd *cobra.Command, output flags.Output) *cobra.Command { + cmd.Flags().Var(&output, "output", "") + return cmd +} + +func TestRenderErrorJSON(t *testing.T) { + var buf bytes.Buffer + ctx := cmdio.InContext(t.Context(), cmdio.NewIO(t.Context(), flags.OutputJSON, nil, &buf, &buf, "", "")) + cmd := withOutput(&cobra.Command{}, flags.OutputJSON) + + err := renderError(ctx, cmd, "NOT_FOUND", "NOT_FOUND", false, errors.New("run 1 not found")) + // JSON mode prints the envelope, so Cobra must stay silent but still exit non-zero. + require.ErrorIs(t, err, root.ErrAlreadyPrinted) + + // The envelope must match the Python air CLI's print_json_error shape exactly. + var got errorEnvelope + require.NoError(t, json.Unmarshal(buf.Bytes(), &got)) + assert.Equal(t, 1, got.V) + assert.NotEmpty(t, got.TS) + assert.Equal(t, jsonError{Code: "NOT_FOUND", Kind: "NOT_FOUND", Message: "run 1 not found"}, got.Error) +} + +func TestRenderErrorText(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + cmd := withOutput(&cobra.Command{}, flags.OutputText) + want := errors.New("run 1 not found") + require.Equal(t, want, renderError(ctx, cmd, "NOT_FOUND", "NOT_FOUND", false, want)) +} From 9072c290ccbe30af2085d86e9517f7ff324fd071 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Wed, 17 Jun 2026 08:25:23 +0000 Subject: [PATCH 14/18] experimental/air: restore last-attempt timing for `air get` Revert the run-level timing change from the previous commit: started_at and duration_seconds read from the last task's window again (reportedTiming), matching the released Python `air` output, which reports the latest attempt. The isoformat timestamp ("+00:00") and half-to-even rounding are kept. Co-authored-by: Isaac --- experimental/air/cmd/format.go | 36 +++++++++++++++++++++-------- experimental/air/cmd/format_test.go | 35 ++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/experimental/air/cmd/format.go b/experimental/air/cmd/format.go index a1601f0eab1..e246c1b3976 100644 --- a/experimental/air/cmd/format.go +++ b/experimental/air/cmd/format.go @@ -36,14 +36,33 @@ func runStatus(state *jobs.RunState) string { return "UNKNOWN" } +// reportedTiming returns the run's start and end times (epoch milliseconds), +// preferring the last task's window over the run-level times so a retried run +// reports its latest attempt. Mirrors Python's _reported_attempt_timing +// (cli_display.py:78-87). +func reportedTiming(run *jobs.Run) (startMillis, endMillis int64) { + startMillis, endMillis = run.StartTime, run.EndTime + if n := len(run.Tasks); n > 0 { + last := run.Tasks[n-1] + if last.StartTime > 0 { + startMillis = last.StartTime + } + if last.EndTime > 0 { + endMillis = last.EndTime + } + } + return startMillis, endMillis +} + // startedAt returns the run's start time as a Python-isoformat string ("+00:00", // not "Z"; microseconds only when non-zero, cli_entrypoint.py:1899), or nil if it // hasn't started. func startedAt(run *jobs.Run) *string { - if run.StartTime == 0 { + startMillis, _ := reportedTiming(run) + if startMillis == 0 { return nil } - t := time.UnixMilli(run.StartTime).UTC() + t := time.UnixMilli(startMillis).UTC() layout := "2006-01-02T15:04:05-07:00" if t.Nanosecond() != 0 { layout = "2006-01-02T15:04:05.000000-07:00" @@ -52,20 +71,19 @@ func startedAt(run *jobs.Run) *string { return &s } -// durationSeconds returns how long the run has taken, in whole seconds, or nil if -// it hasn't started. For a finished run this is end-start; for a running one it is -// the time since it started. Both come from the run-level times, matching Python -// (cli_entrypoint.py:1900-1903). +// durationSeconds returns how long the run has taken, in whole seconds, or nil +// if it has not started. For a finished run this is the elapsed time of the +// reported attempt; for a still-running run it is the time since it started. func durationSeconds(run *jobs.Run) *int64 { - if run.StartTime == 0 { + startMillis, endMillis := reportedTiming(run) + if startMillis == 0 { return nil } - endMillis := run.EndTime if endMillis == 0 { // Still running: measure against the current time. endMillis = time.Now().UnixMilli() } - d := roundMillisToSeconds(endMillis - run.StartTime) + d := roundMillisToSeconds(endMillis - startMillis) return &d } diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go index e2c2e3d2915..8400ab5d6b7 100644 --- a/experimental/air/cmd/format_test.go +++ b/experimental/air/cmd/format_test.go @@ -68,6 +68,33 @@ func TestRunStatusPrefersResultState(t *testing.T) { assert.Equal(t, "UNKNOWN", runStatus(nil)) } +func TestReportedTiming(t *testing.T) { + // No tasks: run-level times are used. + start, end := reportedTiming(&jobs.Run{StartTime: 100, EndTime: 200}) + assert.Equal(t, int64(100), start) + assert.Equal(t, int64(200), end) + + // The last task's window is preferred over the run-level window, so a + // retried run reports its most recent attempt. + start, end = reportedTiming(&jobs.Run{ + StartTime: 100, EndTime: 200, + Tasks: []jobs.RunTask{ + {StartTime: 100, EndTime: 150}, + {StartTime: 300, EndTime: 450}, + }, + }) + assert.Equal(t, int64(300), start) + assert.Equal(t, int64(450), end) + + // A task missing a field falls back to the run-level value for that field. + start, end = reportedTiming(&jobs.Run{ + StartTime: 100, EndTime: 200, + Tasks: []jobs.RunTask{{StartTime: 300}}, + }) + assert.Equal(t, int64(300), start) + assert.Equal(t, int64(200), end) +} + func TestStartedAt(t *testing.T) { // Not started yet. assert.Nil(t, startedAt(&jobs.Run{})) @@ -79,13 +106,13 @@ func TestStartedAt(t *testing.T) { got = startedAt(&jobs.Run{StartTime: 1700000000500}) require.NotNil(t, got) assert.Equal(t, "2023-11-14T22:13:20.500000+00:00", *got) - // Task-level times are ignored; the run-level start is reported (matching Python). + // The last attempt's start time is reported. 1700000060000 ms == 22:14:20. got = startedAt(&jobs.Run{ StartTime: 1700000000000, Tasks: []jobs.RunTask{{StartTime: 1700000060000}}, }) require.NotNil(t, got) - assert.Equal(t, "2023-11-14T22:13:20+00:00", *got) + assert.Equal(t, "2023-11-14T22:14:20+00:00", *got) } func TestDurationSeconds(t *testing.T) { @@ -103,13 +130,13 @@ func TestDurationSeconds(t *testing.T) { require.NotNil(t, d) assert.Equal(t, int64(12), *d) - // Task-level times are ignored; run-level end-start is used (matching Python). + // The last attempt's window drives the duration for a retried run. d = durationSeconds(&jobs.Run{ StartTime: 1700000000000, EndTime: 1700000012000, Tasks: []jobs.RunTask{{StartTime: 1700000000000, EndTime: 1700000005000}}, }) require.NotNil(t, d) - assert.Equal(t, int64(12), *d) + assert.Equal(t, int64(5), *d) // Still running: measured against the current time, so positive. d = durationSeconds(&jobs.Run{StartTime: 1700000000000}) From fddbdfd40bdd14c75090b181396769498cfa4749 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Wed, 17 Jun 2026 18:31:13 +0000 Subject: [PATCH 15/18] experimental/air: fix `air get` MLflow link request The runs/get-output call passed run_id via the query-param arg and a nil request body, which this endpoint rejects with "expected a map", so the MLflow link was never produced for completed runs. Pass run_id through the request arg instead (the SDK serializes it to the query string for GET), which sends a valid body and returns the gen_ai_compute_output run info. Failed runs without MLflow output still yield no link: get-output 404s for them, so mlflowURL returns nil as before. Co-authored-by: Isaac --- experimental/air/cmd/mlflow.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/experimental/air/cmd/mlflow.go b/experimental/air/cmd/mlflow.go index b4e95918b41..6cebecb0af5 100644 --- a/experimental/air/cmd/mlflow.go +++ b/experimental/air/cmd/mlflow.go @@ -42,9 +42,12 @@ func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run return nil } + // Pass run_id through the request arg (the SDK serializes it to the query + // string for GET); passing it via queryParams instead leaves a nil body that + // this endpoint rejects with "expected a map". var out getRunOutputResponse err = apiClient.Do(ctx, http.MethodGet, "/api/2.2/jobs/runs/get-output", - nil, map[string]any{"run_id": taskRunID}, nil, &out) + nil, nil, map[string]any{"run_id": taskRunID}, &out) if err != nil { log.Debugf(ctx, "air get: could not fetch run output for MLflow link: %v", err) return nil From a69e0d348013b0c43ab7a35d39c3f7b82c686298 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Wed, 17 Jun 2026 20:25:34 +0000 Subject: [PATCH 16/18] experimental/air: rename `air get` to `air get run` and match Python output Nest the run-status command under a `get` parent group so the command is `air get run JOB_RUN_ID`, mirroring the Python CLI (the JOB_RUN_ID arg name matches the sibling air commands and avoids confusion with the MLflow run id). Align the text output with Python's `air get run`: lead with the dashboard link (hyperlinked, falling back to the bare URL off a terminal) followed by a gap, then the training config, then the status table. The table uses Python's field order, "N/A" for empty cells, a "2006-01-02 15:04 UTC" Submitted timestamp, and terminal hyperlinks on the Run ID, Experiment, and MLflow Run cells (the MLflow Run cell shows the run's name from the MLflow REST API). The JSON envelope is unchanged. Also reformat the training-config YAML shown in text mode so multi-line fields (e.g. command) render as block literals instead of escaped one-liners. Co-authored-by: Isaac --- acceptance/experimental/air/get/output.txt | 21 ++-- acceptance/experimental/air/get/script | 8 +- acceptance/experimental/air/get/test.toml | 7 ++ acceptance/experimental/air/help/output.txt | 2 +- experimental/air/cmd/format.go | 81 +++++++++++++ experimental/air/cmd/format_test.go | 44 +++++++ experimental/air/cmd/get.go | 124 +++++++++++++------- experimental/air/cmd/get_test.go | 106 +++++++++++------ experimental/air/cmd/mlflow.go | 76 ++++++++++-- experimental/air/cmd/mlflow_test.go | 48 ++++++-- 10 files changed, 405 insertions(+), 112 deletions(-) diff --git a/acceptance/experimental/air/get/output.txt b/acceptance/experimental/air/get/output.txt index b31d0eaa0c5..0e938cdd494 100644 --- a/acceptance/experimental/air/get/output.txt +++ b/acceptance/experimental/air/get/output.txt @@ -1,19 +1,20 @@ === get (text) ->>> [CLI] experimental air get 123 +>>> [CLI] experimental air get run 123 +Dashboard: [DATABRICKS_URL]/jobs/runs/123?o=[NUMID] + Run ID: 123 Status: SUCCESS -Submitted: [TIMESTAMP]+00:00 -Duration: 12s +Submitted: 2023-11-14 22:13 UTC Retries: 0 +Duration: 12s Experiment: my-exp +MLflow Run: my-run User: user@example.com Accelerators: 8x H100 -MLflow: [DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0 -Dashboard: [DATABRICKS_URL]/jobs/runs/123?o=[NUMID] === get (json) ->>> [CLI] experimental air get 123 -o json +>>> [CLI] experimental air get run 123 -o json { "v": 1, "ts": "[TIMESTAMP]", @@ -30,20 +31,20 @@ Dashboard: [DATABRICKS_URL]/jobs/runs/123?o=[NUMID] } === invalid run id ->>> [CLI] experimental air get notanumber -Error: invalid RUN_ID "notanumber": must be a positive integer +>>> [CLI] experimental air get run notanumber +Error: invalid JOB_RUN_ID "notanumber": must be a positive integer Exit code: 1 === invalid run id (json) ->>> [CLI] experimental air get notanumber -o json +>>> [CLI] experimental air get run notanumber -o json { "v": 1, "ts": "[TIMESTAMP]", "error": { "code": "INVALID_ARGS", "kind": "PERMANENT", - "message": "invalid RUN_ID \"notanumber\": must be a positive integer", + "message": "invalid JOB_RUN_ID \"notanumber\": must be a positive integer", "retryable": true } } diff --git a/acceptance/experimental/air/get/script b/acceptance/experimental/air/get/script index ee66b4aff04..b775d06a48d 100644 --- a/acceptance/experimental/air/get/script +++ b/acceptance/experimental/air/get/script @@ -1,11 +1,11 @@ title "get (text)" -trace $CLI experimental air get 123 +trace $CLI experimental air get run 123 title "get (json)" -trace $CLI experimental air get 123 -o json +trace $CLI experimental air get run 123 -o json title "invalid run id" -errcode trace $CLI experimental air get notanumber +errcode trace $CLI experimental air get run notanumber title "invalid run id (json)" -errcode trace $CLI experimental air get notanumber -o json +errcode trace $CLI experimental air get run notanumber -o json diff --git a/acceptance/experimental/air/get/test.toml b/acceptance/experimental/air/get/test.toml index b6219b87f07..6162cebdb5d 100644 --- a/acceptance/experimental/air/get/test.toml +++ b/acceptance/experimental/air/get/test.toml @@ -38,3 +38,10 @@ Pattern = "GET /api/2.2/jobs/runs/get-output" Response.Body = ''' {"gen_ai_compute_output": {"run_info": {"mlflow_experiment_id": "exp1", "mlflow_run_id": "run1"}}} ''' + +# The MLflow Run cell shows the run's name, fetched from the MLflow REST API. +[[Server]] +Pattern = "GET /api/2.0/mlflow/runs/get" +Response.Body = ''' +{"run": {"info": {"run_name": "my-run"}}} +''' diff --git a/acceptance/experimental/air/help/output.txt b/acceptance/experimental/air/help/output.txt index 3a0f86e164f..41a1d6815a8 100644 --- a/acceptance/experimental/air/help/output.txt +++ b/acceptance/experimental/air/help/output.txt @@ -11,7 +11,7 @@ Usage: Available Commands: cancel Cancel one or more runs - get Show details for a run + get Show details for a specific resource list List recent runs logs Stream or fetch logs for a run register-image Mirror a Docker image into the workspace registry diff --git a/experimental/air/cmd/format.go b/experimental/air/cmd/format.go index e246c1b3976..8694ed69e9f 100644 --- a/experimental/air/cmd/format.go +++ b/experimental/air/cmd/format.go @@ -1,14 +1,83 @@ package aircmd import ( + "bytes" + "context" "fmt" + "io" "math" "strings" "time" + "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go/service/jobs" + "go.yaml.in/yaml/v3" ) +// na is the placeholder shown for an empty text-table cell, matching the Python CLI. +const na = "N/A" + +// orNA returns s, or "N/A" when s is empty, for text-table cells. +func orNA(s string) string { + if s == "" { + return na + } + return s +} + +// osc8Link wraps label in an OSC 8 terminal hyperlink to url. +// See https://gist.github.com/egmontkob/eb114294efbcd5adb1944c9f3cb5feda +func osc8Link(label, url string) string { + return "\x1b]8;;" + url + "\x1b\\" + label + "\x1b]8;;\x1b\\" +} + +// hyperlink renders label as a terminal hyperlink to url when out is a rich +// terminal, otherwise it returns label unchanged. This mirrors the Python CLI's +// Rich link markup, which drops the URL on non-terminals (so piped or captured +// output stays plain text). +func hyperlink(ctx context.Context, out io.Writer, label, url string) string { + if url == "" || !cmdio.SupportsColor(ctx, out) { + return label + } + return osc8Link(label, url) +} + +// reformatYAMLForDisplay re-renders a training-config YAML so multi-line strings +// (notably the `command:` field) appear as `|` block literals instead of the +// quoted "\n"-escaped single line they are stored as, which is unreadable. It +// mirrors Python's _reformat_yaml_for_display (cli_display.py); we skip the +// Rich syntax-highlighted panel and only fix the whitespace. On any parse or +// re-encode failure it returns the original content unchanged. +func reformatYAMLForDisplay(content []byte) string { + var node yaml.Node + if err := yaml.Unmarshal(content, &node); err != nil { + return string(content) + } + forceLiteralBlockStrings(&node) + + var buf bytes.Buffer + enc := yaml.NewEncoder(&buf) + enc.SetIndent(2) + if err := enc.Encode(&node); err != nil { + return string(content) + } + enc.Close() + return buf.String() +} + +// forceLiteralBlockStrings walks a YAML node tree and marks every multi-line +// string scalar for `|` block-literal rendering. The encoder automatically +// falls back to a quoted style when a value can't be represented as a block +// literal (e.g. lines with trailing whitespace), so no explicit guard is needed. +func forceLiteralBlockStrings(node *yaml.Node) { + if node.Kind == yaml.ScalarNode && node.Tag == "!!str" && strings.Contains(node.Value, "\n") { + node.Style = yaml.LiteralStyle + } + for _, child := range node.Content { + forceLiteralBlockStrings(child) + } +} + // gpuDisplayNames maps the GPU identifiers returned by the backend to the short // names we show to users. Unknown identifiers are shown unchanged. var gpuDisplayNames = map[string]string{ @@ -71,6 +140,18 @@ func startedAt(run *jobs.Run) *string { return &s } +// submittedDisplay formats the run's start time for the text table as +// "2006-01-02 15:04 UTC", or "N/A" if it hasn't started. Mirrors Python's +// _format_timestamp (cli_display.py); we render in UTC for stable output rather +// than the local zone Python uses. +func submittedDisplay(run *jobs.Run) string { + startMillis, _ := reportedTiming(run) + if startMillis == 0 { + return na + } + return time.UnixMilli(startMillis).UTC().Format("2006-01-02 15:04 MST") +} + // durationSeconds returns how long the run has taken, in whole seconds, or nil // if it has not started. For a finished run this is the elapsed time of the // reported attempt; for a still-running run it is the time since it started. diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go index 8400ab5d6b7..acecd0ce901 100644 --- a/experimental/air/cmd/format_test.go +++ b/experimental/air/cmd/format_test.go @@ -1,13 +1,38 @@ package aircmd import ( + "io" "testing" + "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestOrNA(t *testing.T) { + assert.Equal(t, "x", orNA("x")) + assert.Equal(t, "N/A", orNA("")) +} + +func TestSubmittedDisplay(t *testing.T) { + assert.Equal(t, "N/A", submittedDisplay(&jobs.Run{})) + // 1700000000000 ms == 2023-11-14 22:13:20 UTC. + assert.Equal(t, "2023-11-14 22:13 UTC", submittedDisplay(&jobs.Run{StartTime: 1700000000000})) +} + +func TestOSC8Link(t *testing.T) { + assert.Equal(t, "\x1b]8;;https://h.test/x\x1b\\label\x1b]8;;\x1b\\", osc8Link("label", "https://h.test/x")) +} + +func TestHyperlink(t *testing.T) { + // On a non-terminal (no color), the URL is dropped and only the label shows. + ctx := cmdio.MockDiscard(t.Context()) + assert.Equal(t, "label", hyperlink(ctx, io.Discard, "label", "https://h.test/x")) + // An empty URL is always rendered as the bare label. + assert.Equal(t, "label", hyperlink(ctx, io.Discard, "label", "")) +} + func TestFormatDuration(t *testing.T) { cases := []struct { seconds int64 @@ -168,6 +193,25 @@ func TestExperimentName(t *testing.T) { assert.Equal(t, "exp", *got) } +func TestReformatYAMLForDisplay(t *testing.T) { + // A multi-line command stored as a quoted "\n"-escaped string is re-rendered + // as a `|` block literal, while key order is preserved. + in := "experiment_name: foo\ncommand: \"set -e\\npython train.py --epochs 3\\n\"\nmax_retries: 2\n" + want := "experiment_name: foo\ncommand: |\n set -e\n python train.py --epochs 3\nmax_retries: 2\n" + assert.Equal(t, want, reformatYAMLForDisplay([]byte(in))) + + // A single-line command is left as a plain scalar. + assert.Equal(t, "command: bash train.sh\n", reformatYAMLForDisplay([]byte("command: bash train.sh\n"))) + + // A multi-line value with trailing whitespace can't be a block literal, so the + // encoder falls back to a quoted style rather than emitting invalid YAML. + got := reformatYAMLForDisplay([]byte("command: \"trailing space \\nsecond\"\n")) + assert.Equal(t, "command: \"trailing space \\nsecond\"\n", got) + + // Unparseable content is returned unchanged. + assert.Equal(t, "\tnot: valid: yaml:", reformatYAMLForDisplay([]byte("\tnot: valid: yaml:"))) +} + func TestAccelerators(t *testing.T) { assert.Empty(t, accelerators(&jobs.Run{})) assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) diff --git a/experimental/air/cmd/get.go b/experimental/air/cmd/get.go index 03d7a3aec42..1dbb4f786ba 100644 --- a/experimental/air/cmd/get.go +++ b/experimental/air/cmd/get.go @@ -17,7 +17,7 @@ import ( "github.com/spf13/cobra" ) -// getData is the payload printed by `air get`. The json-tagged fields form +// getData is the payload printed by `air get run`. The json-tagged fields form // the machine-readable output; fields tagged `json:"-"` are shown only in the // human-readable text view. type getData struct { @@ -30,18 +30,24 @@ type getData struct { DashboardURL string `json:"dashboard_url"` MLflowURL *string `json:"mlflow_url"` - // Duration is the human-readable form of DurationSeconds, e.g. "12m 3s". - Duration string `json:"-"` - // Accelerators describes the run's GPUs, e.g. "8x H100". - Accelerators string `json:"-"` - // User is the run's creator. Text-only; JSON omits it, matching `air get --json`. - User string `json:"-"` - // Sweep replaces the single-run view for foreach runs. Text-only; JSON omits it. + // The fields below are pre-rendered for the text table and excluded from + // JSON (matching `air get run --json`). The table always shows every row, + // with "N/A" for missing values, in the same order as the Python CLI. The + // Run ID, Experiment, and MLflow Run cells carry terminal hyperlinks when + // stdout is a terminal, so the URLs don't appear as bare text. + RunIDDisplay string `json:"-"` + SubmittedDisplay string `json:"-"` + DurationDisplay string `json:"-"` + ExperimentDisplay string `json:"-"` + MLflowDisplay string `json:"-"` + UserDisplay string `json:"-"` + AcceleratorsDisplay string `json:"-"` + // Sweep replaces the single-run view for foreach runs. Sweep *sweepInfo `json:"-"` } // getTemplate is the text-mode layout. It reads from the JSON envelope, so -// every field is reached through ".Data". Optional rows are hidden when empty. +// every field is reached through ".Data". const getTemplate = `{{- if .Data.Sweep -}} Sweep Run ID: {{.Data.RunID}} Status: {{.Data.Status}} @@ -59,36 +65,34 @@ Sweep Tasks: {{- end}} {{- end}} {{- else -}} -Run ID: {{.Data.RunID}} +Run ID: {{.Data.RunIDDisplay}} Status: {{.Data.Status}} -{{- if .Data.StartedAt}} -Submitted: {{.Data.StartedAt}} -{{- end}} -{{- if .Data.Duration}} -Duration: {{.Data.Duration}} -{{- end}} +Submitted: {{.Data.SubmittedDisplay}} Retries: {{.Data.AttemptNumber}} -{{- if .Data.ExperimentName}} -Experiment: {{.Data.ExperimentName}} -{{- end}} -{{- if .Data.User}} -User: {{.Data.User}} -{{- end}} -{{- if .Data.Accelerators}} -Accelerators: {{.Data.Accelerators}} -{{- end}} -{{- if .Data.MLflowURL}} -MLflow: {{.Data.MLflowURL}} -{{- end}} -Dashboard: {{.Data.DashboardURL}} +Duration: {{.Data.DurationDisplay}} +Experiment: {{.Data.ExperimentDisplay}} +MLflow Run: {{.Data.MLflowDisplay}} +User: {{.Data.UserDisplay}} +Accelerators: {{.Data.AcceleratorsDisplay}} {{- end}} ` +// newGetCommand is the `get` parent group. Subcommands name the resource to +// describe, e.g. `air get run JOB_RUN_ID`, mirroring the Python CLI. func newGetCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "get JOB_RUN_ID", + Use: "get", + Short: "Show details for a specific resource", + } + cmd.AddCommand(newGetRunCommand()) + return cmd +} + +func newGetRunCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "run JOB_RUN_ID", Args: root.ExactArgs(1), - Short: "Show details for a run", + Short: "Show status, configuration, and timing details for a specific run", Annotations: map[string]string{ "template": getTemplate, }, @@ -111,7 +115,7 @@ func newGetCommand() *cobra.Command { runID, err := strconv.ParseInt(args[0], 10, 64) if err != nil || runID <= 0 { return renderError(ctx, cmd, "INVALID_ARGS", "PERMANENT", true, - fmt.Errorf("invalid RUN_ID %q: must be a positive integer", args[0])) + fmt.Errorf("invalid JOB_RUN_ID %q: must be a positive integer", args[0])) } run, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: runID}) @@ -133,16 +137,28 @@ func newGetCommand() *cobra.Command { data := buildGetData(run) data.DashboardURL = dashboardURL(w.Config.Host, runID, workspaceID) - data.MLflowURL = mlflowURL(ctx, w, run) + ids := mlflowIDs(ctx, w, run) + if ids != nil { + url := mlflowLogsURL(w.Config.Host, ids) + data.MLflowURL = &url + } if task := findForEachTask(run); task != nil { data.Sweep = buildSweepInfo(ctx, w, task) } - // Text mode shows the training-config YAML before the status, mirroring - // `air get`. JSON output omits it, matching `air get --json`. if root.OutputType(cmd) == flags.OutputText { + out := cmd.OutOrStdout() + addTextLinks(ctx, out, w, &data, ids) + + // Lead with the dashboard link (hyperlinked, falling back to the bare + // URL off a terminal), then a gap before the training config and the + // status table, mirroring the Python CLI's header. + fmt.Fprintf(out, "Dashboard: %s\n\n", hyperlink(ctx, out, data.DashboardURL, data.DashboardURL)) + + // Text mode shows the training-config YAML before the status, + // mirroring `air get run`. JSON output omits it. if path := yamlConfigPath(run); path != "" { - printConfigYAML(ctx, cmd.OutOrStdout(), w, path) + printConfigYAML(ctx, out, w, path) } } return renderEnvelope(ctx, data) @@ -151,7 +167,10 @@ func newGetCommand() *cobra.Command { return cmd } -// buildGetData extracts the fields we display from a run. +// buildGetData extracts the fields we display from a run. The text-table cells +// are pre-rendered here with their "N/A" fallbacks; the Run ID, Experiment, and +// MLflow Run cells are finalized later by addTextLinks once the dashboard and +// MLflow identifiers are known. func buildGetData(run *jobs.Run) getData { data := getData{ RunID: strconv.FormatInt(run.RunId, 10), @@ -160,15 +179,38 @@ func buildGetData(run *jobs.Run) getData { DurationSeconds: durationSeconds(run), AttemptNumber: latestAttemptNumber(run), ExperimentName: experimentName(run), - Accelerators: accelerators(run), - User: run.CreatorUserName, } + data.RunIDDisplay = data.RunID + data.SubmittedDisplay = submittedDisplay(run) + data.DurationDisplay = na if data.DurationSeconds != nil { - data.Duration = formatDuration(*data.DurationSeconds) + data.DurationDisplay = formatDuration(*data.DurationSeconds) + } + data.ExperimentDisplay = na + if data.ExperimentName != nil { + data.ExperimentDisplay = *data.ExperimentName } + data.MLflowDisplay = na + data.UserDisplay = orNA(run.CreatorUserName) + data.AcceleratorsDisplay = orNA(accelerators(run)) return data } +// addTextLinks adds the terminal hyperlinks shown in text mode: the Run ID links +// to the run's dashboard page (Python embeds this on the Run ID instead of a +// separate Dashboard row), and the Experiment and MLflow Run cells link to their +// MLflow pages. On a non-terminal these degrade to plain text. +func addTextLinks(ctx context.Context, out io.Writer, w *databricks.WorkspaceClient, data *getData, ids *mlflowIdentifiers) { + data.RunIDDisplay = hyperlink(ctx, out, data.RunID, data.DashboardURL) + if ids == nil { + return + } + if data.ExperimentName != nil { + data.ExperimentDisplay = hyperlink(ctx, out, *data.ExperimentName, mlflowExperimentURL(w.Config.Host, ids)) + } + data.MLflowDisplay = hyperlink(ctx, out, mlflowRunLabel(ctx, w, ids.RunID), mlflowRunURL(w.Config.Host, ids)) +} + // yamlConfigPath returns the run's training-config YAML path, or "" if none. func yamlConfigPath(run *jobs.Run) string { if len(run.Tasks) == 0 { @@ -199,6 +241,6 @@ func printConfigYAML(ctx context.Context, out io.Writer, w *databricks.Workspace } fmt.Fprintln(out, "Training Configuration:") - fmt.Fprintln(out, string(content)) + fmt.Fprintln(out, reformatYAMLForDisplay(content)) fmt.Fprintln(out) } diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index 5e83fd3420e..643c521c0e7 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -33,28 +33,27 @@ func renderGet(t *testing.T, data getData) string { func TestGetTemplateSingleRun(t *testing.T) { out := renderGet(t, getData{ - RunID: "123", + RunIDDisplay: "123", Status: "RUNNING", - User: "me@example.com", - DashboardURL: "https://example.test/run/123", + UserDisplay: "me@example.com", }) assert.Contains(t, out, "Run ID: 123") assert.Contains(t, out, "Status: RUNNING") - assert.Contains(t, out, "User:") - assert.Contains(t, out, "me@example.com") - assert.Contains(t, out, "Dashboard: https://example.test/run/123") + assert.Contains(t, out, "User: me@example.com") assert.NotContains(t, out, "Sweep") + // Python embeds the dashboard link on the Run ID; there is no Dashboard row. + assert.NotContains(t, out, "Dashboard") } func TestGetRunInvalidID(t *testing.T) { m := mocks.NewMockWorkspaceClient(t) ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) - cmd := withOutput(newGetCommand(), flags.OutputText) + cmd := withOutput(newGetRunCommand(), flags.OutputText) cmd.SetContext(ctx) err := cmd.RunE(cmd, []string{"abc"}) require.Error(t, err) - assert.Contains(t, err.Error(), "invalid RUN_ID") + assert.Contains(t, err.Error(), "invalid JOB_RUN_ID") } func TestGetRunNotFound(t *testing.T) { @@ -62,7 +61,7 @@ func TestGetRunNotFound(t *testing.T) { m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 5}).Return( nil, apierr.ErrResourceDoesNotExist) ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) - cmd := withOutput(newGetCommand(), flags.OutputText) + cmd := withOutput(newGetRunCommand(), flags.OutputText) cmd.SetContext(ctx) err := cmd.RunE(cmd, []string{"5"}) @@ -77,7 +76,7 @@ func TestGetRunNotFoundJSON(t *testing.T) { nil, apierr.ErrResourceDoesNotExist) ctx := cmdctx.SetWorkspaceClient(t.Context(), m.WorkspaceClient) ctx = cmdio.InContext(ctx, cmdio.NewIO(ctx, flags.OutputJSON, nil, &buf, &buf, "", "")) - cmd := withOutput(newGetCommand(), flags.OutputJSON) + cmd := withOutput(newGetRunCommand(), flags.OutputJSON) cmd.SetContext(ctx) // In JSON mode the not-found error is a structured envelope, not a bare error. @@ -96,7 +95,7 @@ func TestPrintConfigYAML(t *testing.T) { // The mock asserts Download is called with the resolved path. m.GetMockWorkspaceAPI().EXPECT(). Download(mock.Anything, "/Workspace/cfg.yaml"). - Return(io.NopCloser(strings.NewReader("epochs: 3\n")), nil) + Return(io.NopCloser(strings.NewReader("epochs: 3\ncommand: \"set -e\\npython train.py\\n\"\n")), nil) // The config is data output and must land on stdout (the out writer), // matching the Python `air get` behavior. @@ -104,6 +103,8 @@ func TestPrintConfigYAML(t *testing.T) { printConfigYAML(ctx, &out, m.WorkspaceClient, "/Workspace/cfg.yaml") assert.Contains(t, out.String(), "Training Configuration:") assert.Contains(t, out.String(), "epochs: 3") + // The multi-line command is reformatted to a `|` block literal. + assert.Contains(t, out.String(), "command: |\n set -e\n python train.py") }) t.Run("download failure is non-fatal and writes nothing", func(t *testing.T) { @@ -169,42 +170,54 @@ func TestGetTemplateSweepNoTasks(t *testing.T) { } func TestGetTemplateMinimal(t *testing.T) { - // Only the always-present rows render; optional rows are hidden when empty. - out := renderGet(t, getData{RunID: "1", Status: "PENDING", DashboardURL: "https://example.test/1"}) - assert.Contains(t, out, "Run ID: 1") - assert.Contains(t, out, "Status: PENDING") - assert.Contains(t, out, "Retries: 0") - assert.Contains(t, out, "Dashboard: https://example.test/1") - for _, hidden := range []string{"Submitted:", "Duration:", "Experiment:", "User:", "Accelerators:", "MLflow:"} { - assert.NotContains(t, out, hidden) + // Every row always renders; missing values show "N/A", in Python's order. + out := renderGet(t, getData{ + RunIDDisplay: "1", + Status: "PENDING", + SubmittedDisplay: na, + DurationDisplay: na, + ExperimentDisplay: na, + MLflowDisplay: na, + UserDisplay: na, + AcceleratorsDisplay: na, + }) + for _, want := range []string{ + "Run ID: 1", + "Status: PENDING", + "Submitted: N/A", + "Retries: 0", + "Duration: N/A", + "Experiment: N/A", + "MLflow Run: N/A", + "User: N/A", + "Accelerators: N/A", + } { + assert.Contains(t, out, want) } } func TestGetTemplateAllFields(t *testing.T) { - started := "2023-11-14T22:13:20Z" - exp := "exp" - mlflow := "https://example.test/ml/exp/1" out := renderGet(t, getData{ - RunID: "1", - Status: "SUCCESS", - StartedAt: &started, - Duration: "12s", - AttemptNumber: 2, - ExperimentName: &exp, - User: "me@example.com", - Accelerators: "8x H100", - MLflowURL: &mlflow, - DashboardURL: "https://example.test/1", + RunIDDisplay: "1", + Status: "SUCCESS", + SubmittedDisplay: "2023-11-14 22:13 UTC", + DurationDisplay: "12s", + AttemptNumber: 2, + ExperimentDisplay: "exp", + MLflowDisplay: "sunny-cat-42", + UserDisplay: "me@example.com", + AcceleratorsDisplay: "8x H100", }) for _, want := range []string{ - "Submitted: 2023-11-14T22:13:20Z", - "Duration: 12s", + "Run ID: 1", + "Status: SUCCESS", + "Submitted: 2023-11-14 22:13 UTC", "Retries: 2", + "Duration: 12s", "Experiment: exp", + "MLflow Run: sunny-cat-42", "User: me@example.com", "Accelerators: 8x H100", - "MLflow: https://example.test/ml/exp/1", - "Dashboard: https://example.test/1", } { assert.Contains(t, out, want) } @@ -227,13 +240,28 @@ func TestBuildGetData(t *testing.T) { } d := buildGetData(run) assert.Equal(t, "123", d.RunID) + assert.Equal(t, "123", d.RunIDDisplay) assert.Equal(t, "SUCCESS", d.Status) assert.Equal(t, 1, d.AttemptNumber) - assert.Equal(t, "me@example.com", d.User) - assert.Equal(t, "8x H100", d.Accelerators) - assert.Equal(t, "12s", d.Duration) + assert.Equal(t, "2023-11-14 22:13 UTC", d.SubmittedDisplay) + assert.Equal(t, "me@example.com", d.UserDisplay) + assert.Equal(t, "8x H100", d.AcceleratorsDisplay) + assert.Equal(t, "12s", d.DurationDisplay) + assert.Equal(t, "exp", d.ExperimentDisplay) require.NotNil(t, d.ExperimentName) assert.Equal(t, "exp", *d.ExperimentName) require.NotNil(t, d.DurationSeconds) assert.Equal(t, int64(12), *d.DurationSeconds) } + +func TestBuildGetDataEmpty(t *testing.T) { + // A run with no tasks, creator, or timing renders every text cell as "N/A". + d := buildGetData(&jobs.Run{RunId: 7}) + assert.Equal(t, "7", d.RunIDDisplay) + assert.Equal(t, na, d.SubmittedDisplay) + assert.Equal(t, na, d.DurationDisplay) + assert.Equal(t, na, d.ExperimentDisplay) + assert.Equal(t, na, d.MLflowDisplay) + assert.Equal(t, na, d.UserDisplay) + assert.Equal(t, na, d.AcceleratorsDisplay) +} diff --git a/experimental/air/cmd/mlflow.go b/experimental/air/cmd/mlflow.go index 6cebecb0af5..7e2b5b8fd58 100644 --- a/experimental/air/cmd/mlflow.go +++ b/experimental/air/cmd/mlflow.go @@ -25,11 +25,17 @@ type getRunOutputResponse struct { } `json:"gen_ai_compute_output"` } -// mlflowURL returns a link to the run's MLflow logs, or nil if it can't be -// built. The link is a convenience, so any failure here (missing task, endpoint -// error, run not yet started) is logged and treated as "no link" rather than -// failing the whole command. -func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run) *string { +// mlflowIdentifiers are the experiment and run IDs MLflow assigns to a run. +type mlflowIdentifiers struct { + ExperimentID string + RunID string +} + +// mlflowIDs fetches the run's MLflow experiment and run IDs, or nil if they +// can't be obtained. They drive a convenience link, so any failure here +// (missing task, endpoint error, run not yet started) is logged and treated as +// "no link" rather than failing the whole command. +func mlflowIDs(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run) *mlflowIdentifiers { if len(run.Tasks) == 0 { return nil } @@ -60,9 +66,61 @@ func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run if info.MlflowExperimentID == "" || info.MlflowRunID == "" { return nil } + return &mlflowIdentifiers{ExperimentID: info.MlflowExperimentID, RunID: info.MlflowRunID} +} + +// mlflowLogsURL is the deep link to a run's node-0 logs. It is the value of the +// JSON `mlflow_url` field, matching the Python CLI. +func mlflowLogsURL(host string, ids *mlflowIdentifiers) string { + return fmt.Sprintf("%s/ml/experiments/%s/runs/%s/artifacts/logs/node_0", + strings.TrimRight(host, "/"), ids.ExperimentID, ids.RunID) +} - host := strings.TrimRight(w.Config.Host, "/") - url := fmt.Sprintf("%s/ml/experiments/%s/runs/%s/artifacts/logs/node_0", - host, info.MlflowExperimentID, info.MlflowRunID) - return &url +// mlflowExperimentURL links to the MLflow experiment page; mlflowRunURL links to +// the run page. These back the Experiment and MLflow Run hyperlinks in text mode. +func mlflowExperimentURL(host string, ids *mlflowIdentifiers) string { + return fmt.Sprintf("%s/ml/experiments/%s", strings.TrimRight(host, "/"), ids.ExperimentID) +} + +func mlflowRunURL(host string, ids *mlflowIdentifiers) string { + return fmt.Sprintf("%s/ml/experiments/%s/runs/%s", + strings.TrimRight(host, "/"), ids.ExperimentID, ids.RunID) +} + +// mlflowRunLabel returns the MLflow run's human-readable name to use as the +// hyperlink text, falling back to "...{last 8 of run id}" when the name can't be +// fetched. Mirrors Python's _get_mlflow_run_name (cli_display.py). +func mlflowRunLabel(ctx context.Context, w *databricks.WorkspaceClient, mlflowRunID string) string { + if name := fetchMLflowRunName(ctx, w, mlflowRunID); name != "" { + return name + } + if len(mlflowRunID) > 8 { + return "..." + mlflowRunID[len(mlflowRunID)-8:] + } + return "..." + mlflowRunID +} + +// fetchMLflowRunName fetches a run's MLflow run_name via the MLflow REST API, +// returning "" if it can't be obtained. Best-effort, like the rest of the +// MLflow enrichment. +func fetchMLflowRunName(ctx context.Context, w *databricks.WorkspaceClient, mlflowRunID string) string { + apiClient, err := client.New(w.Config) + if err != nil { + log.Debugf(ctx, "air get: could not build API client for MLflow run name: %v", err) + return "" + } + var out struct { + Run struct { + Info struct { + RunName string `json:"run_name"` + } `json:"info"` + } `json:"run"` + } + err = apiClient.Do(ctx, http.MethodGet, "/api/2.0/mlflow/runs/get", + nil, nil, map[string]any{"run_id": mlflowRunID}, &out) + if err != nil { + log.Debugf(ctx, "air get: could not fetch MLflow run name: %v", err) + return "" + } + return out.Run.Info.RunName } diff --git a/experimental/air/cmd/mlflow_test.go b/experimental/air/cmd/mlflow_test.go index bdc54c18d53..a6a5c151074 100644 --- a/experimental/air/cmd/mlflow_test.go +++ b/experimental/air/cmd/mlflow_test.go @@ -12,7 +12,7 @@ import ( ) // newTestWorkspaceClient builds a WorkspaceClient pointed at a mock HTTP server. -// mlflowURL calls the runs/get-output REST endpoint directly (the field it needs +// mlflowIDs calls the runs/get-output REST endpoint directly (the field it needs // is not modeled by the typed SDK), so it must be exercised over HTTP. func newTestWorkspaceClient(t *testing.T, host string) *databricks.WorkspaceClient { t.Helper() @@ -37,29 +37,29 @@ func runOutputServer(t *testing.T, body string, hit *bool) *httptest.Server { return srv } -func TestMLflowURL(t *testing.T) { +func TestMLflowIDs(t *testing.T) { ctx := t.Context() run := &jobs.Run{Tasks: []jobs.RunTask{{RunId: 99}}} - t.Run("builds the deep-link on success", func(t *testing.T) { + t.Run("returns the identifiers on success", func(t *testing.T) { var hit bool srv := runOutputServer(t, `{"gen_ai_compute_output":{"run_info":{"mlflow_experiment_id":"E1","mlflow_run_id":"R1"}}}`, &hit) - got := mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run) + got := mlflowIDs(ctx, newTestWorkspaceClient(t, srv.URL), run) require.NotNil(t, got) assert.True(t, hit, "runs/get-output should have been called") - assert.Equal(t, srv.URL+"/ml/experiments/E1/runs/R1/artifacts/logs/node_0", *got) + assert.Equal(t, &mlflowIdentifiers{ExperimentID: "E1", RunID: "R1"}, got) }) t.Run("nil when the run has no MLflow info", func(t *testing.T) { var hit bool srv := runOutputServer(t, `{}`, &hit) - assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run)) + assert.Nil(t, mlflowIDs(ctx, newTestWorkspaceClient(t, srv.URL), run)) }) t.Run("nil when the run has no tasks", func(t *testing.T) { // Returns before any HTTP call, so the host is never contacted. - assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, "https://unused.invalid"), &jobs.Run{})) + assert.Nil(t, mlflowIDs(ctx, newTestWorkspaceClient(t, "https://unused.invalid"), &jobs.Run{})) }) t.Run("uses the latest attempt's task run", func(t *testing.T) { @@ -74,7 +74,39 @@ func TestMLflowURL(t *testing.T) { t.Cleanup(srv.Close) retried := &jobs.Run{Tasks: []jobs.RunTask{{RunId: 99}, {RunId: 100}}} - mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), retried) + mlflowIDs(ctx, newTestWorkspaceClient(t, srv.URL), retried) assert.Equal(t, "100", gotRunID) }) } + +func TestMLflowURLs(t *testing.T) { + ids := &mlflowIdentifiers{ExperimentID: "E1", RunID: "R1"} + // A trailing slash on the host must not produce a double slash in the link. + assert.Equal(t, "https://h.test/ml/experiments/E1/runs/R1/artifacts/logs/node_0", mlflowLogsURL("https://h.test/", ids)) + assert.Equal(t, "https://h.test/ml/experiments/E1", mlflowExperimentURL("https://h.test", ids)) + assert.Equal(t, "https://h.test/ml/experiments/E1/runs/R1", mlflowRunURL("https://h.test", ids)) +} + +func TestMLflowRunLabel(t *testing.T) { + ctx := t.Context() + + t.Run("uses the run name when available", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/2.0/mlflow/runs/get" { + _, _ = w.Write([]byte(`{"run":{"info":{"run_name":"sunny-cat-42"}}}`)) + return + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + assert.Equal(t, "sunny-cat-42", mlflowRunLabel(ctx, newTestWorkspaceClient(t, srv.URL), "0123456789abcdef")) + }) + + t.Run("falls back to the last 8 characters of the run id", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + assert.Equal(t, "...9abcdef0", mlflowRunLabel(ctx, newTestWorkspaceClient(t, srv.URL), "0123456789abcdef0")) + }) +} From fa3a1a29d33c48695d6591434443fbcb5fe6555a Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 23:32:03 +0000 Subject: [PATCH 17/18] experimental/air: add GPU accelerator type and compute config model Add compute.go: the gpuType model and compute-block validation the upcoming `air run` config layer depends on. Defines the canonical GPU_* accelerator types, parseGPUType (exact, case-sensitive), gpusPerNode (partition counts), and computeConfig.validate (positive count, multiple-of-per-node, mutually exclusive node_pool_id/pool_name). Co-authored-by: Isaac --- experimental/air/cmd/compute.go | 88 +++++++++++++++++++++++++++ experimental/air/cmd/compute_test.go | 89 ++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 experimental/air/cmd/compute.go create mode 100644 experimental/air/cmd/compute_test.go diff --git a/experimental/air/cmd/compute.go b/experimental/air/cmd/compute.go new file mode 100644 index 00000000000..02e1916520d --- /dev/null +++ b/experimental/air/cmd/compute.go @@ -0,0 +1,88 @@ +package aircmd + +import ( + "errors" + "fmt" + "strings" +) + +// gpuType is a wire-facing accelerator type submitted to the training service. +// The number in the name is the partition count (e.g. GPU_8xH100 is 8 GPUs). +type gpuType string + +const ( + gpuType1xA10 gpuType = "GPU_1xA10" + gpuType8xH100 gpuType = "GPU_8xH100" + gpuType1xH100 gpuType = "GPU_1xH100" +) + +// gpuTypes lists every valid type. Used for validation error messages. +var gpuTypes = []gpuType{gpuType1xA10, gpuType1xH100, gpuType8xH100} + +func validGPUTypesHint() string { + names := make([]string, len(gpuTypes)) + for i, g := range gpuTypes { + names[i] = string(g) + } + return "valid types are: " + strings.Join(names, ", ") +} + +// parseGPUType resolves a YAML accelerator_type string to a gpuType. The match is +// exact: the server's lookup is case-sensitive. +func parseGPUType(value string) (gpuType, error) { + switch gpuType(value) { + case gpuType1xA10, gpuType8xH100, gpuType1xH100: + return gpuType(value), nil + } + return "", fmt.Errorf("invalid GPU type %q: %s", value, validGPUTypesHint()) +} + +// gpusPerNode returns the per-node GPU count, which is the partition count from +// the name (GPU_1xH100 -> 1, GPU_8xH100 -> 8). num_accelerators must be a +// round multiple of this since accelerators are allocated in whole nodes. +func gpusPerNode(g gpuType) (int, error) { + switch g { + case gpuType1xA10, gpuType1xH100: + return 1, nil + case gpuType8xH100: + return 8, nil + } + // Unreachable: callers resolve g through parseGPUType first, which rejects + // unknown types. Kept as a defensive guard. + return 0, fmt.Errorf("invalid GPU type %q", string(g)) +} + +// computeConfig is the `compute` block of the run YAML: which accelerators to +// use and how many. +type computeConfig struct { + NumAccelerators int `yaml:"num_accelerators"` + AcceleratorType string `yaml:"accelerator_type"` + NodePoolID string `yaml:"node_pool_id"` + PoolName string `yaml:"pool_name"` +} + +// validate checks the compute block against the backend's constraints. +func (c computeConfig) validate() error { + g, err := parseGPUType(c.AcceleratorType) + if err != nil { + return fmt.Errorf("compute.accelerator_type: %w", err) + } + + if c.NumAccelerators <= 0 { + return fmt.Errorf("compute.num_accelerators must be positive, got %d", c.NumAccelerators) + } + + perNode, err := gpusPerNode(g) + if err != nil { + return err + } + if c.NumAccelerators%perNode != 0 { + return fmt.Errorf("compute.num_accelerators for %s must be a multiple of %d, got %d", c.AcceleratorType, perNode, c.NumAccelerators) + } + + if c.NodePoolID != "" && c.PoolName != "" { + return errors.New("compute: cannot specify both node_pool_id and pool_name") + } + + return nil +} diff --git a/experimental/air/cmd/compute_test.go b/experimental/air/cmd/compute_test.go new file mode 100644 index 00000000000..ad91d470861 --- /dev/null +++ b/experimental/air/cmd/compute_test.go @@ -0,0 +1,89 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseGPUType(t *testing.T) { + tests := []struct { + in string + want gpuType + }{ + {"GPU_1xA10", gpuType1xA10}, + {"GPU_8xH100", gpuType8xH100}, + {"GPU_1xH100", gpuType1xH100}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + got, err := parseGPUType(tt.in) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestParseGPUTypeInvalid(t *testing.T) { + // Wrong casing is rejected rather than fixed up; legacy types (h100_80gb, a10) + // can no longer be submitted; unknown types are rejected. + for _, in := range []string{"gpu_1xa10", "GPU_1XA10", "GPU_2xH100", "h100_80gb", "a10", "b200", ""} { + t.Run(in, func(t *testing.T) { + _, err := parseGPUType(in) + require.Error(t, err) + assert.Contains(t, err.Error(), "valid types are") + }) + } +} + +func TestGPUsPerNode(t *testing.T) { + tests := []struct { + in gpuType + want int + }{ + {gpuType1xA10, 1}, + {gpuType1xH100, 1}, + {gpuType8xH100, 8}, + } + for _, tt := range tests { + t.Run(string(tt.in), func(t *testing.T) { + got, err := gpusPerNode(tt.in) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } + + _, err := gpusPerNode(gpuType("nonsense")) + require.Error(t, err) +} + +func TestComputeConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg computeConfig + wantErr string // substring; empty means the config is valid + }{ + {"single node", computeConfig{NumAccelerators: 8, AcceleratorType: "GPU_8xH100"}, ""}, + {"multiple nodes", computeConfig{NumAccelerators: 16, AcceleratorType: "GPU_8xH100"}, ""}, + {"single-gpu partitions", computeConfig{NumAccelerators: 3, AcceleratorType: "GPU_1xH100"}, ""}, + {"with node pool", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "pool-123"}, ""}, + {"with pool name", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", PoolName: "my-pool"}, ""}, + {"unknown type", computeConfig{NumAccelerators: 8, AcceleratorType: "b200"}, "accelerator_type"}, + {"legacy type rejected", computeConfig{NumAccelerators: 8, AcceleratorType: "h100_80gb"}, "accelerator_type"}, + {"non-positive count", computeConfig{NumAccelerators: 0, AcceleratorType: "GPU_1xH100"}, "must be positive"}, + {"count not a multiple", computeConfig{NumAccelerators: 4, AcceleratorType: "GPU_8xH100"}, "multiple of 8"}, + {"both pool fields", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "p", PoolName: "n"}, "both"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.validate() + if tt.wantErr == "" { + require.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} From 62be1a1cd17a5fc94139d0a13abdc66d56c493ba Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Tue, 16 Jun 2026 18:41:30 +0000 Subject: [PATCH 18/18] experimental/air: drop node pool / pool name compute fields The training compute config no longer supports pool placement, so remove the node_pool_id and pool_name fields and the validation that rejected setting both. Co-authored-by: Isaac --- experimental/air/cmd/compute.go | 7 ------- experimental/air/cmd/compute_test.go | 3 --- 2 files changed, 10 deletions(-) diff --git a/experimental/air/cmd/compute.go b/experimental/air/cmd/compute.go index 02e1916520d..07013c53906 100644 --- a/experimental/air/cmd/compute.go +++ b/experimental/air/cmd/compute.go @@ -1,7 +1,6 @@ package aircmd import ( - "errors" "fmt" "strings" ) @@ -57,8 +56,6 @@ func gpusPerNode(g gpuType) (int, error) { type computeConfig struct { NumAccelerators int `yaml:"num_accelerators"` AcceleratorType string `yaml:"accelerator_type"` - NodePoolID string `yaml:"node_pool_id"` - PoolName string `yaml:"pool_name"` } // validate checks the compute block against the backend's constraints. @@ -80,9 +77,5 @@ func (c computeConfig) validate() error { return fmt.Errorf("compute.num_accelerators for %s must be a multiple of %d, got %d", c.AcceleratorType, perNode, c.NumAccelerators) } - if c.NodePoolID != "" && c.PoolName != "" { - return errors.New("compute: cannot specify both node_pool_id and pool_name") - } - return nil } diff --git a/experimental/air/cmd/compute_test.go b/experimental/air/cmd/compute_test.go index ad91d470861..3464afbe9ea 100644 --- a/experimental/air/cmd/compute_test.go +++ b/experimental/air/cmd/compute_test.go @@ -67,13 +67,10 @@ func TestComputeConfigValidate(t *testing.T) { {"single node", computeConfig{NumAccelerators: 8, AcceleratorType: "GPU_8xH100"}, ""}, {"multiple nodes", computeConfig{NumAccelerators: 16, AcceleratorType: "GPU_8xH100"}, ""}, {"single-gpu partitions", computeConfig{NumAccelerators: 3, AcceleratorType: "GPU_1xH100"}, ""}, - {"with node pool", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "pool-123"}, ""}, - {"with pool name", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", PoolName: "my-pool"}, ""}, {"unknown type", computeConfig{NumAccelerators: 8, AcceleratorType: "b200"}, "accelerator_type"}, {"legacy type rejected", computeConfig{NumAccelerators: 8, AcceleratorType: "h100_80gb"}, "accelerator_type"}, {"non-positive count", computeConfig{NumAccelerators: 0, AcceleratorType: "GPU_1xH100"}, "must be positive"}, {"count not a multiple", computeConfig{NumAccelerators: 4, AcceleratorType: "GPU_8xH100"}, "multiple of 8"}, - {"both pool fields", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "p", PoolName: "n"}, "both"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {